# Multi-Curriculum Smoke Test: Comprehensive Comparison

Enhanced smoke test comparing baseline training against multiple curriculum configurations:
- **Baseline**: No curriculum learning
- **Default Curriculum**: Balanced adversarial introduction
- **Aggressive Curriculum**: Rapid adversarial escalation
- **Conservative Curriculum**: Gradual adversarial introduction

This notebook provides comprehensive analysis, visualization, and ranking of all curriculum approaches.

In [None]:
# Clone repository and setup
!rm -rf energy-based-model
!git clone https://github.com/mdkrasnow/energy-based-model.git
%cd energy-based-model

In [None]:
# Install dependencies
!pip install -q torch torchvision einops accelerate tqdm tabulate matplotlib numpy pandas ema-pytorch ipdb seaborn scikit-learn

In [None]:
# Common training parameters
COMMON_ARGS = {
    'model': 'mlp',
    'batch_size': 32,
    'diffusion_steps': 10,
    'supervise_energy_landscape': 'True',
    'train_num_steps': 200,  # Increased for better curriculum comparison
    'save_csv_logs': True,
    'csv_log_interval': 100
}

# Tasks to test
TASKS = ['inverse', 'addition', 'lowrank']

# Curriculum configurations to test
CURRICULUM_CONFIGS = {
    'baseline': {
        'name': 'Baseline',
        'description': 'No curriculum learning',
        'color': '#1f77b4',
        'args': ['--disable-curriculum', 'True']
    },
    'default': {
        'name': 'Default Curriculum',
        'description': 'Balanced adversarial introduction',
        'color': '#ff7f0e',
        'args': ['--curriculum-config', 'default']
    },
    'aggressive': {
        'name': 'Aggressive Curriculum',
        'description': 'Rapid adversarial escalation',
        'color': '#d62728',
        'args': ['--curriculum-config', 'aggressive']
    },
    'conservative': {
        'name': 'Conservative Curriculum',
        'description': 'Gradual adversarial introduction',
        'color': '#2ca02c',
        'args': ['--curriculum-config', 'conservative']
    }
}

print(f"Testing {len(CURRICULUM_CONFIGS)} curriculum configurations on {len(TASKS)} tasks:")
print(f"Tasks: {', '.join(TASKS)}")
for key, config in CURRICULUM_CONFIGS.items():
    print(f"  • {config['name']}: {config['description']}")

In [None]:
def build_training_command(curriculum_key: str, task: str = 'inverse') -> str:
    """Build training command for a specific curriculum configuration and task."""
    config = CURRICULUM_CONFIGS[curriculum_key]

    base_cmd = f"""python train.py \
        --dataset {task} \
        --model {COMMON_ARGS['model']} \
        --batch_size {COMMON_ARGS['batch_size']} \
        --diffusion_steps {COMMON_ARGS['diffusion_steps']} \
        --supervise-energy-landscape {COMMON_ARGS['supervise_energy_landscape']} \
        --train-num-steps {COMMON_ARGS['train_num_steps']} \
        --save-csv-logs \
        --csv-log-interval {COMMON_ARGS['csv_log_interval']} \
        --csv-log-dir ./csv_logs_{task}_{curriculum_key}"""

    # Add curriculum-specific arguments
    if config['args']:
        base_cmd += ' \
        ' + ' \
        '.join(config['args'])

    return base_cmd

def load_csv_data(csv_path: Path) -> Optional[pd.DataFrame]:
    """Load CSV data with error handling."""
    try:
        if csv_path.exists():
            return pd.read_csv(csv_path)
        else:
            print(f"Warning: {csv_path} not found")
            return None
    except Exception as e:
        print(f"Error loading {csv_path}: {e}")
        return None

def safe_get_final_value(df: pd.DataFrame, column: str, default: float = 0.0) -> float:
    """Safely get the final value from a dataframe column."""
    if df is None or column not in df.columns or len(df) == 0:
        return default
    return float(df[column].iloc[-1])

def safe_get_best_value(df: pd.DataFrame, column: str, minimize: bool = True, default: float = 0.0) -> float:
    """Safely get the best (min/max) value from a dataframe column."""
    if df is None or column not in df.columns or len(df) == 0:
        return default
    return float(df[column].min() if minimize else df[column].max())

## Multi-Curriculum Training

In [None]:
# Train all tasks with all curriculum configurations
import subprocess
import sys

training_results = {}

# Train each task
for task in TASKS:
    print(f"\n{'#'*80}")
    print(f"# TASK: {task.upper()}")
    print(f"{'#'*80}")
    
    task_results = {}
    
    # For each curriculum configuration
    for curriculum_key, config in CURRICULUM_CONFIGS.items():
        print(f"\n{'='*60}")
        print(f"Starting {config['name']} training for {task} task...")
        print(f"Description: {config['description']}")
        print(f"{'='*60}")

        # Build training command
        cmd = build_training_command(curriculum_key, task)
        print(f"\nCommand: {cmd}")
        print("\nTraining output:")
        print("-" * 40)

        # Execute training with real-time output
        try:
            # Use subprocess to capture and display output in real-time
            process = subprocess.Popen(
                cmd,
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                bufsize=1
            )

            # Display output line by line as it comes
            for line in iter(process.stdout.readline, ''):
                if line:
                    print(line.rstrip())
                    sys.stdout.flush()

            # Wait for process to complete
            result = process.wait()
            task_results[curriculum_key] = result

            if result == 0:
                print(f"✓ {config['name']} training for {task} completed successfully")
            else:
                print(f"✗ {config['name']} training for {task} failed with exit code {result}")

        except Exception as e:
            print(f"✗ Error during {config['name']} training for {task}: {e}")
            task_results[curriculum_key] = -1

        print("-" * 40)
    
    training_results[task] = task_results
    
    # Print task summary
    successful = sum(1 for result in task_results.values() if result == 0)
    print(f"\n{task.upper()} Task Summary: {successful}/{len(task_results)} configurations completed successfully")

print(f"\n{'='*80}")
print("All training completed!")
print(f"{'='*80}")

# Print overall summary
print("\nOVERALL TRAINING SUMMARY:")
for task, task_results in training_results.items():
    successful = sum(1 for result in task_results.values() if result == 0)
    print(f"\n{task.upper()} Task: {successful}/{len(task_results)} configurations successful")
    for curriculum_key, result in task_results.items():
        status = "✓ Success" if result == 0 else "✗ Failed"
        print(f"  {CURRICULUM_CONFIGS[curriculum_key]['name']}: {status}")

In [None]:
def load_all_curriculum_results() -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
    """Load all curriculum training results from CSV files for all tasks."""
    results = {}

    for task in TASKS:
        print(f"\nLoading results for {task} task...")
        task_results = {}
        
        for curriculum_key in CURRICULUM_CONFIGS.keys():
            print(f"  Loading {curriculum_key} data...")

            csv_dir = Path(f"./csv_logs_{task}_{curriculum_key}")

            # Load different types of training data - look for files with timestamps
            import glob

            curriculum_data = {}

            # Find the most recent file for each metric type
            patterns = {
                'training': 'training_metrics_*.csv',
                'validation': 'validation_metrics_*.csv',
                'energy': 'energy_metrics_*.csv',
                'curriculum': 'curriculum_metrics_*.csv',
                'robustness': 'robustness_metrics_*.csv'
            }

            for key, pattern in patterns.items():
                files = glob.glob(str(csv_dir / pattern))
                if files:
                    # Get most recent file
                    latest_file = max(files, key=lambda x: Path(x).stat().st_mtime if Path(x).exists() else 0)
                    curriculum_data[key] = load_csv_data(Path(latest_file))
                else:
                    curriculum_data[key] = None

            # Count available data types
            available = sum(1 for df in curriculum_data.values() if df is not None)
            print(f"    Found {available}/5 data files for {curriculum_key}")

            task_results[curriculum_key] = curriculum_data
        
        results[task] = task_results

    return results

def process_curriculum_data(all_results: Dict[str, Dict[str, Dict[str, pd.DataFrame]]]) -> pd.DataFrame:
    """Extract and standardize key metrics from all curriculum results across all tasks."""
    processed_data = []

    for task, task_results in all_results.items():
        for curriculum_key, data in task_results.items():
            config = CURRICULUM_CONFIGS[curriculum_key]

            # Extract key metrics
            training_df = data.get('training')
            validation_df = data.get('validation')
            energy_df = data.get('energy')

            # Extract task-specific metrics from validation data
            val_accuracy = 0.0  # For inverse task
            val_identity_error = float('inf')  # For inverse task
            val_mse = float('inf')  # General metric
            val_mae = float('inf')  # For addition/lowrank tasks

            if validation_df is not None:
                # Look for accuracy metric (for inverse task)
                accuracy_rows = validation_df[validation_df['metric_name'] == 'accuracy']
                if not accuracy_rows.empty:
                    val_accuracy = accuracy_rows['metric_value'].iloc[-1]  # Get last value (fraction)

                # Look for identity_error metric (for inverse task)
                identity_rows = validation_df[validation_df['metric_name'] == 'identity_error']
                if not identity_rows.empty:
                    val_identity_error = identity_rows['metric_value'].iloc[-1]

                # Look for MSE metric
                mse_rows = validation_df[validation_df['metric_name'] == 'mse']
                if not mse_rows.empty:
                    val_mse = mse_rows['metric_value'].iloc[-1]
                
                # Look for MAE metric (for addition/lowrank)
                mae_rows = validation_df[validation_df['metric_name'] == 'mae']
                if not mae_rows.empty:
                    val_mae = mae_rows['metric_value'].iloc[-1]

            metrics = {
                'task': task,
                'curriculum': curriculum_key,
                'name': config['name'],
                'color': config['color'],

                # Training metrics
                'final_total_loss': safe_get_final_value(training_df, 'total_loss'),
                'final_energy_loss': safe_get_final_value(training_df, 'loss_energy'),
                'final_denoise_loss': safe_get_final_value(training_df, 'loss_denoise'),
                'best_total_loss': safe_get_best_value(training_df, 'total_loss', minimize=True),
                'avg_training_time': training_df['nn_time'].mean() if training_df is not None and 'nn_time' in training_df.columns else 0.0,

                # Validation metrics
                'final_val_accuracy': val_accuracy * 100,  # Convert fraction to percentage for display
                'best_val_accuracy': val_accuracy * 100,  # Convert fraction to percentage for display
                'final_identity_error': val_identity_error,
                'final_val_mse': val_mse,
                'final_val_mae': val_mae,

                # Energy metrics
                'final_energy_margin': safe_get_final_value(energy_df, 'energy_margin'),
                'max_curriculum_weight': safe_get_best_value(energy_df, 'curriculum_weight', minimize=False) if energy_df is not None and 'curriculum_weight' in energy_df.columns else 0.0,
            }

            processed_data.append(metrics)

    return pd.DataFrame(processed_data)

# Load all results
print("Loading all curriculum results for all tasks...")
all_results = load_all_curriculum_results()

# Process data
print("\nProcessing curriculum data...")
summary_df = process_curriculum_data(all_results)

print(f"\nLoaded data for {len(summary_df)} configurations across {len(TASKS)} tasks")

# Display summary by task
for task in TASKS:
    task_data = summary_df[summary_df['task'] == task]
    print(f"\n{task.upper()} Task Summary:")
    display_cols = ['name', 'final_total_loss', 'final_val_mse']
    if task == 'inverse':
        display_cols.append('final_identity_error')
    elif task in ['addition', 'lowrank']:
        display_cols.append('final_val_mae')
    
    display(task_data[display_cols].round(4))

## IRED Paper-Style Results Reporting

This section generates results tables in the format used by the IRED paper, focusing on:
- **Mean Squared Error (MSE)** as the primary metric for all matrix tasks
- **Identity Error** for matrix inverse task (||Pred @ Input - I||²)
- **Task-specific metrics** for comprehensive evaluation
- **Consolidated performance tables** across all tasks

In [None]:
## IRED-Style Results Table Generation

def generate_ired_style_table(summary_df: pd.DataFrame, task: str):
    """Generate IRED paper-style results table for a specific task."""
    import pandas as pd
    from tabulate import tabulate
    
    task_data = summary_df[summary_df['task'] == task].copy()
    
    if len(task_data) == 0:
        print(f"No data available for {task} task")
        return
    
    # Sort by final MSE to rank methods
    task_data = task_data.sort_values('final_val_mse')
    
    # Create IRED-style table
    table_data = []
    for _, row in task_data.iterrows():
        method_name = row['name']
        
        # Main metric: MSE (as reported in IRED paper)
        mse = row['final_val_mse'] if row['final_val_mse'] != float('inf') else 'N/A'
        
        # Task-specific additional metrics
        if task == 'inverse':
            identity_error = row['final_identity_error'] if row['final_identity_error'] != float('inf') else 'N/A'
            if isinstance(mse, float) and isinstance(identity_error, float):
                table_data.append([method_name, f"{mse:.4f}", f"{identity_error:.4f}"])
            else:
                table_data.append([method_name, str(mse), str(identity_error)])
        else:
            if isinstance(mse, float):
                table_data.append([method_name, f"{mse:.4f}"])
            else:
                table_data.append([method_name, str(mse)])
    
    # Create table headers based on task
    if task == 'inverse':
        headers = ['Method', 'MSE', 'Identity Error']
    else:
        headers = ['Method', 'MSE']
    
    print(f"\n{'='*60}")
    print(f"IRED-Style Results Table: {task.upper()} Task")
    print(f"{'='*60}")
    print(tabulate(table_data, headers=headers, tablefmt='grid'))
    
    # Add interpretation
    if len(task_data) > 0:
        best_method = task_data.iloc[0]
        print(f"\n✓ Best performing method: {best_method['name']}")
        if isinstance(best_method['final_val_mse'], float):
            print(f"  MSE: {best_method['final_val_mse']:.4f}")
        if task == 'inverse' and isinstance(best_method['final_identity_error'], float):
            print(f"  Identity Error: {best_method['final_identity_error']:.4f}")

def generate_consolidated_ired_table(summary_df: pd.DataFrame):
    """Generate consolidated IRED-style table for all matrix tasks."""
    from tabulate import tabulate
    import numpy as np
    
    print("\n" + "="*80)
    print("IRED-STYLE CONSOLIDATED RESULTS TABLE")
    print("Matrix Operations Performance (MSE)")
    print("="*80)
    
    # Prepare data for each task
    methods = summary_df['name'].unique()
    
    # Create consolidated table
    table_data = []
    for method in methods:
        method_data = summary_df[summary_df['name'] == method]
        
        row = [method]
        
        # Add MSE for each task
        for task in ['addition', 'lowrank', 'inverse']:
            task_row = method_data[method_data['task'] == task]
            if len(task_row) > 0:
                mse = task_row['final_val_mse'].iloc[0]
                if mse != float('inf'):
                    row.append(f"{mse:.4f}")
                else:
                    row.append("N/A")
            else:
                row.append("-")
        
        # Add average MSE
        mse_values = []
        for task in ['addition', 'lowrank', 'inverse']:
            task_row = method_data[method_data['task'] == task]
            if len(task_row) > 0:
                mse = task_row['final_val_mse'].iloc[0]
                if mse != float('inf'):
                    mse_values.append(mse)
        
        if mse_values:
            avg_mse = np.mean(mse_values)
            row.append(f"{avg_mse:.4f}")
        else:
            row.append("N/A")
        
        table_data.append(row)
    
    # Sort by average MSE
    table_data.sort(key=lambda x: float(x[-1]) if x[-1] not in ["N/A", "-"] else float('inf'))
    
    headers = ['Method', 'Addition', 'Matrix Completion', 'Matrix Inverse', 'Average']
    print(tabulate(table_data, headers=headers, tablefmt='grid'))
    
    # Best method summary
    if table_data:
        print(f"\n✓ Best overall method: {table_data[0][0]}")
        print(f"  Average MSE: {table_data[0][-1]}")

def generate_task_specific_metrics_table(summary_df: pd.DataFrame):
    """Generate detailed metrics table for each task."""
    from tabulate import tabulate
    
    for task in TASKS:
        task_data = summary_df[summary_df['task'] == task].copy()
        
        if len(task_data) == 0:
            continue
        
        print(f"\n{'='*70}")
        print(f"Detailed Metrics Table: {task.upper()} Task")
        print(f"{'='*70}")
        
        # Sort by MSE
        task_data = task_data.sort_values('final_val_mse')
        
        # Prepare table data
        table_data = []
        for _, row in task_data.iterrows():
            row_data = [row['name']]
            
            # Add MSE
            mse = row['final_val_mse']
            row_data.append(f"{mse:.4f}" if mse != float('inf') else "N/A")
            
            # Add task-specific metrics
            if task == 'inverse':
                identity_error = row['final_identity_error']
                accuracy = row['final_val_accuracy']
                row_data.append(f"{identity_error:.4f}" if identity_error != float('inf') else "N/A")
                row_data.append(f"{accuracy:.1f}%" if accuracy > 0 else "N/A")
            
            # Add training loss
            row_data.append(f"{row['final_total_loss']:.4f}")
            
            # Add best loss achieved
            row_data.append(f"{row['best_total_loss']:.4f}")
            
            table_data.append(row_data)
        
        # Define headers based on task
        if task == 'inverse':
            headers = ['Method', 'MSE', 'Identity Error', 'Accuracy', 'Final Loss', 'Best Loss']
        else:
            headers = ['Method', 'MSE', 'Final Loss', 'Best Loss']
        
        print(tabulate(table_data, headers=headers, tablefmt='grid'))
        
        # Summary statistics
        valid_mse = task_data[task_data['final_val_mse'] != float('inf')]['final_val_mse']
        if len(valid_mse) > 0:
            print(f"\nStatistics:")
            print(f"  MSE Range: {valid_mse.min():.4f} - {valid_mse.max():.4f}")
            print(f"  MSE Mean: {valid_mse.mean():.4f}")
            print(f"  MSE Std: {valid_mse.std():.4f}")

# Generate all IRED-style tables
print("\n" + "="*80)
print("IRED PAPER-STYLE RESULTS REPORTING")
print("="*80)

# Individual task tables
for task in TASKS:
    generate_ired_style_table(summary_df, task)

# Consolidated table
generate_consolidated_ired_table(summary_df)

# Detailed metrics
generate_task_specific_metrics_table(summary_df)

In [None]:
def visualize_task_curriculum_comparison(all_results: Dict, summary_df: pd.DataFrame, task: str):
    """Create comprehensive side-by-side curriculum comparison visualizations for a specific task."""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'Multi-Curriculum Training Comparison - {task.upper()} Task', fontsize=16, fontweight='bold')

    task_results = all_results.get(task, {})
    task_summary = summary_df[summary_df['task'] == task]
    
    # 1. Total Loss Comparison
    ax = axes[0, 0]
    for curriculum_key, data in task_results.items():
        training_df = data.get('training')
        if training_df is not None and 'total_loss' in training_df.columns:
            config = CURRICULUM_CONFIGS[curriculum_key]
            ax.plot(training_df['step'], training_df['total_loss'],
                    color=config['color'], label=config['name'], linewidth=2, alpha=0.8)

    ax.set_title('Total Loss Curves', fontweight='bold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Total Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 2. Energy Loss Comparison
    ax = axes[0, 1]
    for curriculum_key, data in task_results.items():
        training_df = data.get('training')
        if training_df is not None and 'loss_energy' in training_df.columns:
            config = CURRICULUM_CONFIGS[curriculum_key]
            ax.plot(training_df['step'], training_df['loss_energy'],
                    color=config['color'], label=config['name'], linewidth=2, alpha=0.8)

    ax.set_title('Energy Loss Curves', fontweight='bold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Energy Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Prepare colors for bar charts
    colors = task_summary['color'].tolist()
    
    # 3. Task-specific metric comparison
    ax = axes[1, 0]
    curricula = task_summary['name'].tolist()
    
    if task == 'inverse':
        # Identity Error for inverse task
        identity_errors = task_summary['final_identity_error'].tolist()
        valid_errors = [(c, e, col) for c, e, col in zip(curricula, identity_errors, colors)
                        if e != float('inf')]
        
        if valid_errors:
            names, errors, cols = zip(*valid_errors)
            bars = ax.bar(names, errors, color=cols, alpha=0.7, edgecolor='black')
            ax.set_title('Identity Error (Lower = Better)', fontweight='bold')
            ax.set_ylabel('||Pred @ Input - I||²')
            plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
            
            # Add value labels
            for bar, error in zip(bars, errors):
                ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                        f'{error:.4f}', ha='center', va='bottom', fontweight='bold')
    else:
        # MAE for addition/lowrank tasks
        mae_values = task_summary['final_val_mae'].tolist()
        valid_mae = [(c, m, col) for c, m, col in zip(curricula, mae_values, colors)
                     if m != float('inf')]
        
        if valid_mae:
            names, maes, cols = zip(*valid_mae)
            bars = ax.bar(names, maes, color=cols, alpha=0.7, edgecolor='black')
            ax.set_title('Mean Absolute Error (Lower = Better)', fontweight='bold')
            ax.set_ylabel('MAE')
            plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
            
            # Add value labels
            for bar, mae in zip(bars, maes):
                ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                        f'{mae:.4f}', ha='center', va='bottom', fontweight='bold')
    
    # 4. MSE Comparison (common metric)
    ax = axes[1, 1]
    mse_values = task_summary['final_val_mse'].tolist()

    # Filter out inf values
    valid_mse = [(c, m, col) for c, m, col in zip(curricula, mse_values, colors)
                 if m != float('inf')]

    if valid_mse:
        names, mses, cols = zip(*valid_mse)
        bars = ax.bar(names, mses, color=cols, alpha=0.7, edgecolor='black')
        ax.set_title('Mean Squared Error', fontweight='bold')
        ax.set_ylabel('MSE')
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')

        # Add value labels
        for bar, mse in zip(bars, mses):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                    f'{mse:.4f}', ha='center', va='bottom', fontweight='bold')
    else:
        ax.text(0.5, 0.5, 'No validation data available\n(Run longer training for validation metrics)',
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('MSE (N/A)', fontweight='bold')

    plt.tight_layout()
    plt.show()

def visualize_cross_task_comparison(summary_df: pd.DataFrame):
    """Compare curriculum performance across different tasks."""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Cross-Task Curriculum Performance Comparison', fontsize=16, fontweight='bold')
    
    tasks = summary_df['task'].unique()
    
    # Plot for each task
    for idx, task in enumerate(tasks):
        row = idx // 3
        col = idx % 3
        ax = axes[row, col]
        
        task_data = summary_df[summary_df['task'] == task]
        
        # Create grouped bar chart for final losses
        x = np.arange(len(task_data))
        width = 0.35
        
        bars1 = ax.bar(x - width/2, task_data['final_total_loss'], width, 
                      label='Total Loss', alpha=0.7)
        bars2 = ax.bar(x + width/2, task_data['final_val_mse'], width,
                      label='MSE', alpha=0.7)
        
        ax.set_title(f'{task.upper()} Task Performance', fontweight='bold')
        ax.set_ylabel('Loss/Error')
        ax.set_xticks(x)
        ax.set_xticklabels(task_data['name'], rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots if any
    for idx in range(len(tasks), 6):
        row = idx // 3
        col = idx % 3
        fig.delaxes(axes[row, col])
    
    plt.tight_layout()
    plt.show()

def create_task_performance_heatmap(summary_df: pd.DataFrame):
    """Create a heatmap showing curriculum performance across tasks."""
    import seaborn as sns
    
    # Pivot data for heatmap
    pivot_loss = summary_df.pivot(index='name', columns='task', values='final_total_loss')
    pivot_mse = summary_df.pivot(index='name', columns='task', values='final_val_mse')
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Total Loss heatmap
    sns.heatmap(pivot_loss, annot=True, fmt='.4f', cmap='RdYlGn_r', ax=ax1, cbar_kws={'label': 'Total Loss'})
    ax1.set_title('Total Loss by Curriculum and Task', fontweight='bold')
    ax1.set_xlabel('Task')
    ax1.set_ylabel('Curriculum')
    
    # MSE heatmap
    # Replace inf values with NaN for better visualization
    pivot_mse_clean = pivot_mse.replace([float('inf')], np.nan)
    sns.heatmap(pivot_mse_clean, annot=True, fmt='.4f', cmap='RdYlGn_r', ax=ax2, cbar_kws={'label': 'MSE'})
    ax2.set_title('MSE by Curriculum and Task', fontweight='bold')
    ax2.set_xlabel('Task')
    ax2.set_ylabel('Curriculum')
    
    plt.tight_layout()
    plt.show()

# Generate visualizations for each task
print("Generating task-specific curriculum comparison visualizations...")
for task in TASKS:
    print(f"\nVisualizing {task.upper()} task...")
    visualize_task_curriculum_comparison(all_results, summary_df, task)

print("\nGenerating cross-task comparison...")
visualize_cross_task_comparison(summary_df)

print("\nGenerating performance heatmap...")
create_task_performance_heatmap(summary_df)

In [None]:
def generate_task_ranking(summary_df: pd.DataFrame):
    """Generate comprehensive ranking for each task and overall."""
    
    print("\n" + "="*80)
    print("MULTI-TASK CURRICULUM RANKING ANALYSIS")
    print("="*80)
    
    # Rank curricula for each task
    for task in TASKS:
        task_data = summary_df[summary_df['task'] == task].copy()
        
        print(f"\n📊 {task.upper()} TASK RANKING")
        print("-" * 50)
        
        # Create a scoring system
        task_data['score'] = 0
        
        # Score based on final loss (lower is better)
        loss_rank = task_data['final_total_loss'].rank()
        task_data['score'] += (len(task_data) - loss_rank + 1) * 25
        
        # Score based on MSE (lower is better)
        mse_valid = task_data['final_val_mse'] != float('inf')
        if mse_valid.any():
            mse_rank = task_data.loc[mse_valid, 'final_val_mse'].rank()
            task_data.loc[mse_valid, 'score'] += (len(mse_rank) - mse_rank + 1) * 25
        
        # Task-specific scoring
        if task == 'inverse':
            # Score based on identity error
            id_valid = task_data['final_identity_error'] != float('inf')
            if id_valid.any():
                id_rank = task_data.loc[id_valid, 'final_identity_error'].rank()
                task_data.loc[id_valid, 'score'] += (len(id_rank) - id_rank + 1) * 25
        elif task in ['addition', 'lowrank']:
            # Score based on MAE
            mae_valid = task_data['final_val_mae'] != float('inf')
            if mae_valid.any():
                mae_rank = task_data.loc[mae_valid, 'final_val_mae'].rank()
                task_data.loc[mae_valid, 'score'] += (len(mae_rank) - mae_rank + 1) * 25
        
        # Score based on training efficiency
        time_rank = task_data['avg_training_time'].rank()
        task_data['score'] += (len(task_data) - time_rank + 1) * 25
        
        # Normalize scores to 0-100
        max_score = task_data['score'].max()
        if max_score > 0:
            task_data['normalized_score'] = (task_data['score'] / max_score) * 100
        else:
            task_data['normalized_score'] = 0
        
        # Sort by score
        task_data = task_data.sort_values('normalized_score', ascending=False)
        
        # Display ranking
        rank_table = task_data[['name', 'normalized_score', 'final_total_loss', 'final_val_mse']].copy()
        rank_table['rank'] = range(1, len(rank_table) + 1)
        rank_table = rank_table[['rank', 'name', 'normalized_score', 'final_total_loss', 'final_val_mse']]
        rank_table.columns = ['Rank', 'Curriculum', 'Score', 'Final Loss', 'MSE']
        
        print(rank_table.round(2).to_string(index=False))
        
        # Winner for this task
        winner = task_data.iloc[0]
        print(f"\n🏆 Winner for {task.upper()}: {winner['name']} (Score: {winner['normalized_score']:.1f})")
        
        # Check if curriculum beats baseline
        baseline_score = task_data[task_data['curriculum'] == 'baseline']['normalized_score'].values
        if len(baseline_score) > 0 and winner['curriculum'] != 'baseline':
            improvement = winner['normalized_score'] - baseline_score[0]
            print(f"   Improvement over baseline: +{improvement:.1f} points")
    
    # Overall ranking across all tasks
    print("\n" + "="*60)
    print("📈 OVERALL CROSS-TASK RANKING")
    print("="*60)
    
    # Calculate average score across tasks for each curriculum
    overall_scores = []
    for curriculum_key in CURRICULUM_CONFIGS.keys():
        curriculum_data = summary_df[summary_df['curriculum'] == curriculum_key]
        
        if len(curriculum_data) > 0:
            # Calculate average performance metrics
            avg_loss = curriculum_data['final_total_loss'].mean()
            avg_mse = curriculum_data[curriculum_data['final_val_mse'] != float('inf')]['final_val_mse'].mean()
            avg_time = curriculum_data['avg_training_time'].mean()
            
            overall_scores.append({
                'curriculum': curriculum_key,
                'name': CURRICULUM_CONFIGS[curriculum_key]['name'],
                'avg_loss': avg_loss,
                'avg_mse': avg_mse if not np.isnan(avg_mse) else float('inf'),
                'avg_time': avg_time,
                'num_tasks': len(curriculum_data)
            })
    
    overall_df = pd.DataFrame(overall_scores)
    
    # Rank based on average loss
    overall_df['rank'] = overall_df['avg_loss'].rank().astype(int)
    overall_df = overall_df.sort_values('rank')
    
    print("\nOverall Performance Across All Tasks:")
    display_df = overall_df[['rank', 'name', 'avg_loss', 'avg_mse', 'avg_time']].copy()
    display_df.columns = ['Rank', 'Curriculum', 'Avg Loss', 'Avg MSE', 'Avg Time']
    print(display_df.round(4).to_string(index=False))
    
    print("\n🎯 KEY FINDINGS:")
    overall_winner = overall_df.iloc[0]
    print(f"   • Best Overall: {overall_winner['name']}")
    print(f"   • Average Loss: {overall_winner['avg_loss']:.4f}")
    
    # Check curriculum vs baseline
    baseline_overall = overall_df[overall_df['curriculum'] == 'baseline']
    if len(baseline_overall) > 0:
        baseline_loss = baseline_overall['avg_loss'].iloc[0]
        if overall_winner['curriculum'] != 'baseline':
            improvement = (baseline_loss - overall_winner['avg_loss']) / baseline_loss * 100
            print(f"   • Improvement over baseline: {improvement:.1f}%")
        else:
            print("   • Baseline performed best overall")
    
    return overall_df

# Generate comprehensive ranking
print("Generating multi-task ranking analysis...")
overall_ranking = generate_task_ranking(summary_df)

## Multi-Task Summary Analysis and Ranking

In [None]:
# Save comprehensive results
output_dir = Path('./multi_task_curriculum_results')
output_dir.mkdir(exist_ok=True)

# Export summary data
summary_df.to_csv(output_dir / 'multi_task_curriculum_summary.csv', index=False)
overall_ranking.to_csv(output_dir / 'overall_curriculum_ranking.csv', index=False)

# Create task-specific summaries
for task in TASKS:
    task_data = summary_df[summary_df['task'] == task]
    task_data.to_csv(output_dir / f'{task}_curriculum_summary.csv', index=False)

print("\n" + "="*80)
print("MULTI-TASK CURRICULUM TRAINING - FINAL REPORT")
print("="*80)

print(f"\n📁 Results saved to: {output_dir}")
print("   • multi_task_curriculum_summary.csv - All results across tasks")
print("   • overall_curriculum_ranking.csv - Overall curriculum rankings")
for task in TASKS:
    print(f"   • {task}_curriculum_summary.csv - {task.capitalize()} task specific results")

# Training statistics
total_configs = len(CURRICULUM_CONFIGS) * len(TASKS)
successful_runs = sum(1 for _, task_results in training_results.items() 
                     for result in task_results.values() if result == 0)

print(f"\n📊 TRAINING STATISTICS:")
print(f"   • Total configurations tested: {total_configs}")
print(f"   • Successful runs: {successful_runs}/{total_configs}")
print(f"   • Tasks evaluated: {', '.join(TASKS)}")
print(f"   • Curricula tested: {', '.join([cfg['name'] for cfg in CURRICULUM_CONFIGS.values()])}")

print("\n🏆 EXECUTIVE SUMMARY:")
print(f"   • Best Overall Curriculum: {overall_ranking.iloc[0]['name']}")
print(f"   • Training steps per run: {COMMON_ARGS['train_num_steps']}")
print(f"   • CSV logging interval: {COMMON_ARGS['csv_log_interval']} steps")

# Task-specific winners
print("\n🎯 TASK-SPECIFIC WINNERS:")
for task in TASKS:
    task_data = summary_df[summary_df['task'] == task]
    if len(task_data) > 0:
        # Find best by loss
        best_idx = task_data['final_total_loss'].idxmin()
        winner = task_data.loc[best_idx]
        print(f"   • {task.upper()}: {winner['name']} (Loss: {winner['final_total_loss']:.4f})")

print("\n" + "="*80)
print("🚀 MULTI-TASK MULTI-CURRICULUM SMOKE TEST COMPLETED! 🚀")
print("="*80)
print("\nUse the generated CSV files and visualizations to:")
print("1. Select the best curriculum approach for each task")
print("2. Compare baseline vs curriculum learning effectiveness")
print("3. Understand trade-offs between different curriculum strategies")
print("4. Make informed decisions for production training")

## Results Export and Final Summary

In [None]:
def visualize_curriculum_comparison(all_results: Dict, summary_df: pd.DataFrame):
    """Create comprehensive side-by-side curriculum comparison visualizations."""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Multi-Curriculum Training Comparison', fontsize=16, fontweight='bold')

    # 1. Total Loss Comparison
    ax = axes[0, 0]
    for curriculum_key, data in all_results.items():
        training_df = data.get('training')
        if training_df is not None and 'total_loss' in training_df.columns:
            config = CURRICULUM_CONFIGS[curriculum_key]
            ax.plot(training_df['step'], training_df['total_loss'],
                    color=config['color'], label=config['name'], linewidth=2, alpha=0.8)

    ax.set_title('Total Loss Curves', fontweight='bold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Total Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 2. Energy Loss Comparison
    ax = axes[0, 1]
    for curriculum_key, data in all_results.items():
        training_df = data.get('training')
        if training_df is not None and 'loss_energy' in training_df.columns:
            config = CURRICULUM_CONFIGS[curriculum_key]
            ax.plot(training_df['step'], training_df['loss_energy'],
                    color=config['color'], label=config['name'], linewidth=2, alpha=0.8)

    ax.set_title('Energy Loss Curves', fontweight='bold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Energy Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Prepare colors for bar charts
    colors = summary_df['color'].tolist()

    # 3. Identity Error Comparison
    ax = axes[1, 0]
    identity_errors = summary_df['final_identity_error'].tolist()
    curricula = summary_df['name'].tolist()

    # Filter out inf values for plotting
    valid_errors = [(c, e, col) for c, e, col in zip(curricula, identity_errors, colors)
                    if e != float('inf')]

    if valid_errors:
        names, errors, cols = zip(*valid_errors)
        bars = ax.bar(names, errors, color=cols, alpha=0.7, edgecolor='black')
        ax.set_title('Identity Error (Lower = Better)', fontweight='bold')
        ax.set_ylabel('||Pred @ Input - I||²')
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')

        # Add value labels
        for bar, error in zip(bars, errors):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                    f'{error:.4f}', ha='center', va='bottom', fontweight='bold')
    else:
        ax.text(0.5, 0.5, 'No validation data available\n(Run longer training for validation metrics)',
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Identity Error (N/A)', fontweight='bold')

    # 4. MSE Comparison
    ax = axes[1, 1]
    mse_values = summary_df['final_val_mse'].tolist()

    # Filter out inf values
    valid_mse = [(c, m, col) for c, m, col in zip(curricula, mse_values, colors)
                 if m != float('inf')]

    if valid_mse:
        names, mses, cols = zip(*valid_mse)
        bars = ax.bar(names, mses, color=cols, alpha=0.7, edgecolor='black')
        ax.set_title('Mean Squared Error', fontweight='bold')
        ax.set_ylabel('MSE')
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')

        # Add value labels
        for bar, mse in zip(bars, mses):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                    f'{mse:.4f}', ha='center', va='bottom', fontweight='bold')
    else:
        ax.text(0.5, 0.5, 'No validation data available\n(Run longer training for validation metrics)',
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('MSE (N/A)', fontweight='bold')

    plt.tight_layout()
    plt.show()

def visualize_inverse_task_metrics(all_results: Dict, summary_df: pd.DataFrame):
    """Visualize task-specific validation metrics over time (accuracy removed)."""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Validation Metrics (Without Accuracy)', fontsize=16, fontweight='bold')

    # 1. Identity Error over time
    ax = axes[0, 0]
    data_found = False
    for curriculum_key, data in all_results.items():
        validation_df = data.get('validation')
        if validation_df is not None:
            config = CURRICULUM_CONFIGS[curriculum_key]

            # Filter for identity_error metric
            identity_df = validation_df[validation_df['metric_name'] == 'identity_error']
            if not identity_df.empty:
                ax.plot(identity_df['step'], identity_df['metric_value'],
                        color=config['color'], label=config['name'],
                        linewidth=2, alpha=0.8, marker='o')
                data_found = True

    if not data_found:
        ax.text(0.5, 0.5, 'No identity error data available\n(Validation runs every 50 steps)',
                ha='center', va='center', transform=ax.transAxes, fontsize=12)

    ax.set_title('Identity Error Evolution (Lower = Better)', fontweight='bold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('||Pred @ Input - I||²')
    if data_found:
        ax.legend()
    ax.grid(True, alpha=0.3)

    # 2. MSE over time
    ax = axes[0, 1]
    data_found = False
    for curriculum_key, data in all_results.items():
        validation_df = data.get('validation')
        if validation_df is not None:
            config = CURRICULUM_CONFIGS[curriculum_key]

            # Filter for mse metric
            mse_df = validation_df[validation_df['metric_name'] == 'mse']
            if not mse_df.empty:
                ax.plot(mse_df['step'], mse_df['metric_value'],
                        color=config['color'], label=config['name'],
                        linewidth=2, alpha=0.8, marker='o')
                data_found = True

    if not data_found:
        ax.text(0.5, 0.5, 'No MSE data available\n(Validation runs every 50 steps)',
                ha='center', va='center', transform=ax.transAxes, fontsize=12)

    ax.set_title('Mean Squared Error Evolution', fontweight='bold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('MSE')
    if data_found:
        ax.legend()
    ax.grid(True, alpha=0.3)

    # 3. Summary comparison (accuracy removed)
    ax = axes[1, 0]

    # Create grouped bar chart for final metrics (without accuracy)
    metrics_to_plot = ['final_identity_error', 'final_val_mse']
    metric_labels = ['Identity Error', 'MSE']

    x = np.arange(len(summary_df))
    width = 0.35

    has_valid_data = False
    for i, (metric, label) in enumerate(zip(metrics_to_plot, metric_labels)):
        values = summary_df[metric].values

        # Check if we have any valid (non-inf) values
        valid_vals = values[values != float('inf')]
        if len(valid_vals) > 0:
            has_valid_data = True
            # For errors, invert so higher is better
            max_val = valid_vals.max()
            norm_values = 1 - (values / (max_val + 1e-8))
            norm_values[values == float('inf')] = 0

            bars = ax.bar(x + i*width, norm_values, width, label=label, alpha=0.7)

    if has_valid_data:
        ax.set_title('Normalized Performance Comparison (Higher = Better)', fontweight='bold')
        ax.set_xlabel('Curriculum')
        ax.set_ylabel('Normalized Score (0-1)')
        ax.set_xticks(x + width / 2)
        ax.set_xticklabels(summary_df['name'], rotation=45, ha='right')
        ax.legend()
    else:
        ax.text(0.5, 0.5, 'No validation data available\n(Run longer training for metrics)',
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Performance Comparison (N/A)', fontweight='bold')

    ax.grid(True, alpha=0.3)

    # 4. Remove unused fourth subplot
    fig.delaxes(axes[1, 1])

    plt.tight_layout()
    plt.show()

def visualize_curriculum_convergence(all_results: Dict, summary_df: pd.DataFrame):
    """Analyze and visualize convergence patterns across curricula."""
    fig, axes = plt.subplots(2, 2, figsize=(16, 10))
    fig.suptitle('Curriculum Convergence Analysis', fontsize=16, fontweight='bold')

    # 1. Loss convergence rate
    ax = axes[0, 0]
    for curriculum_key, data in all_results.items():
        training_df = data.get('training')
        if training_df is not None and 'total_loss' in training_df.columns:
            config = CURRICULUM_CONFIGS[curriculum_key]

            # Calculate rolling average for smoother curve
            window = min(10, len(training_df) // 5)
            if window > 1:
                smoothed = training_df['total_loss'].rolling(window=window, min_periods=1).mean()
                ax.plot(training_df['step'], smoothed,
                        color=config['color'], label=config['name'], linewidth=2, alpha=0.8)

    ax.set_title('Smoothed Loss Convergence', fontweight='bold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Total Loss (Moving Average)')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # 2. Learning efficiency (loss reduction per step)
    ax = axes[0, 1]
    efficiency_data = []
    for curriculum_key, data in all_results.items():
        training_df = data.get('training')
        if training_df is not None and 'total_loss' in training_df.columns and len(training_df) > 1:
            config = CURRICULUM_CONFIGS[curriculum_key]

            # Calculate loss reduction rate
            initial_loss = training_df['total_loss'].iloc[:5].mean()
            final_loss = training_df['total_loss'].iloc[-5:].mean()
            steps = training_df['step'].iloc[-1] - training_df['step'].iloc[0]

            if steps > 0 and initial_loss > 0:
                efficiency = (initial_loss - final_loss) / steps
                efficiency_data.append({
                    'name': config['name'],
                    'efficiency': efficiency,
                    'color': config['color']
                })

    if efficiency_data:
        eff_df = pd.DataFrame(efficiency_data)
        bars = ax.bar(eff_df['name'], eff_df['efficiency'], color=eff_df['color'], alpha=0.7)
        ax.set_title('Learning Efficiency (Loss Reduction per Step)', fontweight='bold')
        ax.set_ylabel('Loss Reduction Rate')
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    else:
        ax.text(0.5, 0.5, 'Insufficient data for efficiency calculation',
                ha='center', va='center', transform=ax.transAxes)
        ax.set_title('Learning Efficiency (N/A)', fontweight='bold')

    # 3. Stability analysis (loss variance)
    ax = axes[1, 0]
    stability_data = []
    for curriculum_key, data in all_results.items():
        training_df = data.get('training')
        if training_df is not None and 'total_loss' in training_df.columns and len(training_df) > 10:
            config = CURRICULUM_CONFIGS[curriculum_key]

            # Calculate variance in second half of training
            half_point = len(training_df) // 2
            second_half_variance = training_df['total_loss'].iloc[half_point:].var()

            stability_data.append({
                'name': config['name'],
                'variance': second_half_variance,
                'color': config['color']
            })

    if stability_data:
        stab_df = pd.DataFrame(stability_data)
        bars = ax.bar(stab_df['name'], stab_df['variance'], color=stab_df['color'], alpha=0.7)
        ax.set_title('Training Stability (Lower Variance = More Stable)', fontweight='bold')
        ax.set_ylabel('Loss Variance (Second Half)')
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    else:
        ax.text(0.5, 0.5, 'Insufficient data for stability analysis',
                ha='center', va='center', transform=ax.transAxes)
        ax.set_title('Training Stability (N/A)', fontweight='bold')

    # 4. Convergence speed comparison
    ax = axes[1, 1]
    convergence_data = []
    for curriculum_key, data in all_results.items():
        training_df = data.get('training')
        if training_df is not None and 'total_loss' in training_df.columns and len(training_df) > 5:
            config = CURRICULUM_CONFIGS[curriculum_key]

            # Find step where loss reaches 90% of final reduction
            initial_loss = training_df['total_loss'].iloc[:5].mean()
            final_loss = training_df['total_loss'].iloc[-5:].mean()
            target_loss = initial_loss - 0.9 * (initial_loss - final_loss)

            # Find first step where loss goes below target
            below_target = training_df[training_df['total_loss'] <= target_loss]
            if not below_target.empty:
                convergence_step = below_target['step'].iloc[0]
            else:
                convergence_step = training_df['step'].iloc[-1]

            convergence_data.append({
                'name': config['name'],
                'steps_to_converge': convergence_step,
                'color': config['color']
            })

    if convergence_data:
        conv_df = pd.DataFrame(convergence_data)
        bars = ax.bar(conv_df['name'], conv_df['steps_to_converge'], color=conv_df['color'], alpha=0.7)
        ax.set_title('Steps to 90% Convergence (Lower = Faster)', fontweight='bold')
        ax.set_ylabel('Training Steps')
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    else:
        ax.text(0.5, 0.5, 'Insufficient data for convergence analysis',
                ha='center', va='center', transform=ax.transAxes)
        ax.set_title('Convergence Speed (N/A)', fontweight='bold')

    plt.tight_layout()
    plt.show()

# Generate visualizations
print("Generating curriculum comparison visualizations...")
visualize_curriculum_comparison(all_results, summary_df)

print("\nGenerating inverse task specific visualizations...")
visualize_inverse_task_metrics(all_results, summary_df)

print("\nGenerating convergence analysis...")
visualize_curriculum_convergence(all_results, summary_df)


In [None]:
## Final Consolidated Summary

def generate_final_summary_report(summary_df: pd.DataFrame, all_results: Dict):
    """Generate a comprehensive final summary report matching IRED paper format."""
    from tabulate import tabulate
    import numpy as np
    
    print("\n" + "="*80)
    print("FINAL SUMMARY: IRED-STYLE PERFORMANCE REPORT")
    print("="*80)
    
    # 1. Overall Best Methods Table
    print("\n📊 BEST PERFORMING METHODS BY TASK")
    print("-" * 50)
    
    best_methods = []
    for task in TASKS:
        task_data = summary_df[summary_df['task'] == task]
        if len(task_data) > 0:
            # Find best by MSE
            valid_data = task_data[task_data['final_val_mse'] != float('inf')]
            if len(valid_data) > 0:
                best = valid_data.loc[valid_data['final_val_mse'].idxmin()]
                
                row = [
                    task.upper(),
                    best['name'],
                    f"{best['final_val_mse']:.4f}" if isinstance(best['final_val_mse'], float) else "N/A"
                ]
                
                # Add task-specific metrics
                if task == 'inverse':
                    identity = best['final_identity_error']
                    row.append(f"{identity:.4f}" if identity != float('inf') else "N/A")
                else:
                    row.append("-")
                
                best_methods.append(row)
    
    headers = ['Task', 'Best Method', 'MSE', 'Identity Error']
    print(tabulate(best_methods, headers=headers, tablefmt='grid'))
    
    # 2. Training Efficiency Summary
    print("\n⚡ TRAINING EFFICIENCY METRICS")
    print("-" * 50)
    
    efficiency_data = []
    for curriculum_key in CURRICULUM_CONFIGS.keys():
        curriculum_data = summary_df[summary_df['curriculum'] == curriculum_key]
        if len(curriculum_data) > 0:
            avg_time = curriculum_data['avg_training_time'].mean()
            avg_loss = curriculum_data['final_total_loss'].mean()
            
            # Calculate convergence rate if training data available
            convergence_rate = "N/A"
            if curriculum_key in all_results:
                training_df = all_results[curriculum_key].get('training')
                if training_df is not None and len(training_df) > 10:
                    initial_loss = training_df['total_loss'].iloc[:5].mean()
                    final_loss = training_df['total_loss'].iloc[-5:].mean()
                    if initial_loss > 0:
                        improvement = (initial_loss - final_loss) / initial_loss * 100
                        convergence_rate = f"{improvement:.1f}%"
            
            efficiency_data.append([
                CURRICULUM_CONFIGS[curriculum_key]['name'],
                f"{avg_time:.2f}",
                f"{avg_loss:.4f}",
                convergence_rate
            ])
    
    headers = ['Method', 'Avg Time (s)', 'Avg Loss', 'Loss Reduction']
    print(tabulate(efficiency_data, headers=headers, tablefmt='grid'))
    
    # 3. Key Findings Summary
    print("\n🎯 KEY FINDINGS")
    print("-" * 50)
    
    # Find overall best method
    valid_mse = summary_df[summary_df['final_val_mse'] != float('inf')]
    if len(valid_mse) > 0:
        overall_best = valid_mse.groupby('name')['final_val_mse'].mean().idxmin()
        overall_best_mse = valid_mse.groupby('name')['final_val_mse'].mean().min()
        
        print(f"• Best Overall Method: {overall_best}")
        print(f"  Average MSE across tasks: {overall_best_mse:.4f}")
    
    # Check baseline comparison
    baseline_data = summary_df[summary_df['curriculum'] == 'baseline']
    non_baseline = summary_df[summary_df['curriculum'] != 'baseline']
    
    if len(baseline_data) > 0 and len(non_baseline) > 0:
        baseline_avg = baseline_data[baseline_data['final_val_mse'] != float('inf')]['final_val_mse'].mean()
        best_non_baseline = non_baseline[non_baseline['final_val_mse'] != float('inf')].groupby('name')['final_val_mse'].mean().min()
        
        if baseline_avg > 0 and not np.isnan(baseline_avg) and not np.isnan(best_non_baseline):
            improvement = (baseline_avg - best_non_baseline) / baseline_avg * 100
            print(f"\n• Curriculum Learning Impact:")
            print(f"  Best curriculum improves over baseline by {improvement:.1f}%")
    
    # Task-specific insights
    print("\n• Task-Specific Insights:")
    for task in TASKS:
        task_data = summary_df[summary_df['task'] == task]
        valid_task_data = task_data[task_data['final_val_mse'] != float('inf')]
        
        if len(valid_task_data) > 1:
            mse_range = valid_task_data['final_val_mse'].max() - valid_task_data['final_val_mse'].min()
            print(f"  {task.upper()}: MSE varies by {mse_range:.4f} across methods")
            
            if task == 'inverse':
                id_errors = task_data[task_data['final_identity_error'] != float('inf')]['final_identity_error']
                if len(id_errors) > 0:
                    print(f"    Identity error range: {id_errors.min():.4f} - {id_errors.max():.4f}")
    
    # 4. Recommendation
    print("\n💡 RECOMMENDATIONS")
    print("-" * 50)
    
    # Find most consistent performer
    method_scores = {}
    for method in summary_df['name'].unique():
        method_data = summary_df[summary_df['name'] == method]
        valid_mse = method_data[method_data['final_val_mse'] != float('inf')]['final_val_mse']
        
        if len(valid_mse) == len(TASKS):  # Has results for all tasks
            avg_mse = valid_mse.mean()
            std_mse = valid_mse.std()
            # Score based on low average and low variance (consistency)
            score = avg_mse + 0.5 * std_mse  
            method_scores[method] = (avg_mse, std_mse, score)
    
    if method_scores:
        best_consistent = min(method_scores.items(), key=lambda x: x[1][2])
        print(f"• Most consistent method: {best_consistent[0]}")
        print(f"  Avg MSE: {best_consistent[1][0]:.4f}, Std: {best_consistent[1][1]:.4f}")
    
    print("\n• For production use, consider:")
    print("  - Use curriculum learning for improved convergence")
    print("  - Monitor identity error for matrix inverse tasks")
    print("  - Track MSE as primary metric for all matrix operations")

# Generate the final summary report
generate_final_summary_report(summary_df, all_results)

In [None]:
def create_performance_radar_chart(summary_df: pd.DataFrame):
    """Create radar chart comparing multiple performance metrics."""
    from math import pi

    # Select metrics for radar chart (normalized to 0-1 scale, higher = better)
    metrics = {
        'Loss Reduction': 1 / (summary_df['final_total_loss'] + 1e-6),  # Lower loss = better
        'Training Speed': 1 / (summary_df['avg_training_time'] + 1e-6),   # Faster = better
        'Best Loss': 1 / (summary_df['best_total_loss'] + 1e-6),         # Lower loss = better
        'Val Accuracy': summary_df['best_val_accuracy'] / 100.0,          # Higher accuracy = better
        'Energy Margin': summary_df['final_energy_margin'] / (summary_df['final_energy_margin'].max() + 1e-6)  # Higher margin = better
    }

    # Normalize all metrics to 0-1 scale
    for metric_name, values in metrics.items():
        if values.max() > 0:
            metrics[metric_name] = values / values.max()

    # Create radar chart
    categories = list(metrics.keys())
    N = len(categories)

    # Compute angles
    angles = [n / float(N) * 2 * pi for n in range(N)]
    angles += angles[:1]  # Complete the circle

    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))
    ax.set_title('Multi-Curriculum Performance Radar Chart', fontsize=14, fontweight='bold', pad=20)

    # Plot each curriculum
    for idx, row in summary_df.iterrows():
        values = [metrics[cat].iloc[idx] for cat in categories]
        values += values[:1]  # Complete the circle

        ax.plot(angles, values, 'o-', linewidth=2, label=row['name'],
                color=row['color'], alpha=0.8)
        ax.fill(angles, values, alpha=0.15, color=row['color'])

    # Add labels
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories, fontweight='bold')
    ax.set_ylim(0, 1)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], alpha=0.7)
    ax.grid(True, alpha=0.3)

    plt.legend(loc='upper right', bbox_to_anchor=(1.2, 1.0))
    plt.tight_layout()
    plt.show()

def summary_analysis(summary_df: pd.DataFrame):
    """Generate comprehensive summary analysis with recommendations."""
    print("\n" + "="*80)
    print("COMPREHENSIVE CURRICULUM ANALYSIS")
    print("="*80)

    # Performance metrics comparison
    print("\n📊 PERFORMANCE METRICS COMPARISON")
    print("-" * 50)

    comparison_table = summary_df[['name', 'final_total_loss', 'best_total_loss',
                                  'final_val_accuracy', 'best_val_accuracy', 'avg_training_time']].copy()
    comparison_table.columns = ['Curriculum', 'Final Loss', 'Best Loss', 'Final Acc (%)', 'Best Acc (%)', 'Avg Time (s)']

    print(comparison_table.round(4).to_string(index=False))

    # Statistical analysis
    print("\n📈 STATISTICAL INSIGHTS")
    print("-" * 50)

    if len(summary_df) > 1:
        # Best performers
        best_final_loss = summary_df.loc[summary_df['final_total_loss'].idxmin(), 'name']
        best_convergence = summary_df.loc[summary_df['best_total_loss'].idxmin(), 'name']
        fastest_training = summary_df.loc[summary_df['avg_training_time'].idxmin(), 'name']

        print(f"🎯 Best Final Loss: {best_final_loss}")
        print(f"🎯 Best Convergence: {best_convergence}")
        print(f"⚡ Fastest Training: {fastest_training}")

        if summary_df['best_val_accuracy'].max() > 0:
            best_accuracy = summary_df.loc[summary_df['best_val_accuracy'].idxmax(), 'name']
            print(f"🎯 Best Accuracy: {best_accuracy}")

        # Improvement analysis vs baseline
        baseline_metrics = summary_df[summary_df['curriculum'] == 'baseline']
        if len(baseline_metrics) > 0:
            baseline_loss = baseline_metrics['final_total_loss'].iloc[0]

            print(f"\n📈 IMPROVEMENT OVER BASELINE")
            print("-" * 30)

            for _, row in summary_df.iterrows():
                if row['curriculum'] != 'baseline':
                    improvement = (baseline_loss - row['final_total_loss']) / baseline_loss * 100
                    print(f"{row['name']}: {improvement:+.1f}% loss improvement")

    # Generate radar chart
    print("\n🎯 MULTI-DIMENSIONAL PERFORMANCE ANALYSIS")
    print("-" * 50)
    create_performance_radar_chart(summary_df)

    return comparison_table

# Generate summary analysis
comparison_table = summary_analysis(summary_df)