In [21]:
from burned_embedder.utils import setup_device

In [23]:
# scripts/analyze_experiments_extended.py

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import rootutils
from scipy import stats

root_path = rootutils.find_root()



def results_to_dataframe(results):
    """Convert results list to pandas DataFrame with extended features"""
    if not results:
        raise ValueError("No results to analyze")
    
    data = []
    for result in results:
        row = {
            'experiment': result['experiment_name'],
            'test_f1': result['test_f1'],
            'test_precision': result['test_precision'],
            'test_recall': result['test_recall'],
            'test_roc_auc': result['test_roc_auc'],
            'val_f1': result['best_val_f1'],
            'overfit_gap': result['overfit_gap'],
            'best_epoch': result['best_epoch'],
            'input_type': result['config']['input_type'],
            'hidden_dims': str(result['config']['hidden_dims']),
            'dropout': result['config']['dropout'],
            'lr': result['config']['lr'],
            'weight_decay': result['config']['weight_decay'],
            'batch_size': result['config']['batch_size'],
            'patience': result['config'].get('patience', 15),
            'seed': result['config']['seed'],
        }
        
        # Compute derived metrics
        cm = np.array(result['confusion_matrix'])
        row['tn'] = cm[0, 0]
        row['fp'] = cm[0, 1]
        row['fn'] = cm[1, 0]
        row['tp'] = cm[1, 1]
        row['false_positive_rate'] = row['fp'] / (row['fp'] + row['tn']) if (row['fp'] + row['tn']) > 0 else 0
        row['false_negative_rate'] = row['fn'] / (row['fn'] + row['tp']) if (row['fn'] + row['tp']) > 0 else 0
        
        # Model complexity
        hidden_list = eval(result['config']['hidden_dims'])
        row['num_layers'] = len(hidden_list)
        row['total_hidden_units'] = sum(hidden_list)
        row['max_layer_size'] = max(hidden_list)
        
        data.append(row)
    
    df = pd.DataFrame(data)
    return df


def print_summary_table(df, top_n=20):
    """Print formatted summary table"""
    print("\n" + "="*120)
    print("EXPERIMENT RESULTS SUMMARY")
    print("="*120)
    
    df_sorted = df.sort_values('test_f1', ascending=False)
    
    print(f"\n{'Rank':<6}{'Experiment':<35}{'F1':<8}{'Prec':<8}{'Recall':<8}{'AUC':<8}{'Overfit':<10}{'Epoch':<7}")
    print("-"*120)
    
    for idx, (_, row) in enumerate(df_sorted.head(top_n).iterrows(), 1):
        print(f"{idx:<6}{row['experiment']:<35}{row['test_f1']:<8.4f}{row['test_precision']:<8.4f}"
              f"{row['test_recall']:<8.4f}{row['test_roc_auc']:<8.4f}{row['overfit_gap']:<10.4f}{row['best_epoch']:<7.0f}")
    
    # Print bottom performers too
    print("\n" + "-"*120)
    print("BOTTOM 5 PERFORMERS:")
    print("-"*120)
    for idx, (_, row) in enumerate(df_sorted.tail(5).iterrows(), len(df_sorted)-4):
        print(f"{idx:<6}{row['experiment']:<35}{row['test_f1']:<8.4f}{row['test_precision']:<8.4f}"
              f"{row['test_recall']:<8.4f}{row['test_roc_auc']:<8.4f}{row['overfit_gap']:<10.4f}{row['best_epoch']:<7.0f}")
    
    # Best by different metrics
    print(f"\n" + "="*120)
    print("BEST MODELS BY METRIC")
    print("="*120)
    
    metrics = {
        'F1 Score': ('test_f1', 'max'),
        'Precision': ('test_precision', 'max'),
        'Recall': ('test_recall', 'max'),
        'AUC': ('test_roc_auc', 'max'),
        'Least Overfit': ('overfit_gap', 'min'),
        'Fastest Convergence': ('best_epoch', 'min'),
        'Best Val-Test Agreement': ('overfit_gap', 'abs_min'),
    }
    
    for metric_name, (col, agg) in metrics.items():
        if agg == 'abs_min':
            idx = df[col].abs().idxmin()
        else:
            idx = df[col].idxmax() if agg == 'max' else df[col].idxmin()
        row = df.loc[idx]
        print(f"\n{metric_name}:")
        print(f"  Model: {row['experiment']}")
        print(f"  Value: {row[col]:.4f}")
        print(f"  Config: input={row['input_type']}, arch={row['hidden_dims']}, "
              f"dropout={row['dropout']}, lr={row['lr']}, wd={row['weight_decay']}")


def analyze_by_category(df):
    """Analyze results by different configuration categories"""
    print(f"\n" + "="*120)
    print("DETAILED ANALYSIS BY CONFIGURATION")
    print("="*120)
    
    # Helper function for pretty printing
    def print_analysis(group_col, title):
        print(f"\n{title}:")
        print("-"*120)
        analysis = df.groupby(group_col).agg({
            'test_f1': ['count', 'mean', 'std', 'min', 'max'],
            'overfit_gap': ['mean', 'std'],
            'test_precision': 'mean',
            'test_recall': 'mean',
            'test_roc_auc': 'mean',
            'best_epoch': 'mean'
        }).round(4)
        print(analysis)
        
        # Statistical test if enough groups
        if len(df[group_col].unique()) >= 2:
            groups = [group['test_f1'].values for name, group in df.groupby(group_col)]
            if len(groups) >= 2 and all(len(g) >= 2 for g in groups[:2]):
                statistic, pvalue = stats.f_oneway(*groups)
                print(f"\nANOVA F-statistic: {statistic:.4f}, p-value: {pvalue:.4e}")
                if pvalue < 0.05:
                    print("  → Statistically significant differences detected!")
                else:
                    print("  → No significant differences (might just be noise)")
    
    # Run analyses
    print_analysis('input_type', '1. BY INPUT TYPE')
    print_analysis('dropout', '2. BY DROPOUT RATE')
    print_analysis('lr', '3. BY LEARNING RATE')
    print_analysis('weight_decay', '4. BY WEIGHT DECAY')
    print_analysis('batch_size', '5. BY BATCH SIZE')
    print_analysis('num_layers', '6. BY NETWORK DEPTH')
    print_analysis('patience', '7. BY EARLY STOPPING PATIENCE')
    
    # Architecture size bins
    df['arch_size'] = pd.cut(df['total_hidden_units'], 
                              bins=[0, 300, 600, 900, 1500, 5000],
                              labels=['tiny', 'small', 'medium', 'large', 'huge'])
    print_analysis('arch_size', '8. BY ARCHITECTURE SIZE')


def deep_dive_concat(df):
    """Deep analysis specifically for concat experiments"""
    print(f"\n" + "="*120)
    print("DEEP DIVE: CONCAT INPUT EXPERIMENTS")
    print("="*120)
    
    concat_df = df[df['input_type'] == 'concat'].copy()
    
    if len(concat_df) == 0:
        print("No concat experiments found!")
        return
    
    print(f"\nTotal concat experiments: {len(concat_df)}")
    print(f"Best F1: {concat_df['test_f1'].max():.4f}")
    print(f"Worst F1: {concat_df['test_f1'].min():.4f}")
    print(f"Mean F1: {concat_df['test_f1'].mean():.4f} ± {concat_df['test_f1'].std():.4f}")
    print(f"Median overfit gap: {concat_df['overfit_gap'].median():.4f}")
    
    # Find optimal hyperparameters
    print("\n" + "-"*120)
    print("OPTIMAL HYPERPARAMETERS FOR CONCAT:")
    print("-"*120)
    
    for param in ['dropout', 'lr', 'weight_decay', 'batch_size', 'hidden_dims']:
        best_idx = concat_df['test_f1'].idxmax()
        top_5 = concat_df.nlargest(5, 'test_f1')
        
        print(f"\n{param.upper()}:")
        print(f"  Best single model: {concat_df.loc[best_idx, param]}")
        print(f"  Top 5 average: {top_5[param].mode().values if param == 'hidden_dims' else top_5[param].mean()}")
        
        # Show distribution
        if param != 'hidden_dims':
            value_counts = top_5[param].value_counts()
            print(f"  Distribution in top 5:")
            for val, count in value_counts.items():
                print(f"    {val}: {count}/5")
    
    # Correlation analysis
    print("\n" + "-"*120)
    print("HYPERPARAMETER CORRELATIONS WITH F1:")
    print("-"*120)
    
    numeric_params = ['dropout', 'lr', 'weight_decay', 'batch_size', 'total_hidden_units', 'num_layers']
    correlations = {}
    
    for param in numeric_params:
        if concat_df[param].nunique() > 1:
            corr = concat_df[param].corr(concat_df['test_f1'])
            correlations[param] = corr
            print(f"{param:>20}: {corr:>7.4f}")
    
    # Find sweet spots
    print("\n" + "-"*120)
    print("SWEET SPOT RANGES (values where F1 > 0.80):")
    print("-"*120)
    
    high_performers = concat_df[concat_df['test_f1'] > 0.80]
    if len(high_performers) > 0:
        for param in ['dropout', 'lr', 'weight_decay']:
            if high_performers[param].nunique() > 1:
                print(f"{param}: [{high_performers[param].min():.6f}, {high_performers[param].max():.6f}]")
    else:
        print("No models achieved F1 > 0.80")


def analyze_regularization_tradeoffs(df):
    """Analyze the relationship between regularization and performance"""
    print(f"\n" + "="*120)
    print("REGULARIZATION ANALYSIS")
    print("="*120)
    
    # Define regularization strength
    df['reg_strength'] = df['dropout'] + np.log10(df['weight_decay'] + 1e-10) / 5
    
    # Bin into categories
    df['reg_category'] = pd.cut(df['reg_strength'], 
                                 bins=[-np.inf, -1, 0, 1, np.inf],
                                 labels=['minimal', 'light', 'moderate', 'heavy'])
    
    print("\nPerformance by regularization strength:")
    reg_analysis = df.groupby('reg_category').agg({
        'test_f1': ['count', 'mean', 'std', 'max'],
        'overfit_gap': ['mean', 'std'],
        'test_precision': 'mean',
        'test_recall': 'mean',
    }).round(4)
    print(reg_analysis)
    
    # Overfitting analysis
    print("\n" + "-"*120)
    print("OVERFITTING STATISTICS:")
    print("-"*120)
    print(f"Models with overfit_gap > 0.15: {(df['overfit_gap'] > 0.15).sum()}/{len(df)}")
    print(f"Models with overfit_gap < 0.05: {(df['overfit_gap'] < 0.05).sum()}/{len(df)}")
    print(f"Models with negative gap (test > train): {(df['overfit_gap'] < 0).sum()}/{len(df)}")
    
    # Best regularization strategy
    well_generalized = df[df['overfit_gap'].abs() < 0.10]
    if len(well_generalized) > 0:
        best_generalized = well_generalized.loc[well_generalized['test_f1'].idxmax()]
        print(f"\nBest well-generalized model (|gap| < 0.10):")
        print(f"  {best_generalized['experiment']}")
        print(f"  F1: {best_generalized['test_f1']:.4f}, Gap: {best_generalized['overfit_gap']:.4f}")
        print(f"  Config: dropout={best_generalized['dropout']}, wd={best_generalized['weight_decay']}")


def analyze_precision_recall_tradeoff(df):
    """Analyze precision-recall tradeoff and identify Pareto frontier"""
    print(f"\n" + "="*120)
    print("PRECISION-RECALL TRADEOFF ANALYSIS")
    print("="*120)
    
    # Sort by F1 and show precision-recall balance
    print("\nTop 10 models by F1 with precision-recall breakdown:")
    print("-"*120)
    top_10 = df.nlargest(10, 'test_f1')
    
    for idx, (_, row) in enumerate(top_10.iterrows(), 1):
        prec_rec_diff = abs(row['test_precision'] - row['test_recall'])
        balance = "balanced" if prec_rec_diff < 0.05 else ("precision-biased" if row['test_precision'] > row['test_recall'] else "recall-biased")
        print(f"{idx:>2}. {row['experiment']:<35} F1:{row['test_f1']:.4f}  "
              f"P:{row['test_precision']:.4f}  R:{row['test_recall']:.4f}  [{balance}]")
    
    # Identify Pareto-optimal models
    print("\n" + "-"*120)
    print("PARETO FRONTIER (models where you can't improve one metric without hurting the other):")
    print("-"*120)
    
    pareto_models = []
    for idx, row in df.iterrows():
        is_pareto = True
        for _, other in df.iterrows():
            if (other['test_precision'] >= row['test_precision'] and 
                other['test_recall'] >= row['test_recall'] and
                (other['test_precision'] > row['test_precision'] or other['test_recall'] > row['test_recall'])):
                is_pareto = False
                break
        if is_pareto:
            pareto_models.append(row)
    
    pareto_df = pd.DataFrame(pareto_models).sort_values('test_f1', ascending=False)
    print(f"\nFound {len(pareto_df)} Pareto-optimal models:")
    for _, row in pareto_df.head(10).iterrows():
        print(f"  {row['experiment']:<35} P:{row['test_precision']:.4f}  R:{row['test_recall']:.4f}  F1:{row['test_f1']:.4f}")


def analyze_seed_variance(df):
    """Analyze variance across different random seeds"""
    print(f"\n" + "="*120)
    print("RANDOM SEED VARIANCE ANALYSIS")
    print("="*120)
    
    # Find experiments with multiple seeds
    df['exp_base'] = df['experiment'].str.replace(r'_seed\d+', '', regex=True)
    seed_groups = df.groupby('exp_base').filter(lambda x: len(x) > 1)
    
    if len(seed_groups) == 0:
        print("\nNo replicated experiments found (need same config with different seeds)")
        return
    
    print(f"\nFound {seed_groups['exp_base'].nunique()} configurations with multiple seeds:")
    print("-"*120)
    
    for exp_base, group in seed_groups.groupby('exp_base'):
        print(f"\n{exp_base}:")
        print(f"  Runs: {len(group)}")
        print(f"  F1: {group['test_f1'].mean():.4f} ± {group['test_f1'].std():.4f} "
              f"(range: [{group['test_f1'].min():.4f}, {group['test_f1'].max():.4f}])")
        print(f"  Precision: {group['test_precision'].mean():.4f} ± {group['test_precision'].std():.4f}")
        print(f"  Recall: {group['test_recall'].mean():.4f} ± {group['test_recall'].std():.4f}")
        print(f"  Overfit gap: {group['overfit_gap'].mean():.4f} ± {group['overfit_gap'].std():.4f}")
        
        if group['test_f1'].std() > 0.02:
            print(f"  ⚠️  High variance detected! Results may be unstable.")


def plot_extended_analysis(df, save_dir):
    """Create comprehensive visualizations"""
    fig = plt.figure(figsize=(24, 18))
    
    # Sort by F1
    df_sorted = df.sort_values('test_f1', ascending=True)
    
    # 1. Top 20 models comparison
    ax1 = plt.subplot(4, 4, 1)
    top_20 = df_sorted.tail(20)
    y_pos = np.arange(len(top_20))
    ax1.barh(y_pos, top_20['test_f1'], alpha=0.6, label='F1')
    ax1.barh(y_pos, top_20['test_precision'], alpha=0.6, label='Precision')
    ax1.barh(y_pos, top_20['test_recall'], alpha=0.6, label='Recall')
    ax1.set_yticks(y_pos)
    ax1.set_yticklabels([e.replace('exp_', '').replace('concat_', 'c_')[:20] for e in top_20['experiment']], fontsize=7)
    ax1.set_xlabel('Score')
    ax1.set_title('Top 20 Models')
    ax1.legend(fontsize=7)
    ax1.grid(axis='x', alpha=0.3)
    
    # 2. Overfitting scatter
    ax2 = plt.subplot(4, 4, 2)
    scatter = ax2.scatter(df['overfit_gap'], df['test_f1'], 
                         c=df['dropout'], s=100, alpha=0.6, cmap='coolwarm')
    ax2.axvline(x=0, color='black', linestyle='--', alpha=0.3)
    ax2.set_xlabel('Overfitting Gap')
    ax2.set_ylabel('Test F1')
    ax2.set_title('Overfitting vs Performance (color=dropout)')
    ax2.grid(alpha=0.3)
    plt.colorbar(scatter, ax=ax2, label='Dropout')
    
    # 3. Precision-Recall frontier
    ax3 = plt.subplot(4, 4, 3)
    ax3.scatter(df['test_recall'], df['test_precision'], 
               c=df['test_f1'], s=100, alpha=0.6, cmap='viridis')
    # Plot top 5
    top_5 = df.nlargest(5, 'test_f1')
    ax3.scatter(top_5['test_recall'], top_5['test_precision'], 
               c='red', s=200, marker='*', edgecolors='black', linewidths=2, label='Top 5')
    ax3.set_xlabel('Recall')
    ax3.set_ylabel('Precision')
    ax3.set_title('Precision-Recall Space (color=F1)')
    ax3.legend()
    ax3.grid(alpha=0.3)
    
    # 4. Input type comparison (boxplot)
    ax4 = plt.subplot(4, 4, 4)
    input_types = [df[df['input_type']==it]['test_f1'].values for it in df['input_type'].unique()]
    ax4.boxplot(input_types, labels=df['input_type'].unique())
    ax4.set_ylabel('Test F1')
    ax4.set_title('F1 Distribution by Input Type')
    ax4.grid(axis='y', alpha=0.3)
    
    # 5. Dropout effect (concat only)
    ax5 = plt.subplot(4, 4, 5)
    concat_df = df[df['input_type'] == 'concat']
    if len(concat_df) > 0:
        dropout_stats = concat_df.groupby('dropout').agg({'test_f1': ['mean', 'std', 'count']})
        dropout_stats = dropout_stats[dropout_stats[('test_f1', 'count')] >= 2]  # Only show if >= 2 samples
        if len(dropout_stats) > 0:
            ax5.errorbar(dropout_stats.index, 
                        dropout_stats[('test_f1', 'mean')],
                        yerr=dropout_stats[('test_f1', 'std')],
                        marker='o', capsize=5, linewidth=2)
            ax5.set_xlabel('Dropout Rate')
            ax5.set_ylabel('Mean Test F1 (concat)')
            ax5.set_title('Dropout Effect on Concat Models')
            ax5.grid(alpha=0.3)
    
    # 6. Weight decay effect (concat only)
    ax6 = plt.subplot(4, 4, 6)
    if len(concat_df) > 0:
        wd_data = concat_df[concat_df['weight_decay'] > 0]
        if len(wd_data) > 0:
            ax6.scatter(wd_data['weight_decay'], wd_data['test_f1'], s=100, alpha=0.6)
            ax6.set_xscale('log')
            ax6.set_xlabel('Weight Decay (log scale)')
            ax6.set_ylabel('Test F1')
            ax6.set_title('Weight Decay Effect (concat)')
            ax6.grid(alpha=0.3)
    
    # 7. Learning rate effect
    ax7 = plt.subplot(4, 4, 7)
    lr_stats = df.groupby('lr').agg({'test_f1': ['mean', 'std', 'count']})
    lr_stats = lr_stats[lr_stats[('test_f1', 'count')] >= 2]
    if len(lr_stats) > 0:
        ax7.errorbar(lr_stats.index, 
                    lr_stats[('test_f1', 'mean')],
                    yerr=lr_stats[('test_f1', 'std')],
                    marker='o', capsize=5, linewidth=2)
        ax7.set_xscale('log')
        ax7.set_xlabel('Learning Rate (log scale)')
        ax7.set_ylabel('Mean Test F1')
        ax7.set_title('Learning Rate Effect')
        ax7.grid(alpha=0.3)
    
    # 8. Architecture size vs performance
    ax8 = plt.subplot(4, 4, 8)
    ax8.scatter(df['total_hidden_units'], df['test_f1'], 
               c=df['overfit_gap'], s=100, alpha=0.6, cmap='RdYlGn_r')
    ax8.set_xlabel('Total Hidden Units')
    ax8.set_ylabel('Test F1')
    ax8.set_title('Model Size vs Performance (color=overfit)')
    ax8.grid(alpha=0.3)
    
    # 9. Batch size effect
    ax9 = plt.subplot(4, 4, 9)
    batch_stats = df.groupby('batch_size').agg({'test_f1': ['mean', 'std', 'count']})
    batch_stats = batch_stats[batch_stats[('test_f1', 'count')] >= 2]
    if len(batch_stats) > 0:
        ax9.errorbar(batch_stats.index, 
                    batch_stats[('test_f1', 'mean')],
                    yerr=batch_stats[('test_f1', 'std')],
                    marker='o', capsize=5, linewidth=2)
        ax9.set_xlabel('Batch Size')
        ax9.set_ylabel('Mean Test F1')
        ax9.set_title('Batch Size Effect')
        ax9.grid(alpha=0.3)
    
    # 10. Convergence speed
    ax10 = plt.subplot(4, 4, 10)
    ax10.scatter(df['best_epoch'], df['test_f1'], s=100, alpha=0.6)
    ax10.set_xlabel('Epoch of Best Validation')
    ax10.set_ylabel('Test F1')
    ax10.set_title('Convergence Speed vs Performance')
    ax10.grid(alpha=0.3)
    
    # 11. False positive vs false negative rates
    ax11 = plt.subplot(4, 4, 11)
    ax11.scatter(df['false_positive_rate'], df['false_negative_rate'],
                c=df['test_f1'], s=100, alpha=0.6, cmap='viridis')
    ax11.set_xlabel('False Positive Rate')
    ax11.set_ylabel('False Negative Rate')
    ax11.set_title('Error Types (color=F1)')
    ax11.grid(alpha=0.3)
    
    # 12. Regularization strength heatmap (dropout vs weight decay)
    ax12 = plt.subplot(4, 4, 12)
    concat_pivot = concat_df.pivot_table(values='test_f1', 
                                         index='dropout', 
                                         columns='weight_decay',
                                         aggfunc='mean')
    if not concat_pivot.empty:
        sns.heatmap(concat_pivot, annot=True, fmt='.3f', cmap='RdYlGn', ax=ax12, cbar_kws={'label': 'F1'})
        ax12.set_title('Concat: Dropout × Weight Decay')
    
    # 13. Architecture depth vs width
    ax13 = plt.subplot(4, 4, 13)
    ax13.scatter(df['num_layers'], df['max_layer_size'],
                c=df['test_f1'], s=100, alpha=0.6, cmap='viridis')
    ax13.set_xlabel('Number of Layers')
    ax13.set_ylabel('Max Layer Size')
    ax13.set_title('Architecture Shape (color=F1)')
    ax13.grid(alpha=0.3)
    
    # 14. Top 10 confusion matrices (simplified)
    ax14 = plt.subplot(4, 4, 14)
    top_10_models = df.nlargest(10, 'test_f1')
    fp_rates = top_10_models['false_positive_rate'].values
    fn_rates = top_10_models['false_negative_rate'].values
    ax14.scatter(fp_rates, fn_rates, s=200, alpha=0.6)
    for i, exp in enumerate(top_10_models['experiment'].values):
        ax14.annotate(f"{i+1}", (fp_rates[i], fn_rates[i]), 
                     ha='center', va='center', fontsize=8, fontweight='bold')
    ax14.set_xlabel('False Positive Rate')
    ax14.set_ylabel('False Negative Rate')
    ax14.set_title('Top 10 Models: Error Rates')
    ax14.grid(alpha=0.3)
    
    # 15. Seed variance (if available)
    ax15 = plt.subplot(4, 4, 15)
    df['exp_base'] = df['experiment'].str.replace(r'_seed\d+', '', regex=True)
    seed_groups = df.groupby('exp_base').filter(lambda x: len(x) > 1)
    if len(seed_groups) > 0:
        variance_data = seed_groups.groupby('exp_base')['test_f1'].agg(['mean', 'std'])
        variance_data = variance_data.sort_values('mean', ascending=False).head(10)
        y_pos = np.arange(len(variance_data))
        ax15.barh(y_pos, variance_data['mean'], xerr=variance_data['std'], capsize=5)
        ax15.set_yticks(y_pos)
        ax15.set_yticklabels([e[:25] for e in variance_data.index], fontsize=7)
        ax15.set_xlabel('F1 Score')
        ax15.set_title('Seed Variance (mean ± std)')
        ax15.grid(axis='x', alpha=0.3)
    else:
        ax15.text(0.5, 0.5, 'No replicated experiments', 
                 ha='center', va='center', transform=ax15.transAxes)
        ax15.set_title('Seed Variance')
    
    # 16. Performance improvement over baseline
    ax16 = plt.subplot(4, 4, 16)
    baseline_f1 = df[df['experiment'].str.contains('baseline')]['test_f1'].mean()
    df['improvement'] = df['test_f1'] - baseline_f1
    top_improvements = df.nlargest(15, 'improvement')
    y_pos = np.arange(len(top_improvements))
    colors = ['green' if x > 0 else 'red' for x in top_improvements['improvement']]
    ax16.barh(y_pos, top_improvements['improvement'], color=colors, alpha=0.6)
    ax16.set_yticks(y_pos)
    ax16.set_yticklabels([e.replace('exp_', '')[:20] for e in top_improvements['experiment']], fontsize=7)
    ax16.axvline(x=0, color='black', linestyle='--', linewidth=1)
    ax16.set_xlabel('F1 Improvement over Baseline')
    ax16.set_title('Top 15 Improvements')
    ax16.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_dir / 'extended_analysis.png', dpi=300, bbox_inches='tight')
    print(f"\nExtended visualization saved to {save_dir / 'extended_analysis.png'}")
    plt.close()


def plot_concat_deep_dive(df, save_dir):
    """Detailed visualization specifically for concat experiments"""
    concat_df = df[df['input_type'] == 'concat'].copy()
    
    if len(concat_df) < 5:
        print("\nNot enough concat experiments for deep dive (need at least 5)")
        return
    
    fig = plt.figure(figsize=(20, 12))
    
    # 1. F1 distribution
    ax1 = plt.subplot(3, 3, 1)
    ax1.hist(concat_df['test_f1'], bins=20, alpha=0.7, edgecolor='black')
    ax1.axvline(concat_df['test_f1'].mean(), color='red', linestyle='--', linewidth=2, label='Mean')
    ax1.axvline(concat_df['test_f1'].median(), color='green', linestyle='--', linewidth=2, label='Median')
    ax1.set_xlabel('Test F1')
    ax1.set_ylabel('Count')
    ax1.set_title(f'F1 Distribution (n={len(concat_df)})')
    ax1.legend()
    ax1.grid(alpha=0.3)
    
    # 2. 3D scatter: dropout × weight_decay × F1
    ax2 = plt.subplot(3, 3, 2, projection='3d')
    scatter = ax2.scatter(concat_df['dropout'], 
                         np.log10(concat_df['weight_decay'] + 1e-10),
                         concat_df['test_f1'],
                         c=concat_df['test_f1'], s=100, alpha=0.6, cmap='viridis')
    ax2.set_xlabel('Dropout')
    ax2.set_ylabel('log10(Weight Decay)')
    ax2.set_zlabel('Test F1')
    ax2.set_title('Regularization Space')
    plt.colorbar(scatter, ax=ax2, shrink=0.5)
    
    # 3. Learning rate vs F1
    ax3 = plt.subplot(3, 3, 3)
    for lr in sorted(concat_df['lr'].unique()):
        lr_data = concat_df[concat_df['lr'] == lr]
        ax3.scatter([lr]*len(lr_data), lr_data['test_f1'], s=100, alpha=0.6, label=f'lr={lr}')
    ax3.set_xscale('log')
    ax3.set_xlabel('Learning Rate')
    ax3.set_ylabel('Test F1')
    ax3.set_title('Learning Rate Landscape')
    ax3.legend(fontsize=8)
    ax3.grid(alpha=0.3)
    
    # 4. Architecture size exploration
    ax4 = plt.subplot(3, 3, 4)
    ax4.scatter(concat_df['total_hidden_units'], concat_df['test_f1'], 
               s=concat_df['num_layers']*50, alpha=0.6, c=concat_df['num_layers'], cmap='plasma')
    ax4.set_xlabel('Total Hidden Units')
    ax4.set_ylabel('Test F1')
    ax4.set_title('Architecture Size (size=depth)')
    ax4.grid(alpha=0.3)
    
    # 5. Overfitting by regularization level
    ax5 = plt.subplot(3, 3, 5)
    concat_df['total_reg'] = concat_df['dropout'] + np.log10(concat_df['weight_decay'] + 1e-9)
    ax5.scatter(concat_df['total_reg'], concat_df['overfit_gap'], 
               c=concat_df['test_f1'], s=100, alpha=0.6, cmap='viridis')
    ax5.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax5.set_xlabel('Total Regularization (dropout + log10(wd))')
    ax5.set_ylabel('Overfitting Gap')
    ax5.set_title('Regularization vs Overfitting')
    ax5.grid(alpha=0.3)
    
    # 6. Batch size effect
    ax6 = plt.subplot(3, 3, 6)
    if concat_df['batch_size'].nunique() > 1:
        batch_grouped = concat_df.groupby('batch_size').agg({
            'test_f1': ['mean', 'std'],
            'overfit_gap': 'mean'
        })
        ax6_twin = ax6.twinx()
        ax6.errorbar(batch_grouped.index, batch_grouped[('test_f1', 'mean')],
                    yerr=batch_grouped[('test_f1', 'std')], 
                    marker='o', color='blue', capsize=5, label='F1')
        ax6_twin.plot(batch_grouped.index, batch_grouped[('overfit_gap', 'mean')],
                     marker='s', color='red', label='Overfit Gap')
        ax6.set_xlabel('Batch Size')
        ax6.set_ylabel('Test F1', color='blue')
        ax6_twin.set_ylabel('Overfit Gap', color='red')
        ax6.set_title('Batch Size Impact')
        ax6.tick_params(axis='y', labelcolor='blue')
        ax6_twin.tick_params(axis='y', labelcolor='red')
        ax6.grid(alpha=0.3)
    
    # 7. Top 10 architectures
    ax7 = plt.subplot(3, 3, 7)
    top_10_concat = concat_df.nlargest(10, 'test_f1')
    y_pos = np.arange(len(top_10_concat))
    ax7.barh(y_pos, top_10_concat['test_f1'], alpha=0.7)
    ax7.set_yticks(y_pos)
    ax7.set_yticklabels([f"{row['hidden_dims']}\nd={row['dropout']}" 
                         for _, row in top_10_concat.iterrows()], fontsize=7)
    ax7.set_xlabel('Test F1')
    ax7.set_title('Top 10 Concat Architectures')
    ax7.grid(axis='x', alpha=0.3)
    
    # 8. Patience effect
    ax8 = plt.subplot(3, 3, 8)
    if concat_df['patience'].nunique() > 1:
        patience_grouped = concat_df.groupby('patience').agg({
            'test_f1': 'mean',
            'best_epoch': 'mean'
        })
        ax8_twin = ax8.twinx()
        ax8.plot(patience_grouped.index, patience_grouped['test_f1'],
                marker='o', color='blue', linewidth=2, label='F1')
        ax8_twin.plot(patience_grouped.index, patience_grouped['best_epoch'],
                     marker='s', color='orange', linewidth=2, label='Best Epoch')
        ax8.set_xlabel('Early Stopping Patience')
        ax8.set_ylabel('Mean Test F1', color='blue')
        ax8_twin.set_ylabel('Mean Best Epoch', color='orange')
        ax8.set_title('Patience Impact')
        ax8.tick_params(axis='y', labelcolor='blue')
        ax8_twin.tick_params(axis='y', labelcolor='orange')
        ax8.grid(alpha=0.3)
    
    # 9. Performance correlation matrix
    ax9 = plt.subplot(3, 3, 9)
    corr_cols = ['test_f1', 'test_precision', 'test_recall', 'overfit_gap', 
                 'dropout', 'total_hidden_units', 'best_epoch']
    corr_matrix = concat_df[corr_cols].corr()
    sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='coolwarm', 
                center=0, ax=ax9, cbar_kws={'label': 'Correlation'})
    ax9.set_title('Feature Correlation Matrix')
    
    plt.tight_layout()
    plt.savefig(save_dir / 'concat_deep_dive.png', dpi=300, bbox_inches='tight')
    print(f"Concat deep dive saved to {save_dir / 'concat_deep_dive.png'}")
    plt.close()


def plot_confusion_matrices(results, df, save_dir):
    """Plot confusion matrices for different categories of models"""
    
    # Get top performers, most generalizable, high precision, high recall
    categories = {
        'Top 3 F1': df.nlargest(3, 'test_f1')['experiment'].tolist(),
        'Most Generalizable': df.nsmallest(3, lambda x: abs(x['overfit_gap']))['experiment'].tolist(),
        'High Precision': df.nlargest(3, 'test_precision')['experiment'].tolist(),
        'High Recall': df.nlargest(3, 'test_recall')['experiment'].tolist(),
    }
    
    fig, axes = plt.subplots(4, 3, figsize=(15, 18))
    
    for row_idx, (category, exp_names) in enumerate(categories.items()):
        for col_idx, exp_name in enumerate(exp_names):
            ax = axes[row_idx, col_idx]
            
            # Find the result
            result = next((r for r in results if r['experiment_name'] == exp_name), None)
            if result is None:
                continue
            
            cm = np.array(result['confusion_matrix'])
            
            # Plot heatmap
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                       xticklabels=['Neg', 'Pos'], yticklabels=['Neg', 'Pos'],
                       cbar_kws={'label': 'Count'})
            
            # Add metrics to title
            f1 = result['test_f1']
            prec = result['test_precision']
            rec = result['test_recall']
            
            title = f"{exp_name.replace('exp_', '').replace('concat_', 'c_')[:25]}\n"
            title += f"F1:{f1:.3f} P:{prec:.3f} R:{rec:.3f}"
            ax.set_title(title, fontsize=9)
            ax.set_ylabel('True')
            ax.set_xlabel('Predicted')
    
    # Add category labels
    for row_idx, category in enumerate(categories.keys()):
        fig.text(0.02, 0.88 - row_idx*0.23, category, 
                rotation=90, fontsize=14, fontweight='bold',
                va='center', ha='center')
    
    plt.tight_layout(rect=[0.03, 0, 1, 1])
    plt.savefig(save_dir / 'confusion_matrices_by_category.png', dpi=300, bbox_inches='tight')
    print(f"Confusion matrices saved to {save_dir / 'confusion_matrices_by_category.png'}")
    plt.close()


def generate_recommendations(df):
    """Generate detailed recommendations based on all experiments"""
    print(f"\n" + "="*120)
    print("DETAILED RECOMMENDATIONS & INSIGHTS")
    print("="*120)
    
    best_overall = df.loc[df['test_f1'].idxmax()]
    least_overfit = df.loc[df['overfit_gap'].abs().idxmin()]
    best_recall = df.loc[df['test_recall'].idxmax()]
    best_precision = df.loc[df['test_precision'].idxmax()]
    best_auc = df.loc[df['test_roc_auc'].idxmax()]
    
    print(f"\n{'='*120}")
    print("1. PRODUCTION DEPLOYMENT RECOMMENDATIONS")
    print(f"{'='*120}")
    
    print(f"\n🏆 Best Overall Model (maximize F1):")
    print(f"   {best_overall['experiment']}")
    print(f"   • F1: {best_overall['test_f1']:.4f} | Precision: {best_overall['test_precision']:.4f} | Recall: {best_overall['test_recall']:.4f}")
    print(f"   • Config: {best_overall['input_type']}, {best_overall['hidden_dims']}")
    print(f"   • Regularization: dropout={best_overall['dropout']}, weight_decay={best_overall['weight_decay']}")
    print(f"   • Training: lr={best_overall['lr']}, batch={best_overall['batch_size']}")
    print(f"   • Overfitting gap: {best_overall['overfit_gap']:.4f}")
    
    print(f"\n🎯 Most Reliable (best generalization):")
    print(f"   {least_overfit['experiment']}")
    print(f"   • F1: {least_overfit['test_f1']:.4f} | Gap: {least_overfit['overfit_gap']:.4f}")
    print(f"   • Config: {least_overfit['input_type']}, {least_overfit['hidden_dims']}")
    print(f"   • Use when: You need consistent performance on unseen data")
    
    print(f"\n🔍 Best for Catching Deforestation (minimize false negatives):")
    print(f"   {best_recall['experiment']}")
    print(f"   • Recall: {best_recall['test_recall']:.4f} (catches {best_recall['test_recall']*100:.1f}% of deforestation)")
    print(f"   • Precision: {best_recall['test_precision']:.4f} (FP rate: {best_recall['false_positive_rate']:.4f})")
    print(f"   • Use when: Missing deforestation is more costly than false alarms")
    
    print(f"\n✅ Best for Reducing False Alarms (minimize false positives):")
    print(f"   {best_precision['experiment']}")
    print(f"   • Precision: {best_precision['test_precision']:.4f} ({best_precision['test_precision']*100:.1f}% of alerts are real)")
    print(f"   • Recall: {best_precision['test_recall']:.4f} (FN rate: {best_precision['false_negative_rate']:.4f})")
    print(f"   • Use when: False alarms are expensive or cause alert fatigue")
    
    print(f"\n📊 Best Ranking Model (highest AUC):")
    print(f"   {best_auc['experiment']}")
    print(f"   • AUC: {best_auc['test_roc_auc']:.4f}")
    print(f"   • Use when: You need to rank areas by deforestation likelihood")
    
    print(f"\n{'='*120}")
    print("2. KEY INSIGHTS FROM EXPERIMENTS")
    print(f"{'='*120}")
    
    # Input type insights
    print(f"\n📥 Input Representation:")
    input_analysis = df.groupby('input_type').agg({
        'test_f1': ['mean', 'max', 'count']
    }).round(4)
    for input_type in df['input_type'].unique():
        stats = input_analysis.loc[input_type]
        print(f"   • {input_type}: mean F1={stats[('test_f1', 'mean')]:.4f}, "
              f"max F1={stats[('test_f1', 'max')]:.4f} (n={int(stats[('test_f1', 'count')])})")
    
    best_input = df.groupby('input_type')['test_f1'].mean().idxmax()
    print(f"   ➜ Winner: {best_input}")
    
    # Regularization insights
    print(f"\n🛡️ Regularization:")
    avg_overfit = df['overfit_gap'].mean()
    severe_overfit = (df['overfit_gap'] > 0.15).sum()
    well_generalized = (df['overfit_gap'].abs() < 0.05).sum()
    
    print(f"   • Average overfitting gap: {avg_overfit:.4f}")
    print(f"   • Models with severe overfit (>0.15): {severe_overfit}/{len(df)}")
    print(f"   • Well-generalized models (|gap|<0.05): {well_generalized}/{len(df)}")
    
    if avg_overfit > 0.10:
        print(f"   ⚠️  STRONG OVERFITTING DETECTED!")
        
        # Find what helps
        concat_df = df[df['input_type'] == 'concat']
        if len(concat_df) > 5:
            high_reg = concat_df[concat_df['dropout'] >= 0.5]
            low_reg = concat_df[concat_df['dropout'] < 0.5]
            if len(high_reg) > 0 and len(low_reg) > 0:
                print(f"   • High dropout (≥0.5) gap: {high_reg['overfit_gap'].mean():.4f}")
                print(f"   • Low dropout (<0.5) gap: {low_reg['overfit_gap'].mean():.4f}")
    
    # Optimal hyperparameters
    print(f"\n⚙️ Optimal Hyperparameters (based on top 10 models):")
    top_10 = df.nlargest(10, 'test_f1')
    
    print(f"   • Dropout: {top_10['dropout'].mode().values[0]} (mode), "
          f"range: [{top_10['dropout'].min()}, {top_10['dropout'].max()}]")
    print(f"   • Learning rate: {top_10['lr'].mode().values[0]} (mode), "
          f"range: [{top_10['lr'].min()}, {top_10['lr'].max()}]")
    print(f"   • Weight decay: {top_10['weight_decay'].mode().values[0]} (mode), "
          f"range: [{top_10['weight_decay'].min()}, {top_10['weight_decay'].max()}]")
    print(f"   • Batch size: {top_10['batch_size'].mode().values[0]} (mode)")
    
    # Architecture insights
    print(f"\n🏗️ Architecture:")
    print(f"   • Depth: {top_10['num_layers'].mode().values[0]} layers (mode), "
          f"range: [{top_10['num_layers'].min()}, {top_10['num_layers'].max()}]")
    print(f"   • Total size: {top_10['total_hidden_units'].mean():.0f} units (mean), "
          f"range: [{top_10['total_hidden_units'].min()}, {top_10['total_hidden_units'].max()}]")
    
    # Convergence insights
    print(f"\n⏱️ Training Dynamics:")
    print(f"   • Average convergence: {df['best_epoch'].mean():.1f} epochs")
    print(f"   • Fastest: {df['best_epoch'].min():.0f} epochs ({df.loc[df['best_epoch'].idxmin(), 'experiment']})")
    print(f"   • Slowest: {df['best_epoch'].max():.0f} epochs")
    
    print(f"\n{'='*120}")
    print("3. ACTIONABLE NEXT STEPS")
    print(f"{'='*120}")
    
    print(f"\n✅ Implement immediately:")
    print(f"   1. Use {best_overall['experiment']} for production")
    print(f"   2. Set up ensemble with top 3-5 models for robustness")
    print(f"   3. Use {best_recall['experiment']} if false negatives are critical")
    
    print(f"\n🔬 Further experimentation:")
    
    # Check what hasn't been explored enough
    if df['input_type'].value_counts().get('combined', 0) < 5:
        print(f"   • Explore 'combined' input more (only {df['input_type'].value_counts().get('combined', 0)} experiments)")
    
    if len(df[df['dropout'] > 0.6]) < 5:
        print(f"   • Test higher dropout (>0.6) to reduce overfitting")
    
    if df['weight_decay'].nunique() < 5:
        print(f"   • Expand weight decay sweep")
    
    if (df['overfit_gap'] > 0.10).sum() > len(df) * 0.5:
        print(f"   • Focus on regularization - majority of models overfit")
    
    # Seed variance check
    df['exp_base'] = df['experiment'].str.replace(r'_seed\d+', '', regex=True)
    replicated = df.groupby('exp_base').size()
    if (replicated > 1).sum() < 3:
        print(f"   • Add more seed replicates for variance estimation")
    
    print(f"\n📊 Model selection guide:")
    print(f"   • Need maximum F1? → {best_overall['experiment']}")
    print(f"   • Need reliability? → {least_overfit['experiment']}")
    print(f"   • Need to catch everything? → {best_recall['experiment']}")
    print(f"   • Need to avoid false alarms? → {best_precision['experiment']}")
    print(f"   • Need probability ranking? → {best_auc['experiment']}")


def export_results_csv(df, save_dir):
    """Export detailed results to CSV"""
    output_file = save_dir / 'detailed_results.csv'
    df.to_csv(output_file, index=False)
    print(f"\n✅ Detailed results exported to {output_file}")


In [24]:

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import rootutils
from scipy import stats

root_path = rootutils.find_root()

def load_experiment_results(summary_file=None, folder_name="experiments_more"):
    """Load experiment results from JSON file"""
    experiments_dir = root_path / folder_name

    if summary_file is None:
        summary_files = sorted(experiments_dir.glob("summary_*.json"))
        if not summary_files:
            raise FileNotFoundError("No summary files found in experiments directory")
        summary_file = summary_files[-1]
        print(f"Loading most recent summary: {summary_file.name}")
    else:
        summary_file = Path(summary_file)
    
    with open(summary_file, 'r') as f:
        results = json.load(f)
    
    return results, summary_file


def results_to_dataframe(results):
    """Convert results list to pandas DataFrame with extended features"""
    if not results:
        raise ValueError("No results to analyze")
    
    data = []
    for result in results:
        row = {
            'experiment': result['experiment_name'],
            'test_f1': result['test_f1'],
            'test_precision': result['test_precision'],
            'test_recall': result['test_recall'],
            'test_auc': result['test_roc_auc'],
            'val_f1': result['best_val_f1'],
            'overfit_gap': result['overfit_gap'],
            'best_epoch': result['best_epoch'],
            'input_type': result['config']['input_type'],
            'hidden_dims': str(result['config']['hidden_dims']),
            'dropout': result['config']['dropout'],
            'lr': result['config']['lr'],
            'weight_decay': result['config']['weight_decay'],
            'batch_size': result['config']['batch_size'],
            'patience': result['config'].get('patience', 15),
            'seed': result['config']['seed'],
        }
        
        # Compute derived metrics
        cm = np.array(result['confusion_matrix'])
        row['tn'] = cm[0, 0]
        row['fp'] = cm[0, 1]
        row['fn'] = cm[1, 0]
        row['tp'] = cm[1, 1]
        row['false_positive_rate'] = row['fp'] / (row['fp'] + row['tn']) if (row['fp'] + row['tn']) > 0 else 0
        row['false_negative_rate'] = row['fn'] / (row['fn'] + row['tp']) if (row['fn'] + row['tp']) > 0 else 0
        
        # Model complexity - handle both list and string formats
        hidden_dims_raw = result['config']['hidden_dims']
        if isinstance(hidden_dims_raw, str):
            hidden_list = eval(hidden_dims_raw)
        else:
            hidden_list = hidden_dims_raw
        
        row['num_layers'] = len(hidden_list)
        row['total_hidden_units'] = sum(hidden_list)
        row['max_layer_size'] = max(hidden_list)
        
        data.append(row)
    
    df = pd.DataFrame(data)
    return df


def print_summary_table(df, top_n=20):
    """Print formatted summary table"""
    print("\n" + "="*120)
    print("EXPERIMENT RESULTS SUMMARY")
    print("="*120)
    
    df_sorted = df.sort_values('test_f1', ascending=False)
    
    print(f"\n{'Rank':<6}{'Experiment':<35}{'F1':<8}{'Prec':<8}{'Recall':<8}{'AUC':<8}{'Overfit':<10}{'Epoch':<7}")
    print("-"*120)
    
    for idx, (_, row) in enumerate(df_sorted.head(top_n).iterrows(), 1):
        print(f"{idx:<6}{row['experiment']:<35}{row['test_f1']:<8.4f}{row['test_precision']:<8.4f}"
              f"{row['test_recall']:<8.4f}{row['test_auc']:<8.4f}{row['overfit_gap']:<10.4f}{row['best_epoch']:<7.0f}")
    
    # Print bottom performers too
    if len(df_sorted) > 5:
        print("\n" + "-"*120)
        print("BOTTOM 5 PERFORMERS:")
        print("-"*120)
        for idx, (_, row) in enumerate(df_sorted.tail(5).iterrows(), len(df_sorted)-4):
            print(f"{idx:<6}{row['experiment']:<35}{row['test_f1']:<8.4f}{row['test_precision']:<8.4f}"
                  f"{row['test_recall']:<8.4f}{row['test_auc']:<8.4f}{row['overfit_gap']:<10.4f}{row['best_epoch']:<7.0f}")
    
    # Best by different metrics
    print(f"\n" + "="*120)
    print("BEST MODELS BY METRIC")
    print("="*120)
    
    metrics = {
        'F1 Score': ('test_f1', 'max'),
        'Precision': ('test_precision', 'max'),
        'Recall': ('test_recall', 'max'),
        'AUC': ('test_auc', 'max'),
        'Least Overfit': ('overfit_gap', 'min'),
        'Fastest Convergence': ('best_epoch', 'min'),
        'Best Val-Test Agreement': ('overfit_gap', 'abs_min'),
    }
    
    for metric_name, (col, agg) in metrics.items():
        if agg == 'abs_min':
            idx = df[col].abs().idxmin()
        else:
            idx = df[col].idxmax() if agg == 'max' else df[col].idxmin()
        row = df.loc[idx]
        print(f"\n{metric_name}:")
        print(f"  Model: {row['experiment']}")
        print(f"  Value: {row[col]:.4f}")
        print(f"  Config: input={row['input_type']}, arch={row['hidden_dims']}, "
              f"dropout={row['dropout']}, lr={row['lr']}, wd={row['weight_decay']}")


def analyze_by_category(df):
    """Analyze results by different configuration categories"""
    print(f"\n" + "="*120)
    print("DETAILED ANALYSIS BY CONFIGURATION")
    print("="*120)
    
    # Helper function for pretty printing
    def print_analysis(group_col, title):
        print(f"\n{title}:")
        print("-"*120)
        
        if group_col not in df.columns or df[group_col].isna().all():
            print("  No data available for this category")
            return
            
        analysis = df.groupby(group_col).agg({
            'test_f1': ['count', 'mean', 'std', 'min', 'max'],
            'overfit_gap': ['mean', 'std'],
            'test_precision': 'mean',
            'test_recall': 'mean',
            'test_auc': 'mean',
            'best_epoch': 'mean'
        }).round(4)
        print(analysis)
        
        # Statistical test if enough groups
        if len(df[group_col].unique()) >= 2:
            groups = [group['test_f1'].values for name, group in df.groupby(group_col) if len(group) >= 2]
            if len(groups) >= 2:
                try:
                    statistic, pvalue = stats.f_oneway(*groups)
                    print(f"\nANOVA F-statistic: {statistic:.4f}, p-value: {pvalue:.4e}")
                    if pvalue < 0.05:
                        print("  → Statistically significant differences detected!")
                    else:
                        print("  → No significant differences (might just be noise)")
                except:
                    pass
    
    # Run analyses
    print_analysis('input_type', '1. BY INPUT TYPE')
    print_analysis('dropout', '2. BY DROPOUT RATE')
    print_analysis('lr', '3. BY LEARNING RATE')
    print_analysis('weight_decay', '4. BY WEIGHT DECAY')
    print_analysis('batch_size', '5. BY BATCH SIZE')
    print_analysis('num_layers', '6. BY NETWORK DEPTH')
    
    if 'patience' in df.columns and df['patience'].nunique() > 1:
        print_analysis('patience', '7. BY EARLY STOPPING PATIENCE')
    
    # Architecture size bins
    df['arch_size'] = pd.cut(df['total_hidden_units'], 
                              bins=[0, 300, 600, 900, 1500, 5000],
                              labels=['tiny', 'small', 'medium', 'large', 'huge'])
    print_analysis('arch_size', '8. BY ARCHITECTURE SIZE')


def deep_dive_concat(df):
    """Deep analysis specifically for concat experiments"""
    print(f"\n" + "="*120)
    print("DEEP DIVE: CONCAT INPUT EXPERIMENTS")
    print("="*120)
    
    concat_df = df[df['input_type'] == 'concat'].copy()
    
    if len(concat_df) == 0:
        print("No concat experiments found!")
        return
    
    print(f"\nTotal concat experiments: {len(concat_df)}")
    print(f"Best F1: {concat_df['test_f1'].max():.4f}")
    print(f"Worst F1: {concat_df['test_f1'].min():.4f}")
    print(f"Mean F1: {concat_df['test_f1'].mean():.4f} ± {concat_df['test_f1'].std():.4f}")
    print(f"Median overfit gap: {concat_df['overfit_gap'].median():.4f}")
    
    # Find optimal hyperparameters
    print("\n" + "-"*120)
    print("OPTIMAL HYPERPARAMETERS FOR CONCAT:")
    print("-"*120)
    
    for param in ['dropout', 'lr', 'weight_decay', 'batch_size', 'hidden_dims']:
        best_idx = concat_df['test_f1'].idxmax()
        top_5 = concat_df.nlargest(min(5, len(concat_df)), 'test_f1')
        
        print(f"\n{param.upper()}:")
        print(f"  Best single model: {concat_df.loc[best_idx, param]}")
        
        if param == 'hidden_dims':
            print(f"  Most common in top {len(top_5)}: {top_5[param].mode().values[0] if len(top_5[param].mode()) > 0 else 'N/A'}")
        else:
            print(f"  Top {len(top_5)} average: {top_5[param].mean():.6f}")
        
        # Show distribution
        if param != 'hidden_dims' and len(top_5) > 0:
            value_counts = top_5[param].value_counts()
            print(f"  Distribution in top {len(top_5)}:")
            for val, count in value_counts.head(3).items():
                print(f"    {val}: {count}/{len(top_5)}")
    
    # Correlation analysis
    print("\n" + "-"*120)
    print("HYPERPARAMETER CORRELATIONS WITH F1:")
    print("-"*120)
    
    numeric_params = ['dropout', 'lr', 'weight_decay', 'batch_size', 'total_hidden_units', 'num_layers']
    correlations = {}
    
    for param in numeric_params:
        if param in concat_df.columns and concat_df[param].nunique() > 1:
            corr = concat_df[param].corr(concat_df['test_f1'])
            correlations[param] = corr
            print(f"{param:>20}: {corr:>7.4f}")
    
    # Find sweet spots
    print("\n" + "-"*120)
    print("SWEET SPOT RANGES (values where F1 > median):")
    print("-"*120)
    
    median_f1 = concat_df['test_f1'].median()
    high_performers = concat_df[concat_df['test_f1'] > median_f1]
    
    if len(high_performers) > 0:
        for param in ['dropout', 'lr', 'weight_decay']:
            if high_performers[param].nunique() > 1:
                print(f"{param}: [{high_performers[param].min():.6f}, {high_performers[param].max():.6f}]")
    else:
        print("No high performers found")


def analyze_regularization_tradeoffs(df):
    """Analyze the relationship between regularization and performance"""
    print(f"\n" + "="*120)
    print("REGULARIZATION ANALYSIS")
    print("="*120)
    
    # Define regularization strength
    df['reg_strength'] = df['dropout'] + np.log10(df['weight_decay'] + 1e-10) / 5
    
    # Bin into categories
    df['reg_category'] = pd.cut(df['reg_strength'], 
                                 bins=[-np.inf, -1, 0, 1, np.inf],
                                 labels=['minimal', 'light', 'moderate', 'heavy'])
    
    print("\nPerformance by regularization strength:")
    try:
        reg_analysis = df.groupby('reg_category').agg({
            'test_f1': ['count', 'mean', 'std', 'max'],
            'overfit_gap': ['mean', 'std'],
            'test_precision': 'mean',
            'test_recall': 'mean',
        }).round(4)
        print(reg_analysis)
    except Exception as e:
        print(f"  Could not compute regularization analysis: {e}")
    
    # Overfitting analysis
    print("\n" + "-"*120)
    print("OVERFITTING STATISTICS:")
    print("-"*120)
    print(f"Models with overfit_gap > 0.15: {(df['overfit_gap'] > 0.15).sum()}/{len(df)}")
    print(f"Models with overfit_gap < 0.05: {(df['overfit_gap'] < 0.05).sum()}/{len(df)}")
    print(f"Models with negative gap (test > train): {(df['overfit_gap'] < 0).sum()}/{len(df)}")
    
    # Best regularization strategy
    well_generalized = df[df['overfit_gap'].abs() < 0.10]
    if len(well_generalized) > 0:
        best_generalized = well_generalized.loc[well_generalized['test_f1'].idxmax()]
        print(f"\nBest well-generalized model (|gap| < 0.10):")
        print(f"  {best_generalized['experiment']}")
        print(f"  F1: {best_generalized['test_f1']:.4f}, Gap: {best_generalized['overfit_gap']:.4f}")
        print(f"  Config: dropout={best_generalized['dropout']}, wd={best_generalized['weight_decay']}")


def analyze_precision_recall_tradeoff(df):
    """Analyze precision-recall tradeoff and identify Pareto frontier"""
    print(f"\n" + "="*120)
    print("PRECISION-RECALL TRADEOFF ANALYSIS")
    print("="*120)
    
    # Sort by F1 and show precision-recall balance
    print("\nTop 10 models by F1 with precision-recall breakdown:")
    print("-"*120)
    top_10 = df.nlargest(min(10, len(df)), 'test_f1')
    
    for idx, (_, row) in enumerate(top_10.iterrows(), 1):
        prec_rec_diff = abs(row['test_precision'] - row['test_recall'])
        balance = "balanced" if prec_rec_diff < 0.05 else ("precision-biased" if row['test_precision'] > row['test_recall'] else "recall-biased")
        print(f"{idx:>2}. {row['experiment']:<35} F1:{row['test_f1']:.4f}  "
              f"P:{row['test_precision']:.4f}  R:{row['test_recall']:.4f}  [{balance}]")
    
    # Identify Pareto-optimal models
    print("\n" + "-"*120)
    print("PARETO FRONTIER (models where you can't improve one metric without hurting the other):")
    print("-"*120)
    
    pareto_models = []
    for idx, row in df.iterrows():
        is_pareto = True
        for _, other in df.iterrows():
            if (other['test_precision'] >= row['test_precision'] and 
                other['test_recall'] >= row['test_recall'] and
                (other['test_precision'] > row['test_precision'] or other['test_recall'] > row['test_recall'])):
                is_pareto = False
                break
        if is_pareto:
            pareto_models.append(row)
    
    pareto_df = pd.DataFrame(pareto_models).sort_values('test_f1', ascending=False)
    print(f"\nFound {len(pareto_df)} Pareto-optimal models:")
    for _, row in pareto_df.head(10).iterrows():
        print(f"  {row['experiment']:<35} P:{row['test_precision']:.4f}  R:{row['test_recall']:.4f}  F1:{row['test_f1']:.4f}")


def analyze_seed_variance(df):
    """Analyze variance across different random seeds"""
    print(f"\n" + "="*120)
    print("RANDOM SEED VARIANCE ANALYSIS")
    print("="*120)
    
    # Find experiments with multiple seeds
    df['exp_base'] = df['experiment'].str.replace(r'_seed\d+', '', regex=True)
    seed_groups = df.groupby('exp_base').filter(lambda x: len(x) > 1)
    
    if len(seed_groups) == 0:
        print("\nNo replicated experiments found (need same config with different seeds)")
        return
    
    print(f"\nFound {seed_groups['exp_base'].nunique()} configurations with multiple seeds:")
    print("-"*120)
    
    for exp_base, group in seed_groups.groupby('exp_base'):
        print(f"\n{exp_base}:")
        print(f"  Runs: {len(group)}")
        print(f"  F1: {group['test_f1'].mean():.4f} ± {group['test_f1'].std():.4f} "
              f"(range: [{group['test_f1'].min():.4f}, {group['test_f1'].max():.4f}])")
        print(f"  Precision: {group['test_precision'].mean():.4f} ± {group['test_precision'].std():.4f}")
        print(f"  Recall: {group['test_recall'].mean():.4f} ± {group['test_recall'].std():.4f}")
        print(f"  Overfit gap: {group['overfit_gap'].mean():.4f} ± {group['overfit_gap'].std():.4f}")
        
        if group['test_f1'].std() > 0.02:
            print(f"  ⚠️  High variance detected! Results may be unstable.")

In [25]:

def export_results_csv(df, save_dir):
    """Export detailed results to CSV"""
    output_file = save_dir / 'detailed_results.csv'
    df.to_csv(output_file, index=False)
    print(f"\n✅ Detailed results exported to {output_file}")



In [26]:

def main():
    """Main analysis function"""
    results, summary_file = load_experiment_results(folder_name="experiments")
    
    if not results:
        print("No experiment results found!")
        return
    
    df = results_to_dataframe(results)
    save_dir = root_path / "experiments"
    
    # Run text analyses
    print_summary_table(df, top_n=20)
    analyze_by_category(df)
    deep_dive_concat(df)
    analyze_regularization_tradeoffs(df)
    analyze_precision_recall_tradeoff(df)
    analyze_seed_variance(df)
    
    # Export to CSV
    export_results_csv(df, save_dir)
    
    print(f"\n{'='*120}")
    print("✅ ANALYSIS COMPLETE!")
    print(f"{'='*120}")
    print(f"\nResults saved to: {save_dir}")


if __name__ == "__main__":
    main()

Loading most recent summary: summary_20250929_184544.json

EXPERIMENT RESULTS SUMMARY

Rank  Experiment                         F1      Prec    Recall  AUC     Overfit   Epoch  
------------------------------------------------------------------------------------------------------------------------
1     concat_minimal_compression         0.8232  0.7670  0.8882  0.8824  0.1260    55     
2     concat_very_wide                   0.8173  0.7719  0.8684  0.8856  0.1105    43     
3     concat_wide_v2                     0.8162  0.7751  0.8618  0.8837  0.1352    49     
4     concat_wide_v1                     0.8050  0.7711  0.8421  0.8827  0.0904    20     

BEST MODELS BY METRIC

F1 Score:
  Model: concat_minimal_compression
  Value: 0.8232
  Config: input=concat, arch=[1400, 700, 350], dropout=0.4, lr=0.001, wd=0.0001

Precision:
  Model: concat_wide_v2
  Value: 0.7751
  Config: input=concat, arch=[1280, 640, 320], dropout=0.4, lr=0.001, wd=0.0001

Recall:
  Model: concat_minimal_compre

  analysis = df.groupby(group_col).agg({
  reg_analysis = df.groupby('reg_category').agg({


In [27]:

def main():
    """Main analysis function"""
    results, summary_file = load_experiment_results(folder_name="experiments_more")
    
    if not results:
        print("No experiment results found!")
        return
    
    df = results_to_dataframe(results)
    save_dir = root_path / "experiments_more"
    
    # Run text analyses
    print_summary_table(df, top_n=20)
    analyze_by_category(df)
    deep_dive_concat(df)
    analyze_regularization_tradeoffs(df)
    analyze_precision_recall_tradeoff(df)
    analyze_seed_variance(df)
    
    # Export to CSV
    export_results_csv(df, save_dir)
    
    print(f"\n{'='*120}")
    print("✅ ANALYSIS COMPLETE!")
    print(f"{'='*120}")
    print(f"\nResults saved to: {save_dir}")


if __name__ == "__main__":
    main()

Loading most recent summary: summary_20250929_191713.json

EXPERIMENT RESULTS SUMMARY

Rank  Experiment                         F1      Prec    Recall  AUC     Overfit   Epoch  
------------------------------------------------------------------------------------------------------------------------
1     concat_best_seed2                  0.8740  0.8464  0.9035  0.9310  0.0398    63     
2     concat_minimal_compression         0.8717  0.8393  0.9068  0.9201  0.0958    53     
3     concat_wide_v2                     0.8689  0.8261  0.9164  0.9188  0.0926    47     
4     concat_batch_32                    0.8674  0.8424  0.8939  0.9184  0.1018    52     
5     concat_ablate_wider                0.8673  0.8338  0.9035  0.9194  0.0950    51     
6     concat_best_seed4                  0.8661  0.8488  0.8842  0.9328  0.0539    50     
7     concat_extra_large                 0.8645  0.8130  0.9228  0.9215  0.0916    37     
8     combined_large                     0.8632  0.8107  0.9228 

  analysis = df.groupby(group_col).agg({
  groups = [group['test_f1'].values for name, group in df.groupby(group_col) if len(group) >= 2]
  reg_analysis = df.groupby('reg_category').agg({


In [28]:

def main():
    """Main analysis function"""
    results, summary_file = load_experiment_results(folder_name="experiments_more")
    
    if not results:
        print("No experiment results found!")
        return
    
    df = results_to_dataframe(results)
    save_dir = root_path / "experiments_more"
    
    # Run text analyses
    print_summary_table(df, top_n=20)
    analyze_by_category(df)
    deep_dive_concat(df)
    analyze_regularization_tradeoffs(df)
    analyze_precision_recall_tradeoff(df)
    analyze_seed_variance(df)
    
    # Export to CSV
    export_results_csv(df, save_dir)
    
    print(f"\n{'='*120}")
    print("✅ ANALYSIS COMPLETE!")
    print(f"{'='*120}")
    print(f"\nResults saved to: {save_dir}")


if __name__ == "__main__":
    main()

Loading most recent summary: summary_20250929_193213.json

EXPERIMENT RESULTS SUMMARY

Rank  Experiment                         F1      Prec    Recall  AUC     Overfit   Epoch  
------------------------------------------------------------------------------------------------------------------------
1     second_best_seed1                  0.8722  0.8667  0.8778  0.9386  0.0921    46     
2     second_best_seed2                  0.8717  0.8393  0.9068  0.9201  0.0958    53     
3     second_best_seed4                  0.8665  0.8378  0.8971  0.9313  0.0570    25     
4     best_model_seed4                   0.8661  0.8488  0.8842  0.9328  0.0539    50     
5     second_best_seed3                  0.8494  0.8466  0.8521  0.9243  0.0905    27     
6     best_model_seed3                   0.8467  0.8789  0.8167  0.9241  0.0647    37     
7     best_model_seed1                   0.8429  0.8403  0.8457  0.9223  0.0631    34     
8     best_model_seed2                   0.7443  0.6376  0.8939 

  analysis = df.groupby(group_col).agg({
  groups = [group['test_f1'].values for name, group in df.groupby(group_col) if len(group) >= 2]
  reg_analysis = df.groupby('reg_category').agg({


In [12]:
# # scripts/analyze_experiments.py

# import json
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns
# from pathlib import Path
# import rootutils

# root_path = rootutils.find_root()

# def load_experiment_results(summary_file=None):
#     """Load experiment results from JSON file"""
#     experiments_dir = root_path / "experiments"
    
#     if summary_file is None:
#         # Find the most recent summary file
#         summary_files = sorted(experiments_dir.glob("summary_*.json"))
#         if not summary_files:
#             raise FileNotFoundError("No summary files found in experiments directory")
#         summary_file = summary_files[-1]
#         print(f"Loading most recent summary: {summary_file.name}")
#     else:
#         summary_file = Path(summary_file)
    
#     with open(summary_file, 'r') as f:
#         results = json.load(f)
    
#     return results, summary_file


# def results_to_dataframe(results):
#     """Convert results list to pandas DataFrame"""
#     if not results:
#         raise ValueError("No results to analyze - all experiments failed")
    
#     data = []
#     for result in results:
#         row = {
#             'experiment': result['experiment_name'],
#             'test_f1': result['test_f1'],
#             'test_precision': result['test_precision'],
#             'test_recall': result['test_recall'],
#             'test_auc': result['test_roc_auc'],
#             'val_f1': result['best_val_f1'],
#             'overfit_gap': result['overfit_gap'],
#             'best_epoch': result['best_epoch'],
#             'input_type': result['config']['input_type'],
#             'hidden_dims': str(result['config']['hidden_dims']),
#             'dropout': result['config']['dropout'],
#             'lr': result['config']['lr'],
#             'weight_decay': result['config']['weight_decay'],
#             'batch_size': result['config']['batch_size'],
#         }
#         data.append(row)
    
#     df = pd.DataFrame(data)
#     return df


# def print_summary_table(df):
#     """Print formatted summary table"""
#     print("\n" + "="*100)
#     print("EXPERIMENT RESULTS SUMMARY")
#     print("="*100)
    
#     # Sort by test F1
#     df_sorted = df.sort_values('test_f1', ascending=False)
    
#     print(f"\n{'Rank':<6}{'Experiment':<30}{'F1':<8}{'Prec':<8}{'Recall':<8}{'AUC':<8}{'Overfit':<10}")
#     print("-"*100)
    
#     for idx, (_, row) in enumerate(df_sorted.iterrows(), 1):
#         print(f"{idx:<6}{row['experiment']:<30}{row['test_f1']:<8.4f}{row['test_precision']:<8.4f}"
#               f"{row['test_recall']:<8.4f}{row['test_auc']:<8.4f}{row['overfit_gap']:<10.4f}")
    
#     # Print best models by different metrics
#     print(f"\n" + "="*100)
#     print("BEST MODELS BY METRIC")
#     print("="*100)
    
#     print(f"\nBest F1 Score:      {df_sorted.iloc[0]['experiment']:<30} (F1: {df_sorted.iloc[0]['test_f1']:.4f})")
#     print(f"Best Precision:     {df.loc[df['test_precision'].idxmax(), 'experiment']:<30} "
#           f"(Prec: {df['test_precision'].max():.4f})")
#     print(f"Best Recall:        {df.loc[df['test_recall'].idxmax(), 'experiment']:<30} "
#           f"(Recall: {df['test_recall'].max():.4f})")
#     print(f"Best AUC:           {df.loc[df['test_auc'].idxmax(), 'experiment']:<30} "
#           f"(AUC: {df['test_auc'].max():.4f})")
#     print(f"Least Overfit:      {df.loc[df['overfit_gap'].idxmin(), 'experiment']:<30} "
#           f"(Gap: {df['overfit_gap'].min():.4f})")


# def analyze_by_category(df):
#     """Analyze results by different configuration categories"""
#     print(f"\n" + "="*100)
#     print("ANALYSIS BY CONFIGURATION")
#     print("="*100)
    
#     # By input type
#     print("\n1. By Input Type:")
#     input_analysis = df.groupby('input_type').agg({
#         'test_f1': ['mean', 'std', 'max'],
#         'overfit_gap': 'mean'
#     }).round(4)
#     print(input_analysis)
    
#     # By dropout
#     print("\n2. By Dropout Rate:")
#     dropout_analysis = df.groupby('dropout').agg({
#         'test_f1': ['mean', 'std', 'max'],
#         'overfit_gap': 'mean'
#     }).round(4)
#     print(dropout_analysis)
    
#     # By learning rate
#     print("\n3. By Learning Rate:")
#     lr_analysis = df.groupby('lr').agg({
#         'test_f1': ['mean', 'std', 'max'],
#         'overfit_gap': 'mean'
#     }).round(4)
#     print(lr_analysis)
    
#     # By weight decay
#     print("\n4. By Weight Decay:")
#     wd_analysis = df.groupby('weight_decay').agg({
#         'test_f1': ['mean', 'std', 'max'],
#         'overfit_gap': 'mean'
#     }).round(4)
#     print(wd_analysis)


# def plot_experiment_comparison(df, save_path=None):
#     """Create comprehensive visualization of experiment results"""
#     fig = plt.figure(figsize=(20, 12))
    
#     # Sort by F1 score
#     df_sorted = df.sort_values('test_f1', ascending=True)
    
#     # 1. Main metrics comparison (bar chart)
#     ax1 = plt.subplot(3, 3, 1)
#     x = np.arange(len(df_sorted))
#     width = 0.2
    
#     ax1.barh(x - width*1.5, df_sorted['test_f1'], width, label='F1', alpha=0.8)
#     ax1.barh(x - width*0.5, df_sorted['test_precision'], width, label='Precision', alpha=0.8)
#     ax1.barh(x + width*0.5, df_sorted['test_recall'], width, label='Recall', alpha=0.8)
#     ax1.barh(x + width*1.5, df_sorted['test_auc'], width, label='AUC', alpha=0.8)
    
#     ax1.set_yticks(x)
#     ax1.set_yticklabels(df_sorted['experiment'], fontsize=8)
#     ax1.set_xlabel('Score')
#     ax1.set_title('Test Metrics Comparison')
#     ax1.legend()
#     ax1.grid(axis='x', alpha=0.3)
    
#     # 2. Overfitting analysis
#     ax2 = plt.subplot(3, 3, 2)
#     ax2.scatter(df['overfit_gap'], df['test_f1'], s=100, alpha=0.6)
#     for idx, row in df.iterrows():
#         ax2.annotate(row['experiment'].replace('exp_', ''), 
#                      (row['overfit_gap'], row['test_f1']), 
#                      fontsize=7, alpha=0.7)
#     ax2.set_xlabel('Overfitting Gap (Train F1 - Test F1)')
#     ax2.set_ylabel('Test F1 Score')
#     ax2.set_title('Overfitting vs Performance')
#     ax2.grid(alpha=0.3)
    
#     # 3. Precision-Recall tradeoff
#     ax3 = plt.subplot(3, 3, 3)
#     ax3.scatter(df['test_recall'], df['test_precision'], s=100, alpha=0.6, c=df['test_f1'], 
#                 cmap='viridis')
#     for idx, row in df.iterrows():
#         ax3.annotate(row['experiment'].replace('exp_', ''), 
#                      (row['test_recall'], row['test_precision']), 
#                      fontsize=7, alpha=0.7)
#     ax3.set_xlabel('Recall')
#     ax3.set_ylabel('Precision')
#     ax3.set_title('Precision-Recall Tradeoff (color=F1)')
#     ax3.grid(alpha=0.3)
#     plt.colorbar(ax3.collections[0], ax=ax3, label='F1 Score')
    
#     # 4. Input type comparison
#     ax4 = plt.subplot(3, 3, 4)
#     input_data = df.groupby('input_type')['test_f1'].agg(['mean', 'std'])
#     input_data['mean'].plot(kind='bar', yerr=input_data['std'], ax=ax4, capsize=5)
#     ax4.set_xlabel('Input Type')
#     ax4.set_ylabel('Test F1 Score')
#     ax4.set_title('Performance by Input Type')
#     ax4.grid(axis='y', alpha=0.3)
#     plt.setp(ax4.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
#     # 5. Dropout effect
#     ax5 = plt.subplot(3, 3, 5)
#     dropout_df = df.groupby('dropout').agg({'test_f1': 'mean', 'overfit_gap': 'mean'})
#     ax5_twin = ax5.twinx()
#     dropout_df['test_f1'].plot(ax=ax5, marker='o', color='blue', label='F1')
#     dropout_df['overfit_gap'].plot(ax=ax5_twin, marker='s', color='red', label='Overfit Gap')
#     ax5.set_xlabel('Dropout Rate')
#     ax5.set_ylabel('Test F1', color='blue')
#     ax5_twin.set_ylabel('Overfit Gap', color='red')
#     ax5.set_title('Dropout Effect on Performance')
#     ax5.grid(alpha=0.3)
#     ax5.legend(loc='upper left')
#     ax5_twin.legend(loc='upper right')
    
#     # 6. Learning rate effect
#     ax6 = plt.subplot(3, 3, 6)
#     lr_df = df.groupby('lr')['test_f1'].mean()
#     ax6.plot(lr_df.index, lr_df.values, marker='o', linewidth=2)
#     ax6.set_xscale('log')
#     ax6.set_xlabel('Learning Rate (log scale)')
#     ax6.set_ylabel('Mean Test F1')
#     ax6.set_title('Learning Rate Effect')
#     ax6.grid(alpha=0.3)
    
#     # 7. Weight decay effect
#     ax7 = plt.subplot(3, 3, 7)
#     wd_df = df[df['weight_decay'] > 0].groupby('weight_decay').agg({'test_f1': 'mean', 'overfit_gap': 'mean'})
#     if len(wd_df) > 0:
#         ax7_twin = ax7.twinx()
#         wd_df['test_f1'].plot(ax=ax7, marker='o', color='blue', label='F1')
#         wd_df['overfit_gap'].plot(ax=ax7_twin, marker='s', color='red', label='Overfit Gap')
#         ax7.set_xscale('log')
#         ax7.set_xlabel('Weight Decay (log scale)')
#         ax7.set_ylabel('Test F1', color='blue')
#         ax7_twin.set_ylabel('Overfit Gap', color='red')
#         ax7.set_title('Weight Decay Regularization')
#         ax7.grid(alpha=0.3)
#         ax7.legend(loc='upper left')
#         ax7_twin.legend(loc='upper right')
    
#     # 8. Model size analysis (count parameters)
#     ax8 = plt.subplot(3, 3, 8)
#     # Extract total layer sizes as proxy for model size
#     df['model_size'] = df['hidden_dims'].apply(lambda x: sum(eval(x)))
#     ax8.scatter(df['model_size'], df['test_f1'], s=100, alpha=0.6, c=df['overfit_gap'], 
#                 cmap='coolwarm')
#     ax8.set_xlabel('Model Size (sum of hidden dims)')
#     ax8.set_ylabel('Test F1')
#     ax8.set_title('Model Size vs Performance (color=overfit)')
#     ax8.grid(alpha=0.3)
#     plt.colorbar(ax8.collections[0], ax=ax8, label='Overfit Gap')
    
#     # 9. Training convergence (epochs to best)
#     ax9 = plt.subplot(3, 3, 9)
#     ax9.scatter(df['best_epoch'], df['test_f1'], s=100, alpha=0.6)
#     for idx, row in df.iterrows():
#         ax9.annotate(row['experiment'].replace('exp_', ''), 
#                      (row['best_epoch'], row['test_f1']), 
#                      fontsize=7, alpha=0.7)
#     ax9.set_xlabel('Epoch of Best Validation')
#     ax9.set_ylabel('Test F1')
#     ax9.set_title('Convergence Speed vs Performance')
#     ax9.grid(alpha=0.3)
    
#     plt.tight_layout()
    
#     if save_path:
#         plt.savefig(save_path, dpi=300, bbox_inches='tight')
#         print(f"\nVisualization saved to {save_path}")
    
#     plt.show()


# def plot_top_models_confusion(results, top_n=3, save_path=None):
#     """Plot confusion matrices for top N models"""
#     # Sort by F1 score
#     sorted_results = sorted(results, key=lambda x: x['test_f1'], reverse=True)
#     top_results = sorted_results[:top_n]
    
#     fig, axes = plt.subplots(1, top_n, figsize=(5*top_n, 4))
#     if top_n == 1:
#         axes = [axes]
    
#     for idx, result in enumerate(top_results):
#         cm = np.array(result['confusion_matrix'])
        
#         sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[idx],
#                    xticklabels=['Neg', 'Pos'], yticklabels=['Neg', 'Pos'])
        
#         axes[idx].set_xlabel('Predicted')
#         axes[idx].set_ylabel('Actual')
#         axes[idx].set_title(f"{result['experiment_name']}\nF1: {result['test_f1']:.4f}")
    
#     plt.tight_layout()
    
#     if save_path:
#         plt.savefig(save_path, dpi=300, bbox_inches='tight')
#         print(f"Confusion matrices saved to {save_path}")
    
#     plt.show()


# def generate_recommendations(df):
#     """Generate recommendations based on experiment results"""
#     print(f"\n" + "="*100)
#     print("RECOMMENDATIONS")
#     print("="*100)
    
#     best_overall = df.loc[df['test_f1'].idxmax()]
#     least_overfit = df.loc[df['overfit_gap'].idxmin()]
#     best_recall = df.loc[df['test_recall'].idxmax()]
#     best_precision = df.loc[df['test_precision'].idxmax()]
    
#     print(f"\n1. Best Overall Performance:")
#     print(f"   Model: {best_overall['experiment']}")
#     print(f"   Configuration: input={best_overall['input_type']}, dropout={best_overall['dropout']}, "
#           f"lr={best_overall['lr']}, wd={best_overall['weight_decay']}")
#     print(f"   Metrics: F1={best_overall['test_f1']:.4f}, Prec={best_overall['test_precision']:.4f}, "
#           f"Rec={best_overall['test_recall']:.4f}")
    
#     print(f"\n2. Most Generalizable (least overfitting):")
#     print(f"   Model: {least_overfit['experiment']}")
#     print(f"   Overfit gap: {least_overfit['overfit_gap']:.4f}")
#     print(f"   Test F1: {least_overfit['test_f1']:.4f}")
    
#     print(f"\n3. Best for Minimizing False Negatives (high recall):")
#     print(f"   Model: {best_recall['experiment']}")
#     print(f"   Recall: {best_recall['test_recall']:.4f} (catches {best_recall['test_recall']*100:.1f}% of deforestation)")
    
#     print(f"\n4. Best for Minimizing False Positives (high precision):")
#     print(f"   Model: {best_precision['experiment']}")
#     print(f"   Precision: {best_precision['test_precision']:.4f} ({best_precision['test_precision']*100:.1f}% of predictions are correct)")
    
#     # General insights
#     print(f"\n5. Key Insights:")
    
#     # Input type
#     best_input = df.groupby('input_type')['test_f1'].mean().idxmax()
#     print(f"   - Best input representation: {best_input}")
    
#     # Regularization
#     if df['overfit_gap'].mean() > 0.1:
#         print(f"   - Strong overfitting detected (avg gap: {df['overfit_gap'].mean():.3f})")
#         print(f"     → Consider: higher dropout, weight decay, or simpler models")
    
#     # Dropout sweet spot
#     if len(df['dropout'].unique()) > 1:
#         dropout_perf = df.groupby('dropout')['test_f1'].mean()
#         best_dropout = dropout_perf.idxmax()
#         print(f"   - Optimal dropout rate appears to be: {best_dropout}")


# def main():
#     """Main analysis function"""
#     # Load results
#     results, summary_file = load_experiment_results()
    
#     if not results:
#         print("No experiment results found or all experiments failed!")
#         return
    
#     # Convert to DataFrame
#     df = results_to_dataframe(results)
    
#     # Print summary table
#     print_summary_table(df)
    
#     # Analyze by category
#     analyze_by_category(df)
    
#     # Generate recommendations
#     generate_recommendations(df)
    
#     # Create visualizations
#     viz_path = root_path / "experiments" / f"analysis_{summary_file.stem}.png"
#     plot_experiment_comparison(df, save_path=viz_path)
    
#     # Plot confusion matrices for top 3
#     cm_path = root_path / "experiments" / f"confusion_matrices_{summary_file.stem}.png"
#     plot_top_models_confusion(results, top_n=3, save_path=cm_path)
    
#     print(f"\n{'='*100}")
#     print("Analysis complete!")
#     print(f"{'='*100}")


# if __name__ == "__main__":
#     main()