# Unified Model Comparison

**Objective**: Compare all models (Temporal GCN, Static GCN, MLP + Graph Features, Baselines) across observation windows K.

**Data source**: Pre-computed multi-seed results from all experiment notebooks.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 8)

## Load All Results

In [None]:
# Define result directories for each model
results_dirs = {
    'Temporal GCN': Path('../../results/evolve_gcn_multi_seed'),
    'Static GCN': Path('../../results/static_gcn_multi_seed'),
    'MLP + Graph Features': Path('../../results/graph_features_baseline_multi_seed'),
    'Logistic Regression': Path('../../results/baselines_multi_seed/logistic_regression'),
    'Random Forest': Path('../../results/baselines_multi_seed/random_forest'),
    'XGBoost': Path('../../results/baselines_multi_seed/xgboost')
}

# Load all results
all_models_results = {}

for model_name, result_dir in results_dirs.items():
    csv_path = result_dir / 'all_seeds_all_metrics.csv'
    if csv_path.exists():
        df = pd.read_csv(csv_path)
        df['model'] = model_name
        all_models_results[model_name] = df
        print(f"✓ Loaded {model_name}: {len(df)} rows")
    else:
        print(f"✗ Missing: {csv_path}")

# Combine all results
combined_df = pd.concat(all_models_results.values(), ignore_index=True)
print(f"\nTotal combined results: {len(combined_df)} rows")
print(f"Models: {combined_df['model'].unique()}")
print(f"K values: {sorted(combined_df['K'].unique())}")
print(f"Splits: {combined_df['split'].unique()}")

## Compute Summary Statistics

In [None]:
# Filter test set only
test_df = combined_df[combined_df['split'] == 'test'].copy()

# Compute mean and std for each model/K combination
summary_stats = test_df.groupby(['model', 'K']).agg({
    'f1': ['mean', 'std'],
    'auc': ['mean', 'std'],
    'precision': ['mean', 'std'],
    'recall': ['mean', 'std'],
    'accuracy': ['mean', 'std']
}).reset_index()

# Flatten column names
summary_stats.columns = ['_'.join(col).strip('_') for col in summary_stats.columns.values]

print("Summary statistics computed:")
print(summary_stats.head(10))

## Visualization 1: F1 Score Comparison (All Models)

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

for model_name in results_dirs.keys():
    model_data = summary_stats[summary_stats['model'] == model_name]
    ax.errorbar(
        model_data['K'],
        model_data['f1_mean'],
        yerr=model_data['f1_std'],
        marker='o',
        linewidth=2,
        capsize=5,
        label=model_name
    )

ax.set_xlabel('Observation Window K', fontsize=14)
ax.set_ylabel('F1 Score', fontsize=14)
ax.set_title('F1 Score vs Observation Window (All Models)', fontsize=16, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Visualization 2: AUC Comparison (All Models)

In [None]:
fig, ax = plt.subplots(figsize=(12, 6))

for model_name in results_dirs.keys():
    model_data = summary_stats[summary_stats['model'] == model_name]
    ax.errorbar(
        model_data['K'],
        model_data['auc_mean'],
        yerr=model_data['auc_std'],
        marker='o',
        linewidth=2,
        capsize=5,
        label=model_name
    )

ax.set_xlabel('Observation Window K', fontsize=14)
ax.set_ylabel('AUC', fontsize=14)
ax.set_title('AUC vs Observation Window (All Models)', fontsize=16, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Visualization 3: Heatmap (F1 Scores)

In [None]:
# Create pivot table for heatmap
f1_pivot = summary_stats.pivot(index='model', columns='K', values='f1_mean')

fig, ax = plt.subplots(figsize=(10, 6))
sns.heatmap(f1_pivot, annot=True, fmt='.3f', cmap='RdYlGn', ax=ax, cbar_kws={'label': 'F1 Score'})
ax.set_title('F1 Score Heatmap: Models vs Observation Window K', fontsize=14, fontweight='bold')
ax.set_xlabel('Observation Window K', fontsize=12)
ax.set_ylabel('Model', fontsize=12)
plt.tight_layout()
plt.show()

## Visualization 4: Best Model per K

In [None]:
# Find best model for each K
best_per_k_f1 = summary_stats.loc[summary_stats.groupby('K')['f1_mean'].idxmax()]
best_per_k_auc = summary_stats.loc[summary_stats.groupby('K')['auc_mean'].idxmax()]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Best F1
ax = axes[0]
ax.bar(best_per_k_f1['K'], best_per_k_f1['f1_mean'], color='steelblue', alpha=0.7)
ax.errorbar(best_per_k_f1['K'], best_per_k_f1['f1_mean'], 
            yerr=best_per_k_f1['f1_std'], fmt='none', color='black', capsize=5)
for i, (k, model) in enumerate(zip(best_per_k_f1['K'], best_per_k_f1['model'])):
    ax.text(k, best_per_k_f1['f1_mean'].iloc[i] + 0.02, model, 
            ha='center', fontsize=9, rotation=45)
ax.set_xlabel('Observation Window K', fontsize=12)
ax.set_ylabel('Best F1 Score', fontsize=12)
ax.set_title('Best Model by F1 Score per K', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Best AUC
ax = axes[1]
ax.bar(best_per_k_auc['K'], best_per_k_auc['auc_mean'], color='green', alpha=0.7)
ax.errorbar(best_per_k_auc['K'], best_per_k_auc['auc_mean'], 
            yerr=best_per_k_auc['auc_std'], fmt='none', color='black', capsize=5)
for i, (k, model) in enumerate(zip(best_per_k_auc['K'], best_per_k_auc['model'])):
    ax.text(k, best_per_k_auc['auc_mean'].iloc[i] + 0.02, model, 
            ha='center', fontsize=9, rotation=45)
ax.set_xlabel('Observation Window K', fontsize=12)
ax.set_ylabel('Best AUC', fontsize=12)
ax.set_title('Best Model by AUC per K', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## Visualization 5: Variance Analysis (Model Stability)

In [None]:
# Average std across all K values for each model
stability = summary_stats.groupby('model').agg({
    'f1_std': 'mean',
    'auc_std': 'mean'
}).sort_values('f1_std')

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# F1 Stability
ax = axes[0]
stability['f1_std'].plot(kind='barh', ax=ax, color='coral')
ax.set_xlabel('Average F1 Std Deviation', fontsize=12)
ax.set_ylabel('Model', fontsize=12)
ax.set_title('Model Stability (F1 Variance)', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')

# AUC Stability
ax = axes[1]
stability['auc_std'].plot(kind='barh', ax=ax, color='skyblue')
ax.set_xlabel('Average AUC Std Deviation', fontsize=12)
ax.set_ylabel('Model', fontsize=12)
ax.set_title('Model Stability (AUC Variance)', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.show()

print("\nModel Stability Ranking (lower std = more stable):")
print(stability)

## Comparison Tables

In [None]:
# Create formatted comparison table
comparison_table = summary_stats.copy()
comparison_table['F1'] = comparison_table.apply(
    lambda x: f"{x['f1_mean']:.3f} ± {x['f1_std']:.3f}", axis=1
)
comparison_table['AUC'] = comparison_table.apply(
    lambda x: f"{x['auc_mean']:.3f} ± {x['auc_std']:.3f}", axis=1
)
comparison_table['Precision'] = comparison_table.apply(
    lambda x: f"{x['precision_mean']:.3f} ± {x['precision_std']:.3f}", axis=1
)
comparison_table['Recall'] = comparison_table.apply(
    lambda x: f"{x['recall_mean']:.3f} ± {x['recall_std']:.3f}", axis=1
)

display_table = comparison_table[['model', 'K', 'F1', 'AUC', 'Precision', 'Recall']]

print("\n" + "="*80)
print("COMPREHENSIVE MODEL COMPARISON (Test Set, Mean ± Std)")
print("="*80)
print(display_table.to_string(index=False))

# Best performers
print("\n" + "="*80)
print("BEST PERFORMERS")
print("="*80)
print("\nBest F1 Score per K:")
print(best_per_k_f1[['K', 'model', 'f1_mean', 'f1_std']].to_string(index=False))
print("\nBest AUC per K:")
print(best_per_k_auc[['K', 'model', 'auc_mean', 'auc_std']].to_string(index=False))

## Save Combined Results

In [None]:
# Save combined results
output_dir = Path('../../results')
output_dir.mkdir(parents=True, exist_ok=True)

# Save full comparison
combined_df.to_csv(output_dir / 'all_models_comparison.csv', index=False)
print(f"Saved: {output_dir / 'all_models_comparison.csv'}")

# Save summary statistics
summary_stats.to_csv(output_dir / 'comparison_summary_statistics.csv', index=False)
print(f"Saved: {output_dir / 'comparison_summary_statistics.csv'}")

# Save formatted table
display_table.to_csv(output_dir / 'comparison_formatted.csv', index=False)
print(f"Saved: {output_dir / 'comparison_formatted.csv'}")

print("\n✅ All comparison results saved!")