In [None]:
# Final Consolidated Summary

def generate_final_summary_report(summary_df: pd.DataFrame, all_results: Dict):
    """Generate a final consolidated summary report of all task and curriculum results."""
    
    print("\n" + "="*80)
    print("🔬 FINAL CONSOLIDATED SUMMARY REPORT")
    print("="*80)
    
    print("\n📊 EXPERIMENTAL OVERVIEW:")
    print(f"   • Total tasks evaluated: {len(TASKS)}")
    print(f"   • Total curricula compared: {len(CURRICULA)}")
    print(f"   • Total experimental conditions: {len(TASKS) * len(CURRICULA)}")
    
    # Task performance overview
    print("\n🎯 TASK PERFORMANCE OVERVIEW:")
    for task_key, config in TASKS.items():
        print(f"   • {config['name']}: {config['description']}")
        print(f"     Model: {config['model']}, Training Steps: {config['train_steps']}")
    
    # Curriculum comparison overview
    print("\n📈 CURRICULUM COMPARISON OVERVIEW:")
    for curriculum_key, config in CURRICULA.items():
        print(f"   • {config['name']}: {config['description']}")
    
    # Summary statistics if data is available
    if not summary_df.empty:
        print("\n📋 PERFORMANCE SUMMARY:")
        print(summary_df.to_string(index=False))
    else:
        print("\n⚠️  No performance data available for consolidated summary")
    
    print("\n" + "="*80)
    print("📝 FINAL CONSOLIDATED SUMMARY COMPLETED")
    print("="*80)

# Generate the final consolidated summary if data is available
if 'task_summaries' in locals() and task_summaries:
    # Combine all task summaries into one DataFrame
    all_summaries = []
    for task_key, summary_df in task_summaries.items():
        summary_copy = summary_df.copy()
        summary_copy['task'] = task_key
        all_summaries.append(summary_copy)
    
    if all_summaries:
        consolidated_df = pd.concat(all_summaries, ignore_index=True)
        generate_final_summary_report(consolidated_df, all_results)
    else:
        print("No task summaries available for consolidated report")
else:
    print("Task summaries not yet generated - run previous cells first")

In [ ]:
def visualize_task_curriculum_comparison(all_results: Dict, summary_df: pd.DataFrame, task: str):
    """Create comprehensive visualizations comparing curriculum approaches for a specific task."""
    pass

In [ ]:
# IRED-Style Results Table Generation

def generate_ired_style_table(summary_df: pd.DataFrame, task: str):
    """Generate IRED-style results table for task comparison."""
    
    # Create formatted results table matching IRED paper style
    results_table = []
    
    for _, row in summary_df.iterrows():
        curriculum_name = row['curriculum_name']
        final_accuracy = row.get('final_accuracy', 0.0)
        best_accuracy = row.get('best_accuracy', 0.0)
        final_loss = row.get('final_total_loss', 0.0)
        
        results_table.append({
            'Method': curriculum_name,
            'Task': task,
            'Final Acc (%)': f"{final_accuracy:.2f}",
            'Best Acc (%)': f"{best_accuracy:.2f}",
            'Final Loss': f"{final_loss:.4f}"
        })
    
    # Convert to DataFrame for nice formatting
    results_df = pd.DataFrame(results_table)
    
    print(f"\\n=== IRED-Style Results Table for {task} ===")
    print(results_df.to_string(index=False))
    
    return results_df

# Generate IRED-style tables for all tasks
print("Generating IRED-style results tables...")
for task_key, summary_df in task_summaries.items():
    task_name = TASKS[task_key]['name']
    generate_ired_style_table(summary_df, task_name)

In [None]:
def generate_task_ranking(summary_df: pd.DataFrame):
    """Generate task ranking based on performance metrics."""
    
    # Create ranking based on accuracy
    ranking_data = []
    
    for _, row in summary_df.iterrows():
        curriculum_name = row['curriculum_name']
        final_accuracy = row.get('final_accuracy', 0.0)
        best_accuracy = row.get('best_accuracy', 0.0)
        final_loss = row.get('final_total_loss', 0.0)
        
        ranking_data.append({
            'Curriculum': curriculum_name,
            'Final Accuracy (%)': final_accuracy,
            'Best Accuracy (%)': best_accuracy,
            'Final Loss': final_loss,
            'Score': final_accuracy  # Simple ranking by final accuracy
        })
    
    # Sort by score (descending - higher accuracy is better)
    ranking_df = pd.DataFrame(ranking_data)
    ranking_df = ranking_df.sort_values('Score', ascending=False)
    
    print("\\nTask Performance Ranking:")
    print(ranking_df.to_string(index=False))
    
    return ranking_df

# IRED Tasks: Baseline vs Aggressive Curriculum Comparison

Comparative analysis of baseline training vs aggressive adversarial curriculum on core IRED reasoning tasks:
- **Connectivity**: Graph connectivity reasoning (12x12 graphs, 0.1 edge probability)
- **Sudoku**: Sudoku completion task (standard difficulty)

This notebook evaluates whether aggressive curriculum learning improves identity error performance within limited training steps for reasoning tasks.

**Key Research Question**: Does aggressive curriculum learning outperform baseline training on identity error metrics for IRED reasoning tasks?

**Tasks Evaluated**: Easier versions (non-OOD) of connectivity and sudoku tasks as described in IRED paper by Yilun Du.

In [None]:
# Clone repository and setup
!rm -rf energy-based-model
!git clone https://github.com/mdkrasnow/energy-based-model.git
%cd energy-based-model

In [None]:
%cd data
!chmod +x download-rrn.sh
!chmod +x download-satnet.sh
!bash download-satnet.sh
!bash download-rrn.sh
%cd ..  

In [None]:
# Install dependencies
!pip install -q torch torchvision einops accelerate tqdm tabulate matplotlib numpy pandas ema-pytorch ipdb seaborn scikit-learn

In [None]:
# Import required libraries
import subprocess
import sys
import glob
import os
from pathlib import Path
from typing import Dict, Optional, List, Tuple
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from math import pi

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

print("All libraries imported successfully!")

## Task and Curriculum Configuration

In [None]:
# Task-specific configurations based on IRED paper
TASKS = {
    'connectivity': {
        'name': 'Graph Connectivity',
        'description': 'Graph connectivity reasoning (12x12, p=0.1)',
        'dataset': 'connectivity',
        'model': 'gnn-conv-1d-v2',
        'batch_size': 64,  # From train.py validation_batch_size
        'train_steps': 600,  # Longer for reasoning tasks
        'validation_interval': 200,  # Validate every 200 steps (milestones 1, 2, 3)
        'extra_args': [],
        'color': '#2E8B57',
        'expected_metrics': ['accuracy', 'balanced_accuracy', 'precision', 'recall'],
        'note': 'Fixed: Now uses proper validation dataset instead of training samples'
    },
    'sudoku': {
        'name': 'Sudoku Completion', 
        'description': 'Sudoku puzzle completion task',
        'dataset': 'sudoku',
        'model': 'sudoku',
        'batch_size': 64,  # From train.py validation_batch_size  
        'train_steps': 600,
        'validation_interval': 200,  # Validate every 200 steps (milestones 1, 2, 3)
        'extra_args': ['--cond_mask', 'True'],  # Required for sudoku
        'color': '#4169E1',
        'expected_metrics': ['accuracy', 'consistency', 'board_accuracy']
    }
}

# Only 2 curricula to compare: baseline vs aggressive
CURRICULA = {
    'baseline': {
        'name': 'Baseline',
        'description': 'No curriculum learning',
        'args': ['--disable-curriculum', 'True'],
        'color': '#1f77b4'
    },
    'aggressive': {
        'name': 'Aggressive Curriculum', 
        'description': 'Rapid adversarial escalation',
        'args': ['--curriculum-config', 'aggressive'],
        'color': '#d62728'
    }
}

# Common training parameters - Enhanced for validation tracking
COMMON_ARGS = {
    'diffusion_steps': 10,
    'supervise_energy_landscape': 'True',
    'save_csv_logs': True,
    'csv_log_interval': 10,  # Log training metrics every 10 steps
}

print(f"Testing {len(TASKS)} IRED reasoning tasks:")
for key, config in TASKS.items():
    print(f"  • {config['name']}: {config['description']}")
    print(f"    Model: {config['model']}, Steps: {config['train_steps']}")
    print(f"    Validation every {config['validation_interval']} steps (3 evaluations total)")
    print(f"    Expected metrics: {', '.join(config['expected_metrics'])}")
    if 'note' in config:
        print(f"    NOTE: {config['note']}")

print(f"\nComparing {len(CURRICULA)} curriculum approaches:")
for key, config in CURRICULA.items():
    print(f"  • {config['name']}: {config['description']}")
    
print("\n⚠️  VALIDATION FIX APPLIED:")
print("  • Connectivity now uses proper validation dataset (not training samples)")
print("  • Extra validation datasets (13x13, 15x15, etc.) run every 2 milestones")
print("  • Output now displays 'Train Sample Evaluation' vs 'Validation' correctly")

## Training Command Builder and Utility Functions

In [None]:
def build_task_training_command(task_key: str, curriculum_key: str) -> str:
    """Build training command for a specific task-curriculum combination with validation tracking."""
    task = TASKS[task_key]
    curriculum = CURRICULA[curriculum_key]

    # Set save_and_sample_every based on validation interval
    save_and_sample_every = task.get('validation_interval', 10)

    base_cmd = f"""python train.py \
        --dataset {task['dataset']} \
        --model {task['model']} \
        --batch_size {task['batch_size']} \
        --diffusion_steps {COMMON_ARGS['diffusion_steps']} \
        --supervise-energy-landscape {COMMON_ARGS['supervise_energy_landscape']} \
        --train-num-steps {task['train_steps']} \
        --save-csv-logs \
        --csv-log-interval {COMMON_ARGS['csv_log_interval']} \
        --csv-log-dir ./csv_logs_{task_key}_{curriculum_key} """ 

    # Add task-specific arguments
    if task['extra_args']:
        base_cmd += ' \\\n        ' + ' \\\n        '.join(task['extra_args'])

    # Add curriculum arguments
    if curriculum['args']:
        base_cmd += ' \\\n        ' + ' \\\n        '.join(curriculum['args'])

    return base_cmd

def load_csv_data(csv_path: Path) -> Optional[pd.DataFrame]:
    """Load CSV data with error handling."""
    try:
        if csv_path.exists():
            df = pd.read_csv(csv_path)
            # Debug: show columns and sample data
            if len(df) > 0:
                print(f"  Loaded {csv_path.name}: {len(df)} rows, columns: {list(df.columns)[:5]}...")
            return df
        else:
            print(f"Warning: {csv_path} not found")
            return None
    except Exception as e:
        print(f"Error loading {csv_path}: {e}")
        return None

def extract_accuracy_from_validation_csv(df: pd.DataFrame, task_key: str) -> Dict[str, List[Tuple[int, float]]]:
    """Extract accuracy metrics from validation CSV data."""
    if df is None or len(df) == 0:
        return {}

    metrics = {}
    expected_metrics = TASKS[task_key].get('expected_metrics', ['accuracy'])

    # Check if required columns exist
    required_cols = ['step', 'metric_name', 'metric_value']
    if not all(col in df.columns for col in required_cols):
        print(f"Warning: Missing required columns in validation CSV. Found: {df.columns.tolist()}")
        return metrics

    # Extract each expected metric
    for metric_name in expected_metrics:
        metric_data = df[df['metric_name'] == metric_name].copy()
        if not metric_data.empty:
            # Sort by step and get (step, value) pairs
            metric_data = metric_data.sort_values('step')
            metrics[metric_name] = list(zip(metric_data['step'], metric_data['metric_value']))
        else:
            # Try alternative names (e.g., 'accuracy' might be stored as 'acc')
            alt_names = {
                'board_accuracy': ['board_acc', 'complete_boards'],
                'balanced_accuracy': ['bal_acc', 'balanced_acc']
            }
            for alt_name in alt_names.get(metric_name, []):
                metric_data = df[df['metric_name'] == alt_name].copy()
                if not metric_data.empty:
                    metric_data = metric_data.sort_values('step')
                    metrics[metric_name] = list(zip(metric_data['step'], metric_data['metric_value']))
                    break

    return metrics

def safe_get_final_value(df: pd.DataFrame, column: str, default: float = 0.0) -> float:
    """Safely get the final value from a dataframe column."""
    if df is None or column not in df.columns or len(df) == 0:
        return default
    return float(df[column].iloc[-1])

def safe_get_best_value(df: pd.DataFrame, column: str, minimize: bool = True, default: float = 0.0) -> float:
    """Safely get the best (min/max) value from a dataframe column."""
    if df is None or column not in df.columns or len(df) == 0:
        return default
    return float(df[column].min() if minimize else df[column].max())

def safe_get_accuracy_metrics(metrics_dict: Dict[str, List[Tuple[int, float]]]) -> Dict[str, float]:
    """Extract final and best accuracy values from metrics dictionary."""
    results = {}
    for metric_name, step_value_pairs in metrics_dict.items():
        if step_value_pairs:
            values = [v for _, v in step_value_pairs]
            results[f'{metric_name}_final'] = values[-1]
            results[f'{metric_name}_best'] = max(values)  # Accuracy metrics should be maximized
            results[f'{metric_name}_steps'] = len(values)
    return results

def calculate_convergence_step(df: pd.DataFrame, column: str, threshold_pct: float = 0.9) -> int:
    """Calculate step where metric reaches threshold percentage of final improvement."""
    if df is None or column not in df.columns or len(df) < 5:
        return -1

    initial_val = df[column].iloc[:5].mean()
    final_val = df[column].iloc[-5:].mean()
    target_val = initial_val - threshold_pct * (initial_val - final_val)

    below_target = df[df[column] <= target_val]
    if not below_target.empty:
        return int(below_target['step'].iloc[0])
    else:
        return int(df['step'].iloc[-1])

print("Utility functions defined successfully!")
print("\nExample command structure:")
print(build_task_training_command('connectivity', 'baseline')[:300] + "...")

## Multi-Task Multi-Curriculum Training

In [None]:
# Train all tasks with all curriculum configurations
import subprocess
import sys

training_results = {}

# Train each task
for task_key, task_config in TASKS.items():
    print(f"\n{'='*80}")
    print(f"TRAINING TASK: {task_config['name']}")
    print(f"Description: {task_config['description']}")
    print(f"Dataset: {task_config['dataset']}, Model: {task_config['model']}")
    print(f"Training Steps: {task_config['train_steps']}, Batch Size: {task_config['batch_size']}")
    print(f"{'='*80}")
    
    task_results = {}
    
    for curriculum_key, curriculum_config in CURRICULA.items():
        print(f"\n{'-'*60}")
        print(f"Starting {task_config['name']} with {curriculum_config['name']}")
        print(f"Curriculum Description: {curriculum_config['description']}")
        print(f"{'-'*60}")
        
        # Build training command
        cmd = build_task_training_command(task_key, curriculum_key)
        print(f"\nCommand: {cmd}")
        print("\nTraining output:")
        print("-" * 40)
        
        # Execute training with real-time output
        try:
            # Use subprocess to capture and display output in real-time
            process = subprocess.Popen(
                cmd,
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                bufsize=1
            )
            
            # Display output line by line as it comes
            for line in iter(process.stdout.readline, ''):
                if line:
                    print(line.rstrip())
                    sys.stdout.flush()
            
            # Wait for process to complete
            result = process.wait()
            task_results[curriculum_key] = result
            
            if result == 0:
                print(f"✓ {task_config['name']} with {curriculum_config['name']} completed successfully")
            else:
                print(f"✗ {task_config['name']} with {curriculum_config['name']} failed with exit code {result}")
                
        except Exception as e:
            print(f"✗ Error during {task_config['name']} with {curriculum_config['name']} training: {e}")
            task_results[curriculum_key] = -1
            
        print("-" * 40)
    
    training_results[task_key] = task_results

print(f"\n{'='*80}")
print("ALL TASK TRAINING COMPLETED!")
print(f"{'='*80}")

# Print comprehensive summary
print("\nTRAINING SUMMARY:")
print("=" * 40)
total_successful = 0
total_experiments = 0

for task_key, task_results in training_results.items():
    task_name = TASKS[task_key]['name']
    successful = sum(1 for result in task_results.values() if result == 0)
    total = len(task_results)
    total_successful += successful
    total_experiments += total
    
    print(f"\n{task_name}: {successful}/{total} completed successfully")
    for curriculum_key, result in task_results.items():
        curriculum_name = CURRICULA[curriculum_key]['name']
        status = "✓ Success" if result == 0 else "✗ Failed"
        print(f"  {curriculum_name}: {status}")

success_rate = (total_successful / total_experiments * 100) if total_experiments > 0 else 0
print(f"\nOVERALL SUCCESS RATE: {total_successful}/{total_experiments} ({success_rate:.1f}%)")

if success_rate >= 75:
    print("🎯 Training pipeline working well! Proceeding to analysis...")
elif success_rate >= 50:
    print("⚠️ Some training issues detected, but proceeding with available data...")
else:
    print("❌ Significant training issues detected. Check configurations and logs.")

## Task-Specific Data Loading and Processing

In [None]:
def load_all_curriculum_results() -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
    """Load curriculum results organized by task, then curriculum, then data type."""
    results = {}
    
    for task_key in TASKS.keys():
        print(f"Loading {task_key} curriculum results...")
        task_data = {}
        
        for curriculum_key in CURRICULA.keys():
            csv_dir = Path(f"./csv_logs_{task_key}_{curriculum_key}")
            
            curriculum_data = {}
            patterns = {
                'training': 'training_metrics_*.csv',
                'validation': 'validation_metrics_*.csv', 
                'energy': 'energy_metrics_*.csv',
                'curriculum': 'curriculum_metrics_*.csv',
                'robustness': 'robustness_metrics_*.csv'
            }
            
            for key, pattern in patterns.items():
                files = glob.glob(str(csv_dir / pattern))
                if files:
                    # Get most recent file
                    latest_file = max(files, key=lambda x: Path(x).stat().st_mtime if Path(x).exists() else 0)
                    curriculum_data[key] = load_csv_data(Path(latest_file))
                else:
                    curriculum_data[key] = None
            
            available = sum(1 for df in curriculum_data.values() if df is not None)
            print(f"  {curriculum_key}: {available}/5 data files found")
            task_data[curriculum_key] = curriculum_data
        
        results[task_key] = task_data
    
    return results

In [None]:
def load_all_task_results() -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
    """Load results organized by task, then curriculum, then data type."""
    results = {}
    
    for task_key in TASKS.keys():
        print(f"Loading {task_key} results...")
        task_data = {}
        
        for curriculum_key in CURRICULA.keys():
            csv_dir = Path(f"./csv_logs_{task_key}_{curriculum_key}")
            
            curriculum_data = {}
            patterns = {
                'training': 'training_metrics_*.csv',
                'validation': 'validation_metrics_*.csv', 
                'energy': 'energy_metrics_*.csv',
                'curriculum': 'curriculum_metrics_*.csv',
                'robustness': 'robustness_metrics_*.csv'
            }
            
            for key, pattern in patterns.items():
                files = glob.glob(str(csv_dir / pattern))
                if files:
                    # Get most recent file
                    latest_file = max(files, key=lambda x: Path(x).stat().st_mtime if Path(x).exists() else 0)
                    curriculum_data[key] = load_csv_data(Path(latest_file))
                else:
                    curriculum_data[key] = None
            
            available = sum(1 for df in curriculum_data.values() if df is not None)
            print(f"  {curriculum_key}: {available}/5 data files found")
            task_data[curriculum_key] = curriculum_data
        
        results[task_key] = task_data
    
    return results

def extract_task_metrics(task_key: str, task_data: Dict) -> pd.DataFrame:
    """Extract key metrics for a specific task including accuracy metrics."""
    processed_data = []
    
    for curriculum_key, data in task_data.items():
        curriculum_config = CURRICULA[curriculum_key]
        training_df = data.get('training')
        validation_df = data.get('validation') 
        energy_df = data.get('energy')
        
        # Extract accuracy metrics from validation CSV
        accuracy_metrics = extract_accuracy_from_validation_csv(validation_df, task_key)
        accuracy_summary = safe_get_accuracy_metrics(accuracy_metrics)
        
        # Get main accuracy metric (first in expected_metrics list)
        main_metric = TASKS[task_key]['expected_metrics'][0]
        final_accuracy = accuracy_summary.get(f'{main_metric}_final', 0.0) * 100  # Convert to percentage
        best_accuracy = accuracy_summary.get(f'{main_metric}_best', 0.0) * 100
        
        # Calculate convergence step for total loss
        convergence_step = -1
        if training_df is not None and 'total_loss' in training_df.columns:
            convergence_step = calculate_convergence_step(training_df, 'total_loss')
        
        # Calculate accuracy convergence (step where 90% of final accuracy is reached)
        accuracy_convergence_step = -1
        if accuracy_metrics.get(main_metric):
            steps, values = zip(*accuracy_metrics[main_metric])
            if len(values) >= 5:
                initial_acc = np.mean(values[:5])
                final_acc = values[-1]
                target_acc = initial_acc + 0.9 * (final_acc - initial_acc)
                for i, (step, value) in enumerate(accuracy_metrics[main_metric]):
                    if value >= target_acc:
                        accuracy_convergence_step = step
                        break
        
        metrics = {
            'task': task_key,
            'task_name': TASKS[task_key]['name'],
            'curriculum': curriculum_key,
            'curriculum_name': curriculum_config['name'],
            'color': curriculum_config['color'],
            
            # Training metrics
            'final_total_loss': safe_get_final_value(training_df, 'total_loss'),
            'best_total_loss': safe_get_best_value(training_df, 'total_loss', minimize=True),
            'final_energy_loss': safe_get_final_value(training_df, 'loss_energy'),
            'final_denoise_loss': safe_get_final_value(training_df, 'loss_denoise'),
            
            # Key accuracy metrics
            'final_accuracy': final_accuracy,
            'best_accuracy': best_accuracy,
            'accuracy_convergence_step': accuracy_convergence_step,
            
            # Additional task-specific metrics
            **{f'{k}': v * 100 for k, v in accuracy_summary.items() if k.endswith('_final')},
            
            # Performance metrics
            'training_time': training_df['nn_time'].mean() if training_df is not None and 'nn_time' in training_df.columns else 0.0,
            'convergence_step': convergence_step,
            'validation_samples': accuracy_summary.get(f'{main_metric}_steps', 0),
            
            # Energy landscape metrics
            'final_energy_margin': safe_get_final_value(energy_df, 'energy_margin'),
            'curriculum_weight': safe_get_best_value(energy_df, 'curriculum_weight', minimize=False) if energy_df is not None and 'curriculum_weight' in energy_df.columns else 0.0
        }
        
        # Store accuracy time series for plotting
        metrics['accuracy_history'] = accuracy_metrics.get(main_metric, [])
        
        processed_data.append(metrics)
    
    return pd.DataFrame(processed_data)

def calculate_improvement_metrics(task_summary: pd.DataFrame) -> Dict[str, float]:
    """Calculate improvement metrics for aggressive vs baseline."""
    baseline_row = task_summary[task_summary['curriculum'] == 'baseline']
    aggressive_row = task_summary[task_summary['curriculum'] == 'aggressive']
    
    improvements = {}
    
    if len(baseline_row) > 0 and len(aggressive_row) > 0:
        baseline = baseline_row.iloc[0]
        aggressive = aggressive_row.iloc[0]
        
        # Loss improvement (lower is better)
        if baseline['final_total_loss'] > 0 and aggressive['final_total_loss'] > 0:
            improvements['total_loss'] = (baseline['final_total_loss'] - aggressive['final_total_loss']) / baseline['final_total_loss'] * 100
        else:
            improvements['total_loss'] = 0.0
        
        # Accuracy improvement (higher is better)
        if baseline['final_accuracy'] > 0:
            improvements['accuracy'] = (aggressive['final_accuracy'] - baseline['final_accuracy'])  # Absolute difference in percentage points
            improvements['accuracy_relative'] = (aggressive['final_accuracy'] - baseline['final_accuracy']) / baseline['final_accuracy'] * 100
        else:
            improvements['accuracy'] = aggressive['final_accuracy'] - baseline['final_accuracy']
            improvements['accuracy_relative'] = 0.0
        
        # Accuracy convergence speed (lower steps is better)
        if baseline['accuracy_convergence_step'] > 0 and aggressive['accuracy_convergence_step'] > 0:
            improvements['accuracy_convergence_speed'] = (baseline['accuracy_convergence_step'] - aggressive['accuracy_convergence_step']) / baseline['accuracy_convergence_step'] * 100
        else:
            improvements['accuracy_convergence_speed'] = 0.0
        
        # Training speed comparison (lower time is better)
        if baseline['training_time'] > 0 and aggressive['training_time'] > 0:
            improvements['training_speed'] = (baseline['training_time'] - aggressive['training_time']) / baseline['training_time'] * 100
        else:
            improvements['training_speed'] = 0.0
        
        # Convergence speed (lower steps is better)
        if baseline['convergence_step'] > 0 and aggressive['convergence_step'] > 0:
            improvements['convergence_speed'] = (baseline['convergence_step'] - aggressive['convergence_step']) / baseline['convergence_step'] * 100
        else:
            improvements['convergence_speed'] = 0.0
    
    return improvements

# Load all results
print("Loading all task-curriculum results...")
all_results = load_all_task_results()

# Process each task separately  
task_summaries = {}
task_improvements = {}

print("\nProcessing task data and calculating improvements...")
for task_key, task_data in all_results.items():
    task_summaries[task_key] = extract_task_metrics(task_key, task_data)
    task_improvements[task_key] = calculate_improvement_metrics(task_summaries[task_key])
    
    print(f"\n{TASKS[task_key]['name']} Summary:")
    display_cols = ['curriculum_name', 'final_total_loss', 'final_accuracy', 'best_accuracy', 'accuracy_convergence_step', 'validation_samples']
    # Filter columns that exist
    display_cols = [col for col in display_cols if col in task_summaries[task_key].columns]
    if display_cols:
        display(task_summaries[task_key][display_cols].round(2))
    else:
        print("  No data available yet - run training first")
    
    # Print improvements
    improvements = task_improvements[task_key]
    if improvements:
        print(f"\nImprovements (Aggressive vs Baseline):")
        print(f"  Total Loss: {improvements.get('total_loss', 0):.1f}%")
        print(f"  Accuracy: {improvements.get('accuracy', 0):+.1f} percentage points ({improvements.get('accuracy_relative', 0):+.1f}% relative)")
        print(f"  Accuracy Convergence Speed: {improvements.get('accuracy_convergence_speed', 0):.1f}%")

print(f"\nLoaded and processed data for {len(task_summaries)} tasks")
print("Data loading and processing complete!")

## Task-Specific Visualization Functions

In [None]:
def visualize_task_comparison(task_key: str, task_data: Dict, task_summary: pd.DataFrame):
    """Create comprehensive visualizations for a single task with focus on accuracy."""
    task_config = TASKS[task_key]
    
    fig, axes = plt.subplots(3, 2, figsize=(16, 18))
    fig.suptitle(f'{task_config["name"]} - Baseline vs Aggressive Curriculum', fontsize=16, fontweight='bold')
    
    # 1. Training Loss Curves
    ax = axes[0, 0]
    for curriculum_key, data in task_data.items():
        training_df = data.get('training')
        if training_df is not None and 'total_loss' in training_df.columns:
            curriculum_config = CURRICULA[curriculum_key]
            # Apply smoothing for better visualization
            window = min(10, len(training_df) // 20) if len(training_df) > 20 else 1
            if window > 1:
                smoothed_loss = training_df['total_loss'].rolling(window=window, min_periods=1).mean()
            else:
                smoothed_loss = training_df['total_loss']
            
            ax.plot(training_df['step'], smoothed_loss,
                    color=curriculum_config['color'], label=curriculum_config['name'], 
                    linewidth=2, alpha=0.8)
    
    ax.set_title('Training Loss Evolution (Smoothed)', fontweight='bold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Total Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')  # Log scale for better loss visualization
    
    # 2. PRIMARY: Accuracy Evolution Over Time
    ax = axes[0, 1]
    accuracy_data_found = False
    
    for curriculum_idx, curriculum_key in enumerate(task_data.keys()):
        curriculum_config = CURRICULA[curriculum_key]
        summary_row = task_summary[task_summary['curriculum'] == curriculum_key]
        
        if len(summary_row) > 0 and 'accuracy_history' in summary_row.columns:
            accuracy_history = summary_row.iloc[0]['accuracy_history']
            if accuracy_history and len(accuracy_history) > 0:
                steps, values = zip(*accuracy_history)
                # Convert to percentage
                values = [v * 100 for v in values]
                ax.plot(steps, values,
                        color=curriculum_config['color'], label=curriculum_config['name'],
                        linewidth=3, alpha=0.9, marker='o', markersize=4)
                accuracy_data_found = True
                
                # Add final value annotation
                if len(values) > 0:
                    ax.annotate(f'{values[-1]:.1f}%', 
                               xy=(steps[-1], values[-1]),
                               xytext=(5, 5), textcoords='offset points',
                               fontsize=10, fontweight='bold',
                               color=curriculum_config['color'])
    
    if accuracy_data_found:
        ax.set_title('Task Accuracy Evolution (PRIMARY METRIC)', fontweight='bold', color='darkred')
        ax.set_xlabel('Training Step')
        ax.set_ylabel('Accuracy (%)')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_ylim(bottom=0)
        # Add reference line at 90% accuracy
        ax.axhline(y=90, color='green', linestyle='--', alpha=0.5, label='90% Target')
    else:
        ax.text(0.5, 0.5, 'No accuracy data available\n(Run training with validation enabled)', 
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Accuracy Evolution (No Data)', fontweight='bold')
    
    # 3. Final Accuracy Comparison (Bar Chart)
    ax = axes[1, 0]
    curricula_names = task_summary['curriculum_name'].tolist()
    accuracies = task_summary['final_accuracy'].tolist()
    best_accuracies = task_summary['best_accuracy'].tolist() if 'best_accuracy' in task_summary.columns else accuracies
    colors = task_summary['color'].tolist()
    
    if max(accuracies) > 0:
        x = np.arange(len(curricula_names))
        width = 0.35
        
        bars1 = ax.bar(x - width/2, accuracies, width, label='Final Accuracy',
                      color=colors, alpha=0.7, edgecolor='black', linewidth=2)
        bars2 = ax.bar(x + width/2, best_accuracies, width, label='Best Accuracy',
                      color=colors, alpha=0.5, edgecolor='black', linewidth=2, hatch='//')
        
        ax.set_title('Accuracy Comparison (TASK PERFORMANCE)', fontweight='bold', color='darkred')
        ax.set_ylabel('Accuracy (%)')
        ax.set_xlabel('Curriculum')
        ax.set_xticks(x)
        ax.set_xticklabels(curricula_names)
        ax.set_ylim(0, max(max(accuracies), max(best_accuracies)) * 1.2 if max(best_accuracies) > 0 else 100)
        ax.legend()
        
        # Add value labels
        for bar1, bar2, acc, best in zip(bars1, bars2, accuracies, best_accuracies):
            ax.text(bar1.get_x() + bar1.get_width()/2, bar1.get_height() + 1,
                    f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')
            ax.text(bar2.get_x() + bar2.get_width()/2, bar2.get_height() + 1,
                    f'{best:.1f}%', ha='center', va='bottom', fontweight='bold', style='italic')
        
        # Add improvement annotation
        if len(accuracies) == 2:
            baseline_idx = 0 if 'baseline' in curricula_names[0].lower() else 1
            aggressive_idx = 1 - baseline_idx
            
            improvement = accuracies[aggressive_idx] - accuracies[baseline_idx]
            color = 'green' if improvement > 0 else 'red'
            ax.text(0.5, 0.95, f'Accuracy Gain: {improvement:+.1f} percentage points', 
                    ha='center', transform=ax.transAxes, fontweight='bold', fontsize=12,
                    bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.3))
    else:
        ax.text(0.5, 0.5, 'No accuracy data available', 
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Accuracy Comparison (No Data)', fontweight='bold')
    
    # 4. Accuracy vs Loss Trade-off
    ax = axes[1, 1]
    if 'final_accuracy' in task_summary.columns and 'final_total_loss' in task_summary.columns:
        for _, row in task_summary.iterrows():
            ax.scatter(row['final_total_loss'], row['final_accuracy'],
                      s=200, c=[row['color']], alpha=0.7,
                      edgecolors='black', linewidth=2,
                      label=row['curriculum_name'])
            # Add text annotation
            ax.annotate(row['curriculum_name'],
                       xy=(row['final_total_loss'], row['final_accuracy']),
                       xytext=(5, 5), textcoords='offset points',
                       fontsize=10)
        
        ax.set_title('Accuracy vs Loss Trade-off', fontweight='bold')
        ax.set_xlabel('Final Total Loss (Lower is Better)')
        ax.set_ylabel('Final Accuracy (%) (Higher is Better)')
        ax.set_xscale('log')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # Add ideal direction arrow
        ax.annotate('', xy=(0.8, 0.9), xytext=(0.95, 0.75),
                   xycoords='axes fraction',
                   arrowprops=dict(arrowstyle='->', lw=2, color='green', alpha=0.5))
        ax.text(0.75, 0.95, 'Better', transform=ax.transAxes,
               fontsize=10, color='green', fontweight='bold')
    else:
        ax.text(0.5, 0.5, 'Insufficient data for trade-off analysis', 
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Accuracy vs Loss Trade-off (No Data)', fontweight='bold')
    
    # 5. Convergence Speed Comparison
    ax = axes[2, 0]
    conv_data = []
    if 'accuracy_convergence_step' in task_summary.columns:
        for _, row in task_summary.iterrows():
            if row['accuracy_convergence_step'] > 0:
                conv_data.append({
                    'curriculum': row['curriculum_name'],
                    'loss_conv': row.get('convergence_step', -1),
                    'acc_conv': row['accuracy_convergence_step'],
                    'color': row['color']
                })
    
    if conv_data:
        conv_df = pd.DataFrame(conv_data)
        x = np.arange(len(conv_df))
        width = 0.35
        
        bars1 = ax.bar(x - width/2, conv_df['loss_conv'], width, label='Loss Convergence',
                      color=conv_df['color'], alpha=0.5, edgecolor='black')
        bars2 = ax.bar(x + width/2, conv_df['acc_conv'], width, label='Accuracy Convergence',
                      color=conv_df['color'], alpha=0.7, edgecolor='black')
        
        ax.set_title('Convergence Speed Comparison (Lower is Faster)', fontweight='bold')
        ax.set_ylabel('Steps to Convergence')
        ax.set_xlabel('Curriculum')
        ax.set_xticks(x)
        ax.set_xticklabels(conv_df['curriculum'])
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add value labels
        for bar1, bar2, loss_c, acc_c in zip(bars1, bars2, conv_df['loss_conv'], conv_df['acc_conv']):
            if loss_c > 0:
                ax.text(bar1.get_x() + bar1.get_width()/2, bar1.get_height() + 1,
                        f'{int(loss_c)}', ha='center', va='bottom')
            if acc_c > 0:
                ax.text(bar2.get_x() + bar2.get_width()/2, bar2.get_height() + 1,
                        f'{int(acc_c)}', ha='center', va='bottom', fontweight='bold')
    else:
        ax.text(0.5, 0.5, 'No convergence data available', 
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Convergence Speed (No Data)', fontweight='bold')
    
    # 6. Task-Specific Metrics
    ax = axes[2, 1]
    task_specific_metrics = []
    
    # Get all metrics that end with '_final' except the main accuracy metric
    main_metric = TASKS[task_key]['expected_metrics'][0]
    for col in task_summary.columns:
        if col.endswith('_final') and not col.startswith(main_metric):
            metric_name = col.replace('_final', '').replace('_', ' ').title()
            task_specific_metrics.append((metric_name, col))
    
    if task_specific_metrics:
        metric_data = []
        for _, row in task_summary.iterrows():
            for metric_name, col in task_specific_metrics[:3]:  # Show top 3 additional metrics
                if col in row and row[col] > 0:
                    metric_data.append({
                        'curriculum': row['curriculum_name'],
                        'metric': metric_name,
                        'value': row[col],
                        'color': row['color']
                    })
        
        if metric_data:
            metric_df = pd.DataFrame(metric_data)
            # Create grouped bar chart
            metrics = metric_df['metric'].unique()
            curricula = metric_df['curriculum'].unique()
            
            x = np.arange(len(metrics))
            width = 0.35
            
            for i, curriculum in enumerate(curricula):
                curr_data = metric_df[metric_df['curriculum'] == curriculum]
                values = [curr_data[curr_data['metric'] == m]['value'].iloc[0] if len(curr_data[curr_data['metric'] == m]) > 0 else 0 
                         for m in metrics]
                color = CURRICULA[list(CURRICULA.keys())[i]]['color']
                ax.bar(x + i * width - width/2, values, width, label=curriculum,
                      color=color, alpha=0.7, edgecolor='black')
            
            ax.set_title(f'Additional {task_config["name"]} Metrics', fontweight='bold')
            ax.set_ylabel('Value (%)')
            ax.set_xlabel('Metric')
            ax.set_xticks(x)
            ax.set_xticklabels(metrics, rotation=45, ha='right')
            ax.legend()
            ax.grid(True, alpha=0.3, axis='y')
        else:
            ax.text(0.5, 0.5, 'No additional metrics available', 
                    ha='center', va='center', transform=ax.transAxes, fontsize=12)
            ax.set_title('Additional Metrics (No Data)', fontweight='bold')
    else:
        ax.text(0.5, 0.5, 'No additional metrics available', 
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Additional Metrics (No Data)', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

def create_cross_task_comparison(task_summaries: Dict, task_improvements: Dict):
    """Compare aggressive curriculum performance across tasks with focus on accuracy."""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Cross-Task Accuracy Performance: Aggressive vs Baseline Curriculum', fontsize=16, fontweight='bold')
    
    # Extract data for comparison
    task_names = []
    baseline_losses = []
    aggressive_losses = []
    baseline_accuracies = []
    aggressive_accuracies = []
    baseline_best_acc = []
    aggressive_best_acc = []
    accuracy_improvements = []
    accuracy_improvements_rel = []
    
    for task_key, summary_df in task_summaries.items():
        task_names.append(TASKS[task_key]['name'])
        
        baseline_row = summary_df[summary_df['curriculum'] == 'baseline']
        aggressive_row = summary_df[summary_df['curriculum'] == 'aggressive']
        
        if len(baseline_row) > 0 and len(aggressive_row) > 0:
            baseline = baseline_row.iloc[0]
            aggressive = aggressive_row.iloc[0]
            
            baseline_losses.append(baseline['final_total_loss'])
            aggressive_losses.append(aggressive['final_total_loss'])
            baseline_accuracies.append(baseline['final_accuracy'])
            aggressive_accuracies.append(aggressive['final_accuracy'])
            baseline_best_acc.append(baseline.get('best_accuracy', baseline['final_accuracy']))
            aggressive_best_acc.append(aggressive.get('best_accuracy', aggressive['final_accuracy']))
            
            # Get improvements from pre-calculated data
            improvements = task_improvements[task_key]
            accuracy_improvements.append(improvements.get('accuracy', 0))
            accuracy_improvements_rel.append(improvements.get('accuracy_relative', 0))
        else:
            baseline_losses.append(0)
            aggressive_losses.append(0)
            baseline_accuracies.append(0)
            aggressive_accuracies.append(0)
            baseline_best_acc.append(0)
            aggressive_best_acc.append(0)
            accuracy_improvements.append(0)
            accuracy_improvements_rel.append(0)
    
    # 1. PRIMARY: Accuracy Comparison by Task
    ax = axes[0, 0]
    valid_acc_indices = [i for i, (b, a) in enumerate(zip(baseline_accuracies, aggressive_accuracies)) 
                         if b > 0 or a > 0]
    
    if valid_acc_indices:
        valid_names = [task_names[i] for i in valid_acc_indices]
        valid_baseline_acc = [baseline_accuracies[i] for i in valid_acc_indices]
        valid_aggressive_acc = [aggressive_accuracies[i] for i in valid_acc_indices]
        valid_x = np.arange(len(valid_names))
        width = 0.35
        
        baseline_bars = ax.bar(valid_x - width/2, valid_baseline_acc, width, label='Baseline', 
                              color=CURRICULA['baseline']['color'], alpha=0.7, edgecolor='black')
        aggressive_bars = ax.bar(valid_x + width/2, valid_aggressive_acc, width, label='Aggressive', 
                                color=CURRICULA['aggressive']['color'], alpha=0.7, edgecolor='black')
        
        ax.set_title('Final Accuracy by Task (PRIMARY METRIC)', fontweight='bold', color='darkred')
        ax.set_ylabel('Accuracy (%)')
        ax.set_xlabel('Task')
        ax.set_xticks(valid_x)
        ax.set_xticklabels(valid_names, rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add value labels
        for bar, val in zip(baseline_bars, valid_baseline_acc):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                    f'{val:.1f}%', ha='center', va='bottom', fontsize=9)
        for bar, val in zip(aggressive_bars, valid_aggressive_acc):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                    f'{val:.1f}%', ha='center', va='bottom', fontsize=9, fontweight='bold')
    else:
        ax.text(0.5, 0.5, 'No accuracy data available', ha='center', va='center', 
                transform=ax.transAxes, fontsize=12)
        ax.set_title('Accuracy by Task (No Data)', fontweight='bold')
    
    # 2. Accuracy Improvement (Percentage Points)
    ax = axes[0, 1]
    valid_improvements = []
    for idx, (name, imp) in enumerate(zip(task_names, accuracy_improvements)):
        if imp != 0 or baseline_accuracies[idx] > 0 or aggressive_accuracies[idx] > 0:
            valid_improvements.append((name, imp))
    
    if valid_improvements:
        names, improvements = zip(*valid_improvements) if valid_improvements else ([], [])
        colors = ['green' if imp > 0 else 'red' for imp in improvements]
        bars = ax.bar(names, improvements, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
        ax.axhline(y=0, color='black', linestyle='-', alpha=0.5)
        ax.set_title('Accuracy Improvement (Percentage Points)', fontweight='bold', color='darkred')
        ax.set_ylabel('Improvement (pp)')
        ax.set_xlabel('Task')
        
        # Add value labels
        for bar, imp in zip(bars, improvements):
            y_pos = bar.get_height() + (0.5 if imp > 0 else -0.5)
            ax.text(bar.get_x() + bar.get_width()/2, y_pos,
                    f'{imp:+.1f}pp', ha='center', va='bottom' if imp > 0 else 'top', 
                    fontweight='bold', fontsize=10)
        
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        ax.grid(True, alpha=0.3, axis='y')
    else:
        ax.text(0.5, 0.5, 'No improvement data available', ha='center', va='center', 
                transform=ax.transAxes, fontsize=12)
        ax.set_title('Accuracy Improvement (No Data)', fontweight='bold')
    
    # 3. Best Accuracy Achieved
    ax = axes[0, 2]
    valid_best_indices = [i for i, (b, a) in enumerate(zip(baseline_best_acc, aggressive_best_acc)) 
                          if b > 0 or a > 0]
    
    if valid_best_indices:
        valid_names = [task_names[i] for i in valid_best_indices]
        valid_baseline_best = [baseline_best_acc[i] for i in valid_best_indices]
        valid_aggressive_best = [aggressive_best_acc[i] for i in valid_best_indices]
        valid_x = np.arange(len(valid_names))
        
        baseline_bars = ax.bar(valid_x - width/2, valid_baseline_best, width, label='Baseline Best', 
                              color=CURRICULA['baseline']['color'], alpha=0.5, edgecolor='black', hatch='//')
        aggressive_bars = ax.bar(valid_x + width/2, valid_aggressive_best, width, label='Aggressive Best', 
                                color=CURRICULA['aggressive']['color'], alpha=0.5, edgecolor='black', hatch='//')
        
        ax.set_title('Best Accuracy Achieved', fontweight='bold')
        ax.set_ylabel('Best Accuracy (%)')
        ax.set_xlabel('Task')
        ax.set_xticks(valid_x)
        ax.set_xticklabels(valid_names, rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
    else:
        ax.text(0.5, 0.5, 'No best accuracy data available', ha='center', va='center', 
                transform=ax.transAxes, fontsize=12)
        ax.set_title('Best Accuracy (No Data)', fontweight='bold')
    
    # 4. Loss Comparison (for reference)
    ax = axes[1, 0]
    valid_indices = [i for i, (b, a) in enumerate(zip(baseline_losses, aggressive_losses)) if b > 0 and a > 0]
    
    if valid_indices:
        valid_names = [task_names[i] for i in valid_indices]
        valid_baseline_losses = [baseline_losses[i] for i in valid_indices]
        valid_aggressive_losses = [aggressive_losses[i] for i in valid_indices]
        valid_x = np.arange(len(valid_names))
        
        baseline_bars = ax.bar(valid_x - width/2, valid_baseline_losses, width, label='Baseline', 
                              color=CURRICULA['baseline']['color'], alpha=0.7, edgecolor='black')
        aggressive_bars = ax.bar(valid_x + width/2, valid_aggressive_losses, width, label='Aggressive', 
                                color=CURRICULA['aggressive']['color'], alpha=0.7, edgecolor='black')
        
        ax.set_title('Total Loss by Task (Lower = Better)', fontweight='bold')
        ax.set_ylabel('Total Loss')
        ax.set_xlabel('Task')
        ax.set_xticks(valid_x)
        ax.set_xticklabels(valid_names, rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        ax.set_yscale('log')
    else:
        ax.text(0.5, 0.5, 'No valid loss data', ha='center', va='center', 
                transform=ax.transAxes, fontsize=12)
        ax.set_title('Loss by Task (No Data)', fontweight='bold')
    
    # 5. Accuracy vs Loss Scatter (All Tasks)
    ax = axes[1, 1]
    scatter_data = []
    for i, task_name in enumerate(task_names):
        if baseline_losses[i] > 0 and baseline_accuracies[i] >= 0:
            scatter_data.append({
                'loss': baseline_losses[i],
                'accuracy': baseline_accuracies[i],
                'task': task_name,
                'curriculum': 'Baseline',
                'color': CURRICULA['baseline']['color']
            })
        if aggressive_losses[i] > 0 and aggressive_accuracies[i] >= 0:
            scatter_data.append({
                'loss': aggressive_losses[i],
                'accuracy': aggressive_accuracies[i],
                'task': task_name,
                'curriculum': 'Aggressive',
                'color': CURRICULA['aggressive']['color']
            })
    
    if scatter_data:
        for item in scatter_data:
            marker = 'o' if item['curriculum'] == 'Baseline' else '^'
            ax.scatter(item['loss'], item['accuracy'], s=150, c=[item['color']], 
                      marker=marker, alpha=0.7, edgecolors='black', linewidth=2)
            # Add task label
            ax.annotate(item['task'][:3], xy=(item['loss'], item['accuracy']),
                       xytext=(2, 2), textcoords='offset points', fontsize=8)
        
        ax.set_title('Accuracy vs Loss Trade-off (All Tasks)', fontweight='bold')
        ax.set_xlabel('Total Loss (Lower is Better)')
        ax.set_ylabel('Accuracy (%) (Higher is Better)')
        ax.set_xscale('log')
        ax.grid(True, alpha=0.3)
        
        # Add legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor=CURRICULA['baseline']['color'], alpha=0.7, label='Baseline'),
            Patch(facecolor=CURRICULA['aggressive']['color'], alpha=0.7, label='Aggressive')
        ]
        ax.legend(handles=legend_elements)
    else:
        ax.text(0.5, 0.5, 'No data for scatter plot', ha='center', va='center', 
                transform=ax.transAxes, fontsize=12)
        ax.set_title('Accuracy vs Loss (No Data)', fontweight='bold')
    
    # 6. Overall Summary Score
    ax = axes[1, 2]
    
    # Create a summary score for each task (weighted: accuracy 70%, loss reduction 30%)
    task_scores = []
    score_names = []
    
    for task_key, improvements in task_improvements.items():
        if any(imp != 0 for imp in improvements.values()):
            # Weighted score: accuracy improvement matters more
            score = (improvements.get('accuracy', 0) * 0.7 + 
                    improvements.get('total_loss', 0) * 0.3)
            task_scores.append(score)
            score_names.append(TASKS[task_key]['name'])
    
    if task_scores:
        colors = ['green' if score > 0 else 'red' for score in task_scores]
        bars = ax.bar(score_names, task_scores, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
        ax.axhline(y=0, color='black', linestyle='-', alpha=0.5)
        ax.set_title('Overall Performance Score\n(70% Accuracy, 30% Loss)', fontweight='bold')
        ax.set_ylabel('Composite Score')
        ax.set_xlabel('Task')
        
        # Add value labels
        for bar, score in zip(bars, task_scores):
            y_pos = bar.get_height() + (0.5 if score > 0 else -0.5)
            ax.text(bar.get_x() + bar.get_width()/2, y_pos,
                    f'{score:+.1f}', ha='center', va='bottom' if score > 0 else 'top', fontweight='bold')
        
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add overall verdict
        avg_score = np.mean(task_scores)
        verdict_color = 'green' if avg_score > 0 else 'red'
        ax.text(0.5, 0.95, f'Average Score: {avg_score:+.1f}', 
                ha='center', transform=ax.transAxes, fontweight='bold', fontsize=12,
                bbox=dict(boxstyle="round,pad=0.3", facecolor=verdict_color, alpha=0.3))
    else:
        ax.text(0.5, 0.5, 'No performance data available', ha='center', va='center', 
                transform=ax.transAxes, fontsize=12)
        ax.set_title('Overall Performance (No Data)', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

# Generate visualizations for each task
print("Generating task-specific visualizations with accuracy focus...")
for task_key, task_data in all_results.items():
    print(f"\nGenerating visualizations for {TASKS[task_key]['name']}...")
    visualize_task_comparison(task_key, task_data, task_summaries[task_key])

print("\nGenerating cross-task comparison with accuracy focus...")
create_cross_task_comparison(task_summaries, task_improvements)

print("\nAll visualizations completed!")

## Comprehensive Analysis and Results

In [None]:
def generate_comprehensive_analysis(task_summaries: Dict, task_improvements: Dict):
    """Generate detailed analysis of curriculum effectiveness with focus on accuracy."""
    
    print("\n" + "="*80)
    print("COMPREHENSIVE TASK-CURRICULUM ANALYSIS: ACCURACY FOCUS")
    print("="*80)
    
    print("\n🎯 PRIMARY METRIC: TASK ACCURACY PERFORMANCE")
    print("-" * 50)
    
    # Create accuracy-focused comparison table
    accuracy_results = []
    for task_key, summary_df in task_summaries.items():
        for _, row in summary_df.iterrows():
            result = {
                'Task': TASKS[task_key]['name'],
                'Curriculum': row['curriculum_name'],
                'Final Accuracy (%)': f"{row['final_accuracy']:.1f}" if row['final_accuracy'] > 0 else 'N/A',
                'Best Accuracy (%)': f"{row.get('best_accuracy', 0):.1f}" if row.get('best_accuracy', 0) > 0 else 'N/A',
                'Acc. Conv. Step': f"{int(row.get('accuracy_convergence_step', -1))}" if row.get('accuracy_convergence_step', -1) > 0 else 'N/A',
                'Validation Samples': f"{int(row.get('validation_samples', 0))}" if row.get('validation_samples', 0) > 0 else 'N/A',
                'Final Loss': f"{row['final_total_loss']:.4f}" if row['final_total_loss'] > 0 else 'N/A'
            }
            accuracy_results.append(result)
    
    accuracy_df = pd.DataFrame(accuracy_results)
    print(accuracy_df.to_string(index=False))
    
    print("\n📊 ACCURACY IMPROVEMENT ANALYSIS")
    print("-" * 50)
    
    # Detailed accuracy analysis for each task
    total_tasks = 0
    successful_accuracy_improvements = 0
    successful_loss_improvements = 0
    accuracy_gains = []
    
    for task_key, summary_df in task_summaries.items():
        task_name = TASKS[task_key]['name']
        print(f"\n{task_name}:")
        print(f"  Task Type: {TASKS[task_key]['description']}")
        
        baseline_row = summary_df[summary_df['curriculum'] == 'baseline']
        aggressive_row = summary_df[summary_df['curriculum'] == 'aggressive']
        
        if len(baseline_row) > 0 and len(aggressive_row) > 0:
            total_tasks += 1
            baseline = baseline_row.iloc[0]
            aggressive = aggressive_row.iloc[0]
            improvements = task_improvements[task_key]
            
            # PRIMARY: Accuracy Analysis
            accuracy_improvement = improvements.get('accuracy', 0)
            accuracy_improvement_rel = improvements.get('accuracy_relative', 0)
            
            print(f"\n  🎯 ACCURACY PERFORMANCE:")
            if baseline['final_accuracy'] > 0 or aggressive['final_accuracy'] > 0:
                print(f"     Baseline: {baseline['final_accuracy']:.1f}%")
                print(f"     Aggressive: {aggressive['final_accuracy']:.1f}%")
                
                if accuracy_improvement > 0:
                    successful_accuracy_improvements += 1
                    print(f"     ✅ IMPROVEMENT: +{accuracy_improvement:.1f} percentage points ({accuracy_improvement_rel:+.1f}% relative)")
                    accuracy_gains.append(accuracy_improvement)
                elif accuracy_improvement < 0:
                    print(f"     ❌ REGRESSION: {accuracy_improvement:.1f} percentage points ({accuracy_improvement_rel:.1f}% relative)")
                    accuracy_gains.append(accuracy_improvement)
                else:
                    print(f"     ➖ NO CHANGE")
                
                # Best accuracy comparison
                if 'best_accuracy' in baseline and 'best_accuracy' in aggressive:
                    best_improvement = aggressive['best_accuracy'] - baseline['best_accuracy']
                    print(f"     Best Achieved: Baseline {baseline['best_accuracy']:.1f}% → Aggressive {aggressive['best_accuracy']:.1f}% ({best_improvement:+.1f}pp)")
            else:
                print(f"     ⚠️  No accuracy data available")
            
            # Accuracy convergence analysis
            if baseline.get('accuracy_convergence_step', -1) > 0 and aggressive.get('accuracy_convergence_step', -1) > 0:
                conv_improvement = improvements.get('accuracy_convergence_speed', 0)
                print(f"\n  ⏱️  ACCURACY CONVERGENCE:")
                print(f"     Baseline: {int(baseline['accuracy_convergence_step'])} steps")
                print(f"     Aggressive: {int(aggressive['accuracy_convergence_step'])} steps")
                if conv_improvement > 0:
                    print(f"     ✅ {conv_improvement:.1f}% faster convergence")
                elif conv_improvement < 0:
                    print(f"     ❌ {abs(conv_improvement):.1f}% slower convergence")
            
            # Secondary: Loss Analysis (for reference)
            loss_improvement = improvements.get('total_loss', 0)
            if loss_improvement != 0:
                if loss_improvement > 0:
                    successful_loss_improvements += 1
                print(f"\n  📉 LOSS (Reference):")
                print(f"     Baseline: {baseline['final_total_loss']:.4f}")
                print(f"     Aggressive: {aggressive['final_total_loss']:.4f}")
                print(f"     {'✅' if loss_improvement > 0 else '❌'} Change: {loss_improvement:+.1f}%")
        else:
            print(f"  ⚠️  Missing training results for comparison")
    
    # Overall accuracy-focused summary
    print("\n🏆 ACCURACY-FOCUSED SUMMARY")
    print("-" * 50)
    
    if total_tasks > 0:
        accuracy_success_rate = successful_accuracy_improvements / total_tasks * 100
        loss_success_rate = successful_loss_improvements / total_tasks * 100
        
        print(f"📈 ACCURACY IMPROVEMENT STATISTICS:")
        print(f"   • Tasks with accuracy gains: {successful_accuracy_improvements}/{total_tasks} ({accuracy_success_rate:.0f}%)")
        
        if accuracy_gains:
            avg_accuracy_gain = np.mean(accuracy_gains)
            max_gain = max(accuracy_gains)
            min_gain = min(accuracy_gains)
            print(f"   • Average accuracy change: {avg_accuracy_gain:+.1f} percentage points")
            print(f"   • Best improvement: {max_gain:+.1f} percentage points")
            print(f"   • Worst change: {min_gain:+.1f} percentage points")
        
        print(f"\n📉 LOSS REDUCTION (Reference):")
        print(f"   • Tasks with loss reduction: {successful_loss_improvements}/{total_tasks} ({loss_success_rate:.0f}%)")
        
        # Task-specific insights
        print("\n🔍 TASK-SPECIFIC INSIGHTS:")
        for task_key, improvements in task_improvements.items():
            task_name = TASKS[task_key]['name']
            acc_imp = improvements.get('accuracy', 0)
            
            if task_key == 'sudoku':
                if acc_imp > 0:
                    print(f"   • {task_name}: Curriculum helps with constraint satisfaction (+{acc_imp:.1f}pp)")
                elif acc_imp < 0:
                    print(f"   • {task_name}: Baseline more stable for puzzle solving ({acc_imp:.1f}pp)")
                else:
                    print(f"   • {task_name}: No significant difference detected")
                    
            elif task_key == 'connectivity':
                if acc_imp > 0:
                    print(f"   • {task_name}: Curriculum improves graph reasoning (+{acc_imp:.1f}pp)")
                elif acc_imp < 0:
                    print(f"   • {task_name}: Baseline better for connectivity detection ({acc_imp:.1f}pp)")
                else:
                    print(f"   • {task_name}: Similar performance across curricula")
        
        # Final verdict based on accuracy
        print("\n🎯 ACCURACY-BASED RECOMMENDATION:")
        print("-" * 30)
        
        if accuracy_success_rate >= 75:
            print("✅ STRONG EVIDENCE: Aggressive curriculum significantly improves task accuracy")
            print(f"   → Deploy aggressive curriculum (Accuracy gains in {accuracy_success_rate:.0f}% of tasks)")
        elif accuracy_success_rate >= 50:
            print("⚖️  MIXED EVIDENCE: Aggressive curriculum shows moderate accuracy benefits")
            print(f"   → Consider task-specific deployment (Accuracy gains in {accuracy_success_rate:.0f}% of tasks)")
        else:
            print("❌ INSUFFICIENT EVIDENCE: Baseline performs better or equally on accuracy")
            print(f"   → Stick with baseline training (Accuracy gains in only {accuracy_success_rate:.0f}% of tasks)")
        
        # Additional recommendations
        print("\n💡 KEY OBSERVATIONS:")
        if avg_accuracy_gain > 5:
            print("   • Substantial accuracy improvements observed (+5pp average)")
        elif avg_accuracy_gain > 0:
            print("   • Modest accuracy improvements observed")
        else:
            print("   • No clear accuracy advantage for aggressive curriculum")
        
        if any(imp.get('accuracy_convergence_speed', 0) > 20 for imp in task_improvements.values()):
            print("   • Aggressive curriculum shows faster accuracy convergence in some tasks")
        
    else:
        print("⚠️  Insufficient data for accuracy assessment")
    
    # Experimental details
    print(f"\n📋 EXPERIMENTAL CONFIGURATION:")
    print(f"   • Tasks evaluated: {', '.join([TASKS[k]['name'] for k in TASKS.keys()])}")
    print(f"   • Primary metric: Task-specific accuracy (%)")
    print(f"   • Training steps: {list(TASKS.values())[0]['train_steps']} per task")
    print(f"   • Validation frequency: Every {list(TASKS.values())[0].get('validation_interval', 10)} steps")
    print(f"   • Diffusion steps: {COMMON_ARGS['diffusion_steps']}")
    
    return accuracy_df

def create_accuracy_summary_report(task_summaries: Dict, task_improvements: Dict):
    """Create a concise accuracy-focused summary report."""
    
    print("\n" + "="*80)
    print("📊 ACCURACY PERFORMANCE REPORT")
    print("="*80)
    
    # Create summary table
    summary_data = []
    for task_key in TASKS.keys():
        if task_key in task_summaries and task_key in task_improvements:
            summary_df = task_summaries[task_key]
            improvements = task_improvements[task_key]
            
            baseline_row = summary_df[summary_df['curriculum'] == 'baseline']
            aggressive_row = summary_df[summary_df['curriculum'] == 'aggressive']
            
            if len(baseline_row) > 0 and len(aggressive_row) > 0:
                summary_data.append({
                    'Task': TASKS[task_key]['name'],
                    'Baseline Acc (%)': f"{baseline_row.iloc[0]['final_accuracy']:.1f}",
                    'Aggressive Acc (%)': f"{aggressive_row.iloc[0]['final_accuracy']:.1f}",
                    'Improvement (pp)': f"{improvements.get('accuracy', 0):+.1f}",
                    'Winner': '🏆 Aggressive' if improvements.get('accuracy', 0) > 0 else '🏆 Baseline' if improvements.get('accuracy', 0) < 0 else '🤝 Tie'
                })
    
    if summary_data:
        summary_table = pd.DataFrame(summary_data)
        print("\nACCURACY COMPARISON TABLE:")
        print(summary_table.to_string(index=False))
        
        # Calculate overall winner
        aggressive_wins = sum(1 for d in summary_data if 'Aggressive' in d['Winner'])
        baseline_wins = sum(1 for d in summary_data if 'Baseline' in d['Winner'])
        
        print(f"\n📈 OVERALL ACCURACY VERDICT:")
        print(f"   Aggressive wins: {aggressive_wins}/{len(summary_data)} tasks")
        print(f"   Baseline wins: {baseline_wins}/{len(summary_data)} tasks")
        
        if aggressive_wins > baseline_wins:
            print(f"\n   🏆 WINNER: Aggressive Curriculum (Better accuracy in {aggressive_wins}/{len(summary_data)} tasks)")
        elif baseline_wins > aggressive_wins:
            print(f"\n   🏆 WINNER: Baseline Training (Better accuracy in {baseline_wins}/{len(summary_data)} tasks)")
        else:
            print(f"\n   🤝 TIE: Equal performance across tasks")
    else:
        print("No comparison data available")
    
    print("\n" + "="*80)

# Generate comprehensive accuracy-focused analysis
print("Generating comprehensive accuracy-focused analysis...")
accuracy_results = generate_comprehensive_analysis(task_summaries, task_improvements)

print("\nGenerating accuracy summary report...")
create_accuracy_summary_report(task_summaries, task_improvements)

print("\nComprehensive accuracy analysis completed!")

## Results Export and Final Summary

In [None]:
# Export comprehensive results
output_dir = Path('./ired_task_curriculum_results')
output_dir.mkdir(exist_ok=True)

# Save task-specific summaries
for task_key, summary_df in task_summaries.items():
    summary_df.to_csv(output_dir / f'{task_key}_detailed_summary.csv', index=False)

# Save unified results
unified_results.to_csv(output_dir / 'unified_task_curriculum_results.csv', index=False)

# Save improvements summary
improvements_data = []
for task_key, improvements in task_improvements.items():
    for metric, value in improvements.items():
        improvements_data.append({
            'Task': TASKS[task_key]['name'],
            'Task_Key': task_key,
            'Metric': metric.replace('_', ' ').title(),
            'Improvement_Percent': value
        })

improvements_df = pd.DataFrame(improvements_data)
improvements_df.to_csv(output_dir / 'curriculum_improvements_summary.csv', index=False)

# Create training results summary
training_summary = []
for task_key, task_results in training_results.items():
    for curriculum_key, result_code in task_results.items():
        training_summary.append({
            'Task': TASKS[task_key]['name'],
            'Task_Key': task_key,
            'Curriculum': CURRICULA[curriculum_key]['name'],
            'Curriculum_Key': curriculum_key,
            'Exit_Code': result_code,
            'Success': result_code == 0,
            'Training_Steps': TASKS[task_key]['train_steps'],
            'Model': TASKS[task_key]['model'],
            'Dataset': TASKS[task_key]['dataset']
        })

training_df = pd.DataFrame(training_summary)
training_df.to_csv(output_dir / 'training_execution_summary.csv', index=False)

print("\n" + "="*80)
print("RESULTS EXPORT AND FINAL SUMMARY")
print("="*80)

print(f"\n📁 Results saved to: {output_dir}")
print("   📊 Data Files:")
for task_key in TASKS.keys():
    print(f"     • {task_key}_detailed_summary.csv - {TASKS[task_key]['name']} detailed metrics")
print(f"     • unified_task_curriculum_results.csv - Combined results across all tasks")
print(f"     • curriculum_improvements_summary.csv - Improvement percentages by metric")
print(f"     • training_execution_summary.csv - Training execution status")

# Generate final executive summary
successful_tasks = sum(1 for task_results in training_results.values() 
                      if all(r == 0 for r in task_results.values()))
total_experiments = sum(len(task_results) for task_results in training_results.values())
successful_experiments = sum(sum(1 for r in task_results.values() if r == 0) 
                           for task_results in training_results.values())

print("\n" + "="*80)
print("🧠 IRED TASK CURRICULUM ANALYSIS - EXECUTIVE SUMMARY 🧠")
print("="*80)

print(f"\n🎯 EXPERIMENTAL OVERVIEW:")
print(f"   • Tasks evaluated: {len(TASKS)} IRED reasoning tasks")
print(f"   • Curricula compared: {len(CURRICULA)} approaches (Baseline vs Aggressive)")
print(f"   • Total experiments: {total_experiments}")
print(f"   • Successful completions: {successful_experiments}/{total_experiments} ({successful_experiments/total_experiments*100:.1f}%)")
print(f"   • Fully completed tasks: {successful_tasks}/{len(TASKS)}")

# Calculate key findings
loss_improvements = [task_improvements[task]['total_loss'] 
                    for task in task_improvements.keys() 
                    if task_improvements[task]['total_loss'] != 0]

accuracy_improvements = [task_improvements[task]['accuracy'] 
                        for task in task_improvements.keys() 
                        if task_improvements[task]['accuracy'] != 0]

if loss_improvements:
    avg_loss_improvement = np.mean(loss_improvements)
    positive_loss_improvements = sum(1 for imp in loss_improvements if imp > 0)
    print(f"\n📉 TRAINING LOSS FINDINGS:")
    print(f"   • Average improvement: {avg_loss_improvement:+.1f}%")
    print(f"   • Tasks with improvement: {positive_loss_improvements}/{len(loss_improvements)}")
    print(f"   • Success rate: {positive_loss_improvements/len(loss_improvements)*100:.1f}%")

if accuracy_improvements:
    avg_accuracy_improvement = np.mean(accuracy_improvements)
    positive_accuracy_improvements = sum(1 for imp in accuracy_improvements if imp > 0)
    print(f"\n📈 ACCURACY FINDINGS:")
    print(f"   • Average improvement: {avg_accuracy_improvement:+.1f}%")
    print(f"   • Tasks with improvement: {positive_accuracy_improvements}/{len(accuracy_improvements)}")
    print(f"   • Success rate: {positive_accuracy_improvements/len(accuracy_improvements)*100:.1f}%")

# Final recommendation
print(f"\n🏆 FINAL RECOMMENDATION:")
overall_success_indicators = []
if loss_improvements:
    loss_success_rate = sum(1 for imp in loss_improvements if imp > 0) / len(loss_improvements)
    overall_success_indicators.append(loss_success_rate)
if accuracy_improvements:
    accuracy_success_rate = sum(1 for imp in accuracy_improvements if imp > 0) / len(accuracy_improvements)
    overall_success_indicators.append(accuracy_success_rate)

if overall_success_indicators:
    overall_success_rate = np.mean(overall_success_indicators)
    if overall_success_rate >= 0.5:
        print(f"   ✅ DEPLOY AGGRESSIVE CURRICULUM: Shows consistent improvements")
        print(f"      Success rate: {overall_success_rate*100:.0f}% of metrics improved")
    else:
        print(f"   ❌ STICK WITH BASELINE: Aggressive curriculum does not consistently improve performance")
        print(f"      Success rate: {overall_success_rate*100:.0f}% of metrics improved")
else:
    print(f"   ⚠️  INCONCLUSIVE: Insufficient data for reliable recommendation")
    print(f"      Consider longer training or additional validation metrics")

print(f"\n📋 TECHNICAL DETAILS:")
print(f"   • Training steps per task: {list(TASKS.values())[0]['train_steps']}")
print(f"   • Diffusion steps: {COMMON_ARGS['diffusion_steps']}")
print(f"   • Validation interval: Every {COMMON_ARGS['csv_log_interval']} steps")
print(f"   • Key metrics: Total Loss, Task Accuracy")
print(f"   • Models tested: {', '.join(set(task['model'] for task in TASKS.values()))}")

print(f"\n📚 TASKS EVALUATED:")
for task_key, task_config in TASKS.items():
    success = all(training_results[task_key][curr] == 0 for curr in CURRICULA.keys())
    status = "✅" if success else "❌"
    print(f"   {status} {task_config['name']}: {task_config['description']}")

print("\n" + "="*80)
print("🚀 IRED TASK CURRICULUM COMPARISON COMPLETED SUCCESSFULLY! 🚀")
print("="*80)

print(f"\n💡 NEXT STEPS:")
print(f"   1. Review detailed task-specific visualizations above")
print(f"   2. Examine CSV files for numerical analysis")
print(f"   3. Consider extending to OOD (out-of-distribution) versions if aggressive shows promise")
print(f"   4. Validate findings with longer training runs if needed")
print(f"   5. Test on additional IRED tasks if curriculum shows clear benefits")