# IRED Tasks: Baseline vs Aggressive Curriculum Comparison

Comparative analysis of baseline training vs aggressive adversarial curriculum on core IRED reasoning tasks:
- **Connectivity**: Graph connectivity reasoning (12x12 graphs, 0.1 edge probability)
- **Sudoku**: Sudoku completion task (standard difficulty)

This notebook evaluates whether aggressive curriculum learning improves identity error performance within limited training steps for reasoning tasks.

**Key Research Question**: Does aggressive curriculum learning outperform baseline training on identity error metrics for IRED reasoning tasks?

**Tasks Evaluated**: Easier versions (non-OOD) of connectivity and sudoku tasks as described in IRED paper by Yilun Du.

In [None]:
# Clone repository and setup
!rm -rf energy-based-model
!git clone https://github.com/mdkrasnow/energy-based-model.git
%cd energy-based-model

In [None]:
# Install dependencies
!pip install -q torch torchvision einops accelerate tqdm tabulate matplotlib numpy pandas ema-pytorch ipdb seaborn scikit-learn

In [None]:
# Import required libraries
import subprocess
import sys
import glob
import os
from pathlib import Path
from typing import Dict, Optional, List, Tuple
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from math import pi

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

print("All libraries imported successfully!")

## Task and Curriculum Configuration

In [None]:
# Task-specific configurations based on IRED paper
TASKS = {
    'connectivity': {
        'name': 'Graph Connectivity',
        'description': 'Graph connectivity reasoning (12x12, p=0.1)',
        'dataset': 'connectivity',
        'model': 'gnn-conv-1d-v2',
        'batch_size': 64,  # From train.py validation_batch_size
        'train_steps': 150,  # Longer for reasoning tasks
        'extra_args': [],
        'color': '#2E8B57',
        'expected_metric': 'identity_error'
    },
    'sudoku': {
        'name': 'Sudoku Completion', 
        'description': 'Sudoku puzzle completion task',
        'dataset': 'sudoku',
        'model': 'sudoku',
        'batch_size': 64,  # From train.py validation_batch_size  
        'train_steps': 150,
        'extra_args': ['--cond_mask', 'True'],  # Required for sudoku
        'color': '#4169E1',
        'expected_metric': 'sudoku_accuracy'
    }
}

# Only 2 curricula to compare: baseline vs aggressive
CURRICULA = {
    'baseline': {
        'name': 'Baseline',
        'description': 'No curriculum learning',
        'args': ['--disable-curriculum', 'True'],
        'color': '#1f77b4'
    },
    'aggressive': {
        'name': 'Aggressive Curriculum', 
        'description': 'Rapid adversarial escalation',
        'args': ['--curriculum-config', 'aggressive'],
        'color': '#d62728'
    }
}

# Common training parameters
COMMON_ARGS = {
    'diffusion_steps': 10,
    'supervise_energy_landscape': 'True',
    'save_csv_logs': True,
    'csv_log_interval': 100
}

print(f"Testing {len(TASKS)} IRED reasoning tasks:")
for key, config in TASKS.items():
    print(f"  • {config['name']}: {config['description']}")
    print(f"    Model: {config['model']}, Steps: {config['train_steps']}")

print(f"\nComparing {len(CURRICULA)} curriculum approaches:")
for key, config in CURRICULA.items():
    print(f"  • {config['name']}: {config['description']}")

## Training Command Builder and Utility Functions

In [None]:
def build_task_training_command(task_key: str, curriculum_key: str) -> str:
    """Build training command for a specific task-curriculum combination."""
    task = TASKS[task_key]
    curriculum = CURRICULA[curriculum_key]
    
    base_cmd = f"""python train.py \
        --dataset {task['dataset']} \
        --model {task['model']} \
        --batch_size {task['batch_size']} \
        --diffusion_steps {COMMON_ARGS['diffusion_steps']} \
        --supervise-energy-landscape {COMMON_ARGS['supervise_energy_landscape']} \
        --train-num-steps {task['train_steps']} \
        --save-csv-logs \
        --csv-log-interval {COMMON_ARGS['csv_log_interval']} \
        --csv-log-dir ./csv_logs_{task_key}_{curriculum_key}"""
    
    # Add task-specific arguments
    if task['extra_args']:
        base_cmd += ' \
        ' + ' \
        '.join(task['extra_args'])
    
    # Add curriculum arguments  
    if curriculum['args']:
        base_cmd += ' \
        ' + ' \
        '.join(curriculum['args'])
    
    return base_cmd

def load_csv_data(csv_path: Path) -> Optional[pd.DataFrame]:
    """Load CSV data with error handling."""
    try:
        if csv_path.exists():
            return pd.read_csv(csv_path)
        else:
            print(f"Warning: {csv_path} not found")
            return None
    except Exception as e:
        print(f"Error loading {csv_path}: {e}")
        return None

def safe_get_final_value(df: pd.DataFrame, column: str, default: float = 0.0) -> float:
    """Safely get the final value from a dataframe column."""
    if df is None or column not in df.columns or len(df) == 0:
        return default
    return float(df[column].iloc[-1])

def safe_get_best_value(df: pd.DataFrame, column: str, minimize: bool = True, default: float = 0.0) -> float:
    """Safely get the best (min/max) value from a dataframe column."""
    if df is None or column not in df.columns or len(df) == 0:
        return default
    return float(df[column].min() if minimize else df[column].max())

def calculate_convergence_step(df: pd.DataFrame, column: str, threshold_pct: float = 0.9) -> int:
    """Calculate step where metric reaches threshold percentage of final improvement."""
    if df is None or column not in df.columns or len(df) < 5:
        return -1
    
    initial_val = df[column].iloc[:5].mean()
    final_val = df[column].iloc[-5:].mean()
    target_val = initial_val - threshold_pct * (initial_val - final_val)
    
    below_target = df[df[column] <= target_val]
    if not below_target.empty:
        return int(below_target['step'].iloc[0])
    else:
        return int(df['step'].iloc[-1])

print("Utility functions defined successfully!")
print("\nExample command structure:")
print(build_task_training_command('connectivity', 'baseline')[:200] + "...")

## Multi-Task Multi-Curriculum Training

In [None]:
# Train all task-curriculum combinations
training_results = {}

for task_key, task_config in TASKS.items():
    print(f"\n{'='*80}")
    print(f"TRAINING TASK: {task_config['name']}")
    print(f"Description: {task_config['description']}")
    print(f"Dataset: {task_config['dataset']}, Model: {task_config['model']}")
    print(f"Training Steps: {task_config['train_steps']}, Batch Size: {task_config['batch_size']}")
    print(f"{'='*80}")
    
    task_results = {}
    
    for curriculum_key, curriculum_config in CURRICULA.items():
        print(f"\n{'-'*60}")
        print(f"Starting {task_config['name']} with {curriculum_config['name']}")
        print(f"Curriculum Description: {curriculum_config['description']}")
        print(f"{'-'*60}")
        
        # Build training command
        cmd = build_task_training_command(task_key, curriculum_key)
        print(f"\nCommand: {cmd}")
        print("\nTraining output:")
        print("-" * 40)
        
        # Execute training with real-time output
        try:
            # Use subprocess to capture and display output in real-time
            process = subprocess.Popen(
                cmd,
                shell=True,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                bufsize=1
            )
            
            # Display output line by line as it comes
            for line in iter(process.stdout.readline, ''):
                if line:
                    print(line.rstrip())
                    sys.stdout.flush()
            
            # Wait for process to complete
            result = process.wait()
            task_results[curriculum_key] = result
            
            if result == 0:
                print(f"✓ {task_config['name']} with {curriculum_config['name']} completed successfully")
            else:
                print(f"✗ {task_config['name']} with {curriculum_config['name']} failed with exit code {result}")
                
        except Exception as e:
            print(f"✗ Error during {task_config['name']} with {curriculum_config['name']} training: {e}")
            task_results[curriculum_key] = -1
            
        print("-" * 40)
    
    training_results[task_key] = task_results

print(f"\n{'='*80}")
print("ALL TASK TRAINING COMPLETED!")
print(f"{'='*80}")

# Print comprehensive summary
print("\nTRAINING SUMMARY:")
print("=" * 40)
total_successful = 0
total_experiments = 0

for task_key, task_results in training_results.items():
    task_name = TASKS[task_key]['name']
    successful = sum(1 for result in task_results.values() if result == 0)
    total = len(task_results)
    total_successful += successful
    total_experiments += total
    
    print(f"\n{task_name}: {successful}/{total} completed successfully")
    for curriculum_key, result in task_results.items():
        curriculum_name = CURRICULA[curriculum_key]['name']
        status = "✓ Success" if result == 0 else "✗ Failed"
        print(f"  {curriculum_name}: {status}")

success_rate = (total_successful / total_experiments * 100) if total_experiments > 0 else 0
print(f"\nOVERALL SUCCESS RATE: {total_successful}/{total_experiments} ({success_rate:.1f}%)")

if success_rate >= 75:
    print("🎯 Training pipeline working well! Proceeding to analysis...")
elif success_rate >= 50:
    print("⚠️ Some training issues detected, but proceeding with available data...")
else:
    print("❌ Significant training issues detected. Check configurations and logs.")

## Task-Specific Data Loading and Processing

In [None]:
def load_all_task_results() -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
    """Load results organized by task, then curriculum, then data type."""
    results = {}
    
    for task_key in TASKS.keys():
        print(f"Loading {task_key} results...")
        task_data = {}
        
        for curriculum_key in CURRICULA.keys():
            csv_dir = Path(f"./csv_logs_{task_key}_{curriculum_key}")
            
            curriculum_data = {}
            patterns = {
                'training': 'training_metrics_*.csv',
                'validation': 'validation_metrics_*.csv', 
                'energy': 'energy_metrics_*.csv',
                'curriculum': 'curriculum_metrics_*.csv',
                'robustness': 'robustness_metrics_*.csv'
            }
            
            for key, pattern in patterns.items():
                files = glob.glob(str(csv_dir / pattern))
                if files:
                    # Get most recent file
                    latest_file = max(files, key=lambda x: Path(x).stat().st_mtime if Path(x).exists() else 0)
                    curriculum_data[key] = load_csv_data(Path(latest_file))
                else:
                    curriculum_data[key] = None
            
            available = sum(1 for df in curriculum_data.values() if df is not None)
            print(f"  {curriculum_key}: {available}/5 data files found")
            task_data[curriculum_key] = curriculum_data
        
        results[task_key] = task_data
    
    return results

def extract_task_metrics(task_key: str, task_data: Dict) -> pd.DataFrame:
    """Extract key metrics for a specific task."""
    processed_data = []
    
    for curriculum_key, data in task_data.items():
        curriculum_config = CURRICULA[curriculum_key]
        training_df = data.get('training')
        validation_df = data.get('validation') 
        energy_df = data.get('energy')
        
        # Extract identity error (key metric for reasoning tasks)
        identity_error = float('inf')
        best_identity_error = float('inf')
        final_accuracy = 0.0
        best_accuracy = 0.0
        convergence_step = -1
        
        if validation_df is not None:
            # Look for identity_error metric
            identity_rows = validation_df[validation_df['metric_name'] == 'identity_error']
            if not identity_rows.empty:
                identity_error = identity_rows['metric_value'].iloc[-1]
                best_identity_error = identity_rows['metric_value'].min()
            
            # Look for accuracy metric
            accuracy_rows = validation_df[validation_df['metric_name'] == 'accuracy']
            if not accuracy_rows.empty:
                final_accuracy = accuracy_rows['metric_value'].iloc[-1] * 100  # Convert to percentage
                best_accuracy = accuracy_rows['metric_value'].max() * 100
            
            # Look for sudoku-specific accuracy if applicable
            if task_key == 'sudoku':
                sudoku_acc_rows = validation_df[validation_df['metric_name'] == 'sudoku_accuracy']
                if not sudoku_acc_rows.empty:
                    final_accuracy = sudoku_acc_rows['metric_value'].iloc[-1] * 100
                    best_accuracy = sudoku_acc_rows['metric_value'].max() * 100
        
        # Calculate convergence step for total loss
        if training_df is not None and 'total_loss' in training_df.columns:
            convergence_step = calculate_convergence_step(training_df, 'total_loss')
        
        metrics = {
            'task': task_key,
            'task_name': TASKS[task_key]['name'],
            'curriculum': curriculum_key,
            'curriculum_name': curriculum_config['name'],
            'color': curriculum_config['color'],
            
            # Training metrics
            'final_total_loss': safe_get_final_value(training_df, 'total_loss'),
            'best_total_loss': safe_get_best_value(training_df, 'total_loss', minimize=True),
            'final_energy_loss': safe_get_final_value(training_df, 'loss_energy'),
            'final_denoise_loss': safe_get_final_value(training_df, 'loss_denoise'),
            
            # Key reasoning metrics
            'final_identity_error': identity_error,
            'best_identity_error': best_identity_error,
            'final_accuracy': final_accuracy,
            'best_accuracy': best_accuracy,
            
            # Performance metrics
            'training_time': training_df['nn_time'].mean() if training_df is not None and 'nn_time' in training_df.columns else 0.0,
            'convergence_step': convergence_step,
            
            # Energy landscape metrics
            'final_energy_margin': safe_get_final_value(energy_df, 'energy_margin'),
            'curriculum_weight': safe_get_best_value(energy_df, 'curriculum_weight', minimize=False) if energy_df is not None and 'curriculum_weight' in energy_df.columns else 0.0
        }
        
        processed_data.append(metrics)
    
    return pd.DataFrame(processed_data)

def calculate_improvement_metrics(task_summary: pd.DataFrame) -> Dict[str, float]:
    """Calculate improvement metrics for aggressive vs baseline."""
    baseline_row = task_summary[task_summary['curriculum'] == 'baseline']
    aggressive_row = task_summary[task_summary['curriculum'] == 'aggressive']
    
    improvements = {}
    
    if len(baseline_row) > 0 and len(aggressive_row) > 0:
        baseline = baseline_row.iloc[0]
        aggressive = aggressive_row.iloc[0]
        
        # Identity error improvement (lower is better)
        if baseline['final_identity_error'] != float('inf') and aggressive['final_identity_error'] != float('inf'):
            improvements['identity_error'] = (baseline['final_identity_error'] - aggressive['final_identity_error']) / baseline['final_identity_error'] * 100
        else:
            improvements['identity_error'] = 0.0
        
        # Loss improvement (lower is better)
        if baseline['final_total_loss'] > 0 and aggressive['final_total_loss'] > 0:
            improvements['total_loss'] = (baseline['final_total_loss'] - aggressive['final_total_loss']) / baseline['final_total_loss'] * 100
        else:
            improvements['total_loss'] = 0.0
        
        # Accuracy improvement (higher is better)
        if baseline['final_accuracy'] > 0:
            improvements['accuracy'] = (aggressive['final_accuracy'] - baseline['final_accuracy']) / baseline['final_accuracy'] * 100
        else:
            improvements['accuracy'] = 0.0
        
        # Training speed comparison (lower time is better)
        if baseline['training_time'] > 0 and aggressive['training_time'] > 0:
            improvements['training_speed'] = (baseline['training_time'] - aggressive['training_time']) / baseline['training_time'] * 100
        else:
            improvements['training_speed'] = 0.0
        
        # Convergence speed (lower steps is better)
        if baseline['convergence_step'] > 0 and aggressive['convergence_step'] > 0:
            improvements['convergence_speed'] = (baseline['convergence_step'] - aggressive['convergence_step']) / baseline['convergence_step'] * 100
        else:
            improvements['convergence_speed'] = 0.0
    
    return improvements

# Load all results
print("Loading all task-curriculum results...")
all_results = load_all_task_results()

# Process each task separately  
task_summaries = {}
task_improvements = {}

print("\nProcessing task data and calculating improvements...")
for task_key, task_data in all_results.items():
    task_summaries[task_key] = extract_task_metrics(task_key, task_data)
    task_improvements[task_key] = calculate_improvement_metrics(task_summaries[task_key])
    
    print(f"\n{TASKS[task_key]['name']} Summary:")
    display_cols = ['curriculum_name', 'final_total_loss', 'final_identity_error', 'final_accuracy', 'convergence_step']
    display(task_summaries[task_key][display_cols].round(4))
    
    # Print improvements
    improvements = task_improvements[task_key]
    print(f"\nImprovements (Aggressive vs Baseline):")
    for metric, improvement in improvements.items():
        sign = "+" if improvement > 0 else ""
        print(f"  {metric.replace('_', ' ').title()}: {sign}{improvement:.1f}%")

print(f"\nLoaded and processed data for {len(task_summaries)} tasks")
print("Data loading and processing complete!")

## Task-Specific Visualization Functions

In [None]:
def visualize_task_comparison(task_key: str, task_data: Dict, task_summary: pd.DataFrame):
    """Create comprehensive visualizations for a single task."""
    task_config = TASKS[task_key]
    
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'{task_config["name"]} - Baseline vs Aggressive Curriculum', fontsize=16, fontweight='bold')
    
    # 1. Training Loss Curves
    ax = axes[0, 0]
    for curriculum_key, data in task_data.items():
        training_df = data.get('training')
        if training_df is not None and 'total_loss' in training_df.columns:
            curriculum_config = CURRICULA[curriculum_key]
            # Apply smoothing for better visualization
            window = min(10, len(training_df) // 20)
            if window > 1:
                smoothed_loss = training_df['total_loss'].rolling(window=window, min_periods=1).mean()
            else:
                smoothed_loss = training_df['total_loss']
            
            ax.plot(training_df['step'], smoothed_loss,
                    color=curriculum_config['color'], label=curriculum_config['name'], 
                    linewidth=2, alpha=0.8)
    
    ax.set_title('Training Loss Evolution (Smoothed)', fontweight='bold')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Total Loss')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_yscale('log')  # Log scale for better loss visualization
    
    # 2. Identity Error Evolution (KEY METRIC)
    ax = axes[0, 1] 
    data_found = False
    for curriculum_key, data in task_data.items():
        validation_df = data.get('validation')
        if validation_df is not None:
            curriculum_config = CURRICULA[curriculum_key]
            identity_df = validation_df[validation_df['metric_name'] == 'identity_error']
            if not identity_df.empty:
                ax.plot(identity_df['step'], identity_df['metric_value'],
                        color=curriculum_config['color'], label=curriculum_config['name'],
                        linewidth=3, alpha=0.9, marker='o', markersize=4)
                data_found = True
    
    if data_found:
        ax.set_title('Identity Error Evolution (Lower = Better)', fontweight='bold')
        ax.set_xlabel('Training Step')
        ax.set_ylabel('||Pred @ Input - I||²')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_yscale('log')  # Log scale for identity error
    else:
        ax.text(0.5, 0.5, 'No identity error data available\n(Run longer training for validation)', 
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Identity Error (No Data)', fontweight='bold')
    
    # 3. Final Identity Error Comparison (Bar Chart)
    ax = axes[1, 0]
    curricula_names = task_summary['curriculum_name'].tolist()
    identity_errors = task_summary['final_identity_error'].tolist()
    colors = task_summary['color'].tolist()
    
    # Filter out inf values
    valid_data = [(name, error, color) for name, error, color in zip(curricula_names, identity_errors, colors) 
                  if error != float('inf') and error > 0]
    
    if valid_data:
        names, errors, cols = zip(*valid_data)
        bars = ax.bar(names, errors, color=cols, alpha=0.7, edgecolor='black', linewidth=2)
        ax.set_title('Final Identity Error Comparison', fontweight='bold')
        ax.set_ylabel('Identity Error')
        ax.set_yscale('log')  # Log scale for better comparison
        
        # Add improvement annotation
        if len(errors) == 2:
            baseline_idx = 0 if 'baseline' in names[0].lower() else 1
            aggressive_idx = 1 - baseline_idx
            
            baseline_error = errors[baseline_idx]
            aggressive_error = errors[aggressive_idx]
            improvement = (baseline_error - aggressive_error) / baseline_error * 100
            
            color = 'green' if improvement > 0 else 'red'
            ax.text(0.5, 0.8, f'Aggressive Improvement: {improvement:+.1f}%', 
                    ha='center', transform=ax.transAxes, fontweight='bold', fontsize=12,
                    bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.3))
        
        # Add value labels
        for bar, error in zip(bars, errors):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.1,
                    f'{error:.2e}', ha='center', va='bottom', fontweight='bold', rotation=45)
    else:
        ax.text(0.5, 0.5, 'No valid identity error data\nfor comparison', 
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Identity Error Comparison (No Data)', fontweight='bold')
    
    # 4. Task-Specific Metric (Accuracy or Loss)
    ax = axes[1, 1]
    accuracies = task_summary['final_accuracy'].tolist()
    
    if max(accuracies) > 0:
        bars = ax.bar(curricula_names, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
        ax.set_title('Final Accuracy Comparison', fontweight='bold')
        ax.set_ylabel('Accuracy (%)')
        ax.set_ylim(0, max(accuracies) * 1.2)
        
        # Add value labels
        for bar, acc in zip(bars, accuracies):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(accuracies)*0.02,
                    f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')
        
        # Add improvement annotation for accuracy
        if len(accuracies) == 2:
            baseline_idx = 0 if 'baseline' in curricula_names[0].lower() else 1
            aggressive_idx = 1 - baseline_idx
            
            if accuracies[baseline_idx] > 0:
                acc_improvement = (accuracies[aggressive_idx] - accuracies[baseline_idx]) / accuracies[baseline_idx] * 100
                color = 'green' if acc_improvement > 0 else 'red'
                ax.text(0.5, 0.9, f'Accuracy Improvement: {acc_improvement:+.1f}%', 
                        ha='center', transform=ax.transAxes, fontweight='bold',
                        bbox=dict(boxstyle="round,pad=0.3", facecolor=color, alpha=0.3))
    else:
        # Show final loss instead
        final_losses = task_summary['final_total_loss'].tolist()
        if max(final_losses) > 0:
            bars = ax.bar(curricula_names, final_losses, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
            ax.set_title('Final Total Loss Comparison', fontweight='bold')
            ax.set_ylabel('Total Loss')
            ax.set_yscale('log')
            
            for bar, loss in zip(bars, final_losses):
                ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.1,
                        f'{loss:.2e}', ha='center', va='bottom', fontweight='bold', rotation=45)
        else:
            ax.text(0.5, 0.5, 'No accuracy or loss data available', ha='center', va='center', 
                    transform=ax.transAxes, fontsize=12)
            ax.set_title('Performance Metrics (No Data)', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

def create_cross_task_comparison(task_summaries: Dict, task_improvements: Dict):
    """Compare aggressive curriculum performance across tasks."""
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Cross-Task Performance: Aggressive vs Baseline Curriculum', fontsize=16, fontweight='bold')
    
    # Extract data for comparison
    task_names = []
    baseline_errors = []
    aggressive_errors = []
    identity_improvements = []
    loss_improvements = []
    
    for task_key, summary_df in task_summaries.items():
        task_names.append(TASKS[task_key]['name'])
        
        baseline_row = summary_df[summary_df['curriculum'] == 'baseline']
        aggressive_row = summary_df[summary_df['curriculum'] == 'aggressive']
        
        if len(baseline_row) > 0 and len(aggressive_row) > 0:
            baseline_error = baseline_row['final_identity_error'].iloc[0]
            aggressive_error = aggressive_row['final_identity_error'].iloc[0]
            
            baseline_errors.append(baseline_error if baseline_error != float('inf') else None)
            aggressive_errors.append(aggressive_error if aggressive_error != float('inf') else None)
            
            # Get improvements from pre-calculated data
            improvements = task_improvements[task_key]
            identity_improvements.append(improvements.get('identity_error', 0))
            loss_improvements.append(improvements.get('total_loss', 0))
        else:
            baseline_errors.append(None)
            aggressive_errors.append(None)
            identity_improvements.append(0)
            loss_improvements.append(0)
    
    # 1. Identity Error Comparison
    ax = axes[0, 0]
    x = np.arange(len(task_names))
    width = 0.35
    
    # Filter valid data
    valid_baseline = [e for e in baseline_errors if e is not None]
    valid_aggressive = [e for e in aggressive_errors if e is not None]
    valid_names = [name for name, b, a in zip(task_names, baseline_errors, aggressive_errors) 
                   if b is not None and a is not None]
    
    if valid_baseline and valid_aggressive:
        valid_x = np.arange(len(valid_names))
        baseline_bars = ax.bar(valid_x - width/2, valid_baseline, width, label='Baseline', 
                              color=CURRICULA['baseline']['color'], alpha=0.7, edgecolor='black')
        aggressive_bars = ax.bar(valid_x + width/2, valid_aggressive, width, label='Aggressive', 
                                color=CURRICULA['aggressive']['color'], alpha=0.7, edgecolor='black')
        
        ax.set_title('Identity Error by Task (Lower = Better)', fontweight='bold')
        ax.set_ylabel('Identity Error')
        ax.set_xlabel('Task')
        ax.set_xticks(valid_x)
        ax.set_xticklabels(valid_names, rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        ax.set_yscale('log')
    else:
        ax.text(0.5, 0.5, 'No valid identity error data\nfor cross-task comparison', 
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Identity Error by Task (No Data)', fontweight='bold')
    
    # 2. Identity Error Improvement Percentage
    ax = axes[0, 1]
    valid_improvements = [(name, imp) for name, imp in zip(task_names, identity_improvements) if imp != 0]
    
    if valid_improvements:
        names, improvements = zip(*valid_improvements)
        colors = ['green' if imp > 0 else 'red' for imp in improvements]
        bars = ax.bar(names, improvements, color=colors, alpha=0.7, edgecolor='black')
        ax.axhline(y=0, color='black', linestyle='-', alpha=0.5)
        ax.set_title('Identity Error Improvement (%)', fontweight='bold')
        ax.set_ylabel('Improvement (%)')
        ax.set_xlabel('Task')
        
        # Add value labels
        for bar, imp in zip(bars, improvements):
            y_pos = bar.get_height() + (1 if imp > 0 else -3)
            ax.text(bar.get_x() + bar.get_width()/2, y_pos,
                    f'{imp:+.1f}%', ha='center', va='bottom' if imp > 0 else 'top', fontweight='bold')
        
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        ax.grid(True, alpha=0.3, axis='y')
    else:
        ax.text(0.5, 0.5, 'No improvement data available', ha='center', va='center', 
                transform=ax.transAxes, fontsize=12)
        ax.set_title('Identity Error Improvement (No Data)', fontweight='bold')
    
    # 3. Loss Improvement Comparison
    ax = axes[1, 0]
    valid_loss_improvements = [(name, imp) for name, imp in zip(task_names, loss_improvements) if imp != 0]
    
    if valid_loss_improvements:
        names, improvements = zip(*valid_loss_improvements)
        colors = ['green' if imp > 0 else 'red' for imp in improvements]
        bars = ax.bar(names, improvements, color=colors, alpha=0.7, edgecolor='black')
        ax.axhline(y=0, color='black', linestyle='-', alpha=0.5)
        ax.set_title('Total Loss Improvement (%)', fontweight='bold')
        ax.set_ylabel('Improvement (%)')
        ax.set_xlabel('Task')
        
        # Add value labels
        for bar, imp in zip(bars, improvements):
            y_pos = bar.get_height() + (0.5 if imp > 0 else -1.5)
            ax.text(bar.get_x() + bar.get_width()/2, y_pos,
                    f'{imp:+.1f}%', ha='center', va='bottom' if imp > 0 else 'top', fontweight='bold')
        
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        ax.grid(True, alpha=0.3, axis='y')
    else:
        ax.text(0.5, 0.5, 'No loss improvement data available', ha='center', va='center', 
                transform=ax.transAxes, fontsize=12)
        ax.set_title('Loss Improvement (No Data)', fontweight='bold')
    
    # 4. Overall Performance Summary
    ax = axes[1, 1]
    
    # Create a summary score for each task
    task_scores = []
    score_names = []
    
    for task_key, improvements in task_improvements.items():
        if any(imp != 0 for imp in improvements.values()):
            # Weighted score: identity error (50%), loss (30%), accuracy (20%)
            score = (improvements.get('identity_error', 0) * 0.5 + 
                    improvements.get('total_loss', 0) * 0.3 + 
                    improvements.get('accuracy', 0) * 0.2)
            task_scores.append(score)
            score_names.append(TASKS[task_key]['name'])
    
    if task_scores:
        colors = ['green' if score > 0 else 'red' for score in task_scores]
        bars = ax.bar(score_names, task_scores, color=colors, alpha=0.7, edgecolor='black')
        ax.axhline(y=0, color='black', linestyle='-', alpha=0.5)
        ax.set_title('Overall Performance Score\n(Weighted: Identity 50%, Loss 30%, Accuracy 20%)', fontweight='bold')
        ax.set_ylabel('Composite Score (%)')
        ax.set_xlabel('Task')
        
        # Add value labels
        for bar, score in zip(bars, task_scores):
            y_pos = bar.get_height() + (1 if score > 0 else -3)
            ax.text(bar.get_x() + bar.get_width()/2, y_pos,
                    f'{score:+.1f}', ha='center', va='bottom' if score > 0 else 'top', fontweight='bold')
        
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        ax.grid(True, alpha=0.3, axis='y')
    else:
        ax.text(0.5, 0.5, 'No performance data available\nfor composite scoring', 
                ha='center', va='center', transform=ax.transAxes, fontsize=12)
        ax.set_title('Overall Performance Score (No Data)', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

# Generate visualizations for each task
print("Generating task-specific visualizations...")
for task_key, task_data in all_results.items():
    print(f"\nGenerating visualizations for {TASKS[task_key]['name']}...")
    visualize_task_comparison(task_key, task_data, task_summaries[task_key])

print("\nGenerating cross-task comparison...")
create_cross_task_comparison(task_summaries, task_improvements)

print("\nAll visualizations completed!")

## Comprehensive Analysis and Results

In [None]:
def generate_comprehensive_analysis(task_summaries: Dict, task_improvements: Dict):
    """Generate detailed analysis of curriculum effectiveness across tasks."""
    
    print("\n" + "="*80)
    print("COMPREHENSIVE TASK-CURRICULUM ANALYSIS")
    print("="*80)
    
    print("\n📊 TASK PERFORMANCE SUMMARY")
    print("-" * 50)
    
    # Create unified comparison table
    all_results = []
    for task_key, summary_df in task_summaries.items():
        for _, row in summary_df.iterrows():
            result = {
                'Task': TASKS[task_key]['name'],
                'Curriculum': row['curriculum_name'],
                'Final Loss': f"{row['final_total_loss']:.4f}" if row['final_total_loss'] > 0 else 'N/A',
                'Identity Error': f"{row['final_identity_error']:.2e}" if row['final_identity_error'] != float('inf') else 'N/A',
                'Accuracy (%)': f"{row['final_accuracy']:.1f}" if row['final_accuracy'] > 0 else 'N/A',
                'Training Time (s)': f"{row['training_time']:.3f}" if row['training_time'] > 0 else 'N/A',
                'Convergence Step': f"{row['convergence_step']}" if row['convergence_step'] > 0 else 'N/A'
            }
            all_results.append(result)
    
    unified_df = pd.DataFrame(all_results)
    print(unified_df.to_string(index=False))
    
    print("\n🎯 CURRICULUM EFFECTIVENESS ANALYSIS")
    print("-" * 50)
    
    # Detailed analysis for each task
    total_tasks = 0
    successful_identity_improvements = 0
    successful_loss_improvements = 0
    successful_accuracy_improvements = 0
    
    for task_key, summary_df in task_summaries.items():
        task_name = TASKS[task_key]['name']
        print(f"\n{task_name}:")
        print(f"  Dataset: {TASKS[task_key]['dataset']}, Model: {TASKS[task_key]['model']}")
        
        baseline_row = summary_df[summary_df['curriculum'] == 'baseline']
        aggressive_row = summary_df[summary_df['curriculum'] == 'aggressive']
        
        if len(baseline_row) > 0 and len(aggressive_row) > 0:
            total_tasks += 1
            baseline = baseline_row.iloc[0]
            aggressive = aggressive_row.iloc[0]
            improvements = task_improvements[task_key]
            
            # Identity Error Analysis
            identity_improvement = improvements.get('identity_error', 0)
            if identity_improvement > 0:
                successful_identity_improvements += 1
                print(f"  ✅ IDENTITY ERROR: Improved by {identity_improvement:.1f}%")
                print(f"     Baseline: {baseline['final_identity_error']:.2e} → Aggressive: {aggressive['final_identity_error']:.2e}")
            elif identity_improvement < 0:
                print(f"  ❌ IDENTITY ERROR: Worsened by {abs(identity_improvement):.1f}%")
                print(f"     Baseline: {baseline['final_identity_error']:.2e} → Aggressive: {aggressive['final_identity_error']:.2e}")
            else:
                print(f"  ⚠️  IDENTITY ERROR: No valid comparison data")
            
            # Loss Analysis
            loss_improvement = improvements.get('total_loss', 0)
            if loss_improvement > 0:
                successful_loss_improvements += 1
                print(f"  ✅ TOTAL LOSS: Improved by {loss_improvement:.1f}%")
                print(f"     Baseline: {baseline['final_total_loss']:.4f} → Aggressive: {aggressive['final_total_loss']:.4f}")
            elif loss_improvement < 0:
                print(f"  ❌ TOTAL LOSS: Worsened by {abs(loss_improvement):.1f}%")
                print(f"     Baseline: {baseline['final_total_loss']:.4f} → Aggressive: {aggressive['final_total_loss']:.4f}")
            else:
                print(f"  ⚠️  TOTAL LOSS: No valid comparison data")
            
            # Accuracy Analysis
            accuracy_improvement = improvements.get('accuracy', 0)
            if accuracy_improvement > 0:
                successful_accuracy_improvements += 1
                print(f"  ✅ ACCURACY: Improved by {accuracy_improvement:.1f}%")
                print(f"     Baseline: {baseline['final_accuracy']:.1f}% → Aggressive: {aggressive['final_accuracy']:.1f}%")
            elif accuracy_improvement < 0:
                print(f"  ❌ ACCURACY: Worsened by {abs(accuracy_improvement):.1f}%")
                print(f"     Baseline: {baseline['final_accuracy']:.1f}% → Aggressive: {aggressive['final_accuracy']:.1f}%")
            else:
                print(f"  ⚠️  ACCURACY: No valid comparison data")
            
            # Training efficiency
            speed_improvement = improvements.get('training_speed', 0)
            if speed_improvement != 0:
                print(f"  🏃 TRAINING SPEED: {'Faster' if speed_improvement > 0 else 'Slower'} by {abs(speed_improvement):.1f}%")
            
            # Convergence analysis
            conv_improvement = improvements.get('convergence_speed', 0)
            if conv_improvement != 0:
                print(f"  🎯 CONVERGENCE: {'Faster' if conv_improvement > 0 else 'Slower'} by {abs(conv_improvement):.1f}%")
        else:
            print(f"  ⚠️  Missing training results for comparison")
    
    # Overall statistical summary
    print("\n🏆 OVERALL STATISTICAL SUMMARY")
    print("-" * 50)
    
    if total_tasks > 0:
        identity_success_rate = successful_identity_improvements / total_tasks * 100
        loss_success_rate = successful_loss_improvements / total_tasks * 100
        accuracy_success_rate = successful_accuracy_improvements / total_tasks * 100
        
        print(f"📈 IMPROVEMENT SUCCESS RATES:")
        print(f"   • Identity Error: {successful_identity_improvements}/{total_tasks} tasks ({identity_success_rate:.0f}%)")
        print(f"   • Total Loss: {successful_loss_improvements}/{total_tasks} tasks ({loss_success_rate:.0f}%)")
        print(f"   • Accuracy: {successful_accuracy_improvements}/{total_tasks} tasks ({accuracy_success_rate:.0f}%)")
        
        # Calculate average improvements
        avg_identity_improvement = np.mean([improvements.get('identity_error', 0) 
                                          for improvements in task_improvements.values()])
        avg_loss_improvement = np.mean([improvements.get('total_loss', 0) 
                                      for improvements in task_improvements.values()])
        avg_accuracy_improvement = np.mean([improvements.get('accuracy', 0) 
                                          for improvements in task_improvements.values()])
        
        print(f"\n📊 AVERAGE IMPROVEMENTS:")
        print(f"   • Identity Error: {avg_identity_improvement:+.1f}%")
        print(f"   • Total Loss: {avg_loss_improvement:+.1f}%")
        print(f"   • Accuracy: {avg_accuracy_improvement:+.1f}%")
        
        # Overall recommendation
        print("\n🎯 RECOMMENDATIONS:")
        print("-" * 30)
        
        if identity_success_rate >= 50:
            print("✅ IDENTITY ERROR: Aggressive curriculum shows promise for IRED reasoning tasks")
        else:
            print("❌ IDENTITY ERROR: Baseline training competitive with aggressive curriculum")
        
        if loss_success_rate >= 50:
            print("✅ TRAINING LOSS: Aggressive curriculum improves optimization")
        else:
            print("❌ TRAINING LOSS: Baseline optimization competitive")
        
        if accuracy_success_rate >= 50:
            print("✅ TASK ACCURACY: Aggressive curriculum enhances task performance")
        else:
            print("❌ TASK ACCURACY: Baseline performance competitive")
        
        # Final verdict
        overall_success_rate = (identity_success_rate + loss_success_rate + accuracy_success_rate) / 3
        print(f"\n🏆 OVERALL VERDICT:")
        if overall_success_rate >= 66:
            print(f"   🎯 STRONG RECOMMENDATION: Deploy aggressive curriculum (Success: {overall_success_rate:.0f}%)")
        elif overall_success_rate >= 33:
            print(f"   ⚖️  MIXED RESULTS: Task-dependent benefits (Success: {overall_success_rate:.0f}%)")
        else:
            print(f"   🎯 RECOMMENDATION: Stick with baseline training (Success: {overall_success_rate:.0f}%)")
    
    else:
        print("⚠️  Insufficient data for overall assessment")
    
    # Experimental details
    print(f"\n📋 EXPERIMENTAL SETUP:")
    print(f"   • Tasks tested: {', '.join([TASKS[k]['name'] for k in TASKS.keys()])}")
    print(f"   • Training steps: {list(TASKS.values())[0]['train_steps']} per task")
    print(f"   • Diffusion steps: {COMMON_ARGS['diffusion_steps']}")
    print(f"   • Focus metric: Identity error (||Pred @ Input - I||²)")
    print(f"   • Validation interval: Every {COMMON_ARGS['csv_log_interval']} steps")
    
    return unified_df

def create_performance_radar_chart(task_summaries: Dict, task_improvements: Dict):
    """Create radar chart comparing curriculum performance across multiple dimensions."""
    
    # Calculate aggregated metrics across tasks
    baseline_metrics = {
        'Identity Error': [],
        'Total Loss': [],
        'Accuracy': [],
        'Training Speed': [],
        'Convergence Speed': []
    }
    
    aggressive_metrics = {
        'Identity Error': [],
        'Total Loss': [],
        'Accuracy': [],
        'Training Speed': [],
        'Convergence Speed': []
    }
    
    for task_key, summary_df in task_summaries.items():
        baseline_row = summary_df[summary_df['curriculum'] == 'baseline']
        aggressive_row = summary_df[summary_df['curriculum'] == 'aggressive']
        
        if len(baseline_row) > 0 and len(aggressive_row) > 0:
            baseline = baseline_row.iloc[0]
            aggressive = aggressive_row.iloc[0]
            
            # Normalize metrics (convert to 0-1 scale where higher = better)
            # Identity Error (lower is better, so invert)
            if baseline['final_identity_error'] != float('inf') and aggressive['final_identity_error'] != float('inf'):
                max_error = max(baseline['final_identity_error'], aggressive['final_identity_error'])
                baseline_metrics['Identity Error'].append(1 - baseline['final_identity_error'] / (max_error + 1e-8))
                aggressive_metrics['Identity Error'].append(1 - aggressive['final_identity_error'] / (max_error + 1e-8))
            
            # Total Loss (lower is better, so invert)
            if baseline['final_total_loss'] > 0 and aggressive['final_total_loss'] > 0:
                max_loss = max(baseline['final_total_loss'], aggressive['final_total_loss'])
                baseline_metrics['Total Loss'].append(1 - baseline['final_total_loss'] / max_loss)
                aggressive_metrics['Total Loss'].append(1 - aggressive['final_total_loss'] / max_loss)
            
            # Accuracy (higher is better)
            if baseline['final_accuracy'] > 0 or aggressive['final_accuracy'] > 0:
                max_acc = max(baseline['final_accuracy'], aggressive['final_accuracy'])
                baseline_metrics['Accuracy'].append(baseline['final_accuracy'] / (max_acc + 1e-8))
                aggressive_metrics['Accuracy'].append(aggressive['final_accuracy'] / (max_acc + 1e-8))
            
            # Training Speed (lower time is better, so invert)
            if baseline['training_time'] > 0 and aggressive['training_time'] > 0:
                max_time = max(baseline['training_time'], aggressive['training_time'])
                baseline_metrics['Training Speed'].append(1 - baseline['training_time'] / max_time)
                aggressive_metrics['Training Speed'].append(1 - aggressive['training_time'] / max_time)
            
            # Convergence Speed (lower steps is better, so invert)
            if baseline['convergence_step'] > 0 and aggressive['convergence_step'] > 0:
                max_conv = max(baseline['convergence_step'], aggressive['convergence_step'])
                baseline_metrics['Convergence Speed'].append(1 - baseline['convergence_step'] / max_conv)
                aggressive_metrics['Convergence Speed'].append(1 - aggressive['convergence_step'] / max_conv)
    
    # Calculate average scores
    categories = []
    baseline_scores = []
    aggressive_scores = []
    
    for category, values in baseline_metrics.items():
        if values and aggressive_metrics[category]:  # Only include if both have data
            categories.append(category)
            baseline_scores.append(np.mean(values))
            aggressive_scores.append(np.mean(aggressive_metrics[category]))
    
    if len(categories) >= 3:  # Only create radar chart if we have enough metrics
        # Create radar chart
        N = len(categories)
        angles = [n / float(N) * 2 * pi for n in range(N)]
        angles += angles[:1]  # Complete the circle
        
        fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))
        ax.set_title('Multi-Dimensional Performance Comparison\n(Higher = Better)', 
                     fontsize=14, fontweight='bold', pad=20)
        
        # Add baseline scores
        baseline_values = baseline_scores + baseline_scores[:1]
        ax.plot(angles, baseline_values, 'o-', linewidth=3, label='Baseline',
                color=CURRICULA['baseline']['color'], alpha=0.8)
        ax.fill(angles, baseline_values, alpha=0.15, color=CURRICULA['baseline']['color'])
        
        # Add aggressive scores
        aggressive_values = aggressive_scores + aggressive_scores[:1]
        ax.plot(angles, aggressive_values, 'o-', linewidth=3, label='Aggressive',
                color=CURRICULA['aggressive']['color'], alpha=0.8)
        ax.fill(angles, aggressive_values, alpha=0.15, color=CURRICULA['aggressive']['color'])
        
        # Customize chart
        ax.set_xticks(angles[:-1])
        ax.set_xticklabels(categories, fontweight='bold')
        ax.set_ylim(0, 1)
        ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
        ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], alpha=0.7)
        ax.grid(True, alpha=0.3)
        
        plt.legend(loc='upper right', bbox_to_anchor=(1.2, 1.0))
        plt.tight_layout()
        plt.show()
    else:
        print("Insufficient data for radar chart (need at least 3 metrics with valid comparisons)")

# Generate comprehensive analysis
print("Generating comprehensive analysis...")
unified_results = generate_comprehensive_analysis(task_summaries, task_improvements)

print("\nGenerating performance radar chart...")
create_performance_radar_chart(task_summaries, task_improvements)

print("\nComprehensive analysis completed!")

## Results Export and Final Summary

In [None]:
# Export comprehensive results
output_dir = Path('./ired_task_curriculum_results')
output_dir.mkdir(exist_ok=True)

# Save task-specific summaries
for task_key, summary_df in task_summaries.items():
    summary_df.to_csv(output_dir / f'{task_key}_detailed_summary.csv', index=False)

# Save unified results
unified_results.to_csv(output_dir / 'unified_task_curriculum_results.csv', index=False)

# Save improvements summary
improvements_data = []
for task_key, improvements in task_improvements.items():
    for metric, value in improvements.items():
        improvements_data.append({
            'Task': TASKS[task_key]['name'],
            'Task_Key': task_key,
            'Metric': metric.replace('_', ' ').title(),
            'Improvement_Percent': value
        })

improvements_df = pd.DataFrame(improvements_data)
improvements_df.to_csv(output_dir / 'curriculum_improvements_summary.csv', index=False)

# Create training results summary
training_summary = []
for task_key, task_results in training_results.items():
    for curriculum_key, result_code in task_results.items():
        training_summary.append({
            'Task': TASKS[task_key]['name'],
            'Task_Key': task_key,
            'Curriculum': CURRICULA[curriculum_key]['name'],
            'Curriculum_Key': curriculum_key,
            'Exit_Code': result_code,
            'Success': result_code == 0,
            'Training_Steps': TASKS[task_key]['train_steps'],
            'Model': TASKS[task_key]['model'],
            'Dataset': TASKS[task_key]['dataset']
        })

training_df = pd.DataFrame(training_summary)
training_df.to_csv(output_dir / 'training_execution_summary.csv', index=False)

print("\n" + "="*80)
print("RESULTS EXPORT AND FINAL SUMMARY")
print("="*80)

print(f"\n📁 Results saved to: {output_dir}")
print("   📊 Data Files:")
for task_key in TASKS.keys():
    print(f"     • {task_key}_detailed_summary.csv - {TASKS[task_key]['name']} detailed metrics")
print(f"     • unified_task_curriculum_results.csv - Combined results across all tasks")
print(f"     • curriculum_improvements_summary.csv - Improvement percentages by metric")
print(f"     • training_execution_summary.csv - Training execution status")

# Generate final executive summary
successful_tasks = sum(1 for task_results in training_results.values() 
                      if all(r == 0 for r in task_results.values()))
total_experiments = sum(len(task_results) for task_results in training_results.values())
successful_experiments = sum(sum(1 for r in task_results.values() if r == 0) 
                           for task_results in training_results.values())

print("\n" + "="*80)
print("🧠 IRED TASK CURRICULUM ANALYSIS - EXECUTIVE SUMMARY 🧠")
print("="*80)

print(f"\n🎯 EXPERIMENTAL OVERVIEW:")
print(f"   • Tasks evaluated: {len(TASKS)} IRED reasoning tasks")
print(f"   • Curricula compared: {len(CURRICULA)} approaches (Baseline vs Aggressive)")
print(f"   • Total experiments: {total_experiments}")
print(f"   • Successful completions: {successful_experiments}/{total_experiments} ({successful_experiments/total_experiments*100:.1f}%)")
print(f"   • Fully completed tasks: {successful_tasks}/{len(TASKS)}")

# Calculate key findings
identity_improvements = [task_improvements[task]['identity_error'] 
                        for task in task_improvements.keys() 
                        if task_improvements[task]['identity_error'] != 0]

loss_improvements = [task_improvements[task]['total_loss'] 
                    for task in task_improvements.keys() 
                    if task_improvements[task]['total_loss'] != 0]

if identity_improvements:
    avg_identity_improvement = np.mean(identity_improvements)
    positive_identity_improvements = sum(1 for imp in identity_improvements if imp > 0)
    print(f"\n🔍 IDENTITY ERROR FINDINGS:")
    print(f"   • Average improvement: {avg_identity_improvement:+.1f}%")
    print(f"   • Tasks with improvement: {positive_identity_improvements}/{len(identity_improvements)}")
    print(f"   • Success rate: {positive_identity_improvements/len(identity_improvements)*100:.1f}%")

if loss_improvements:
    avg_loss_improvement = np.mean(loss_improvements)
    positive_loss_improvements = sum(1 for imp in loss_improvements if imp > 0)
    print(f"\n📉 TRAINING LOSS FINDINGS:")
    print(f"   • Average improvement: {avg_loss_improvement:+.1f}%")
    print(f"   • Tasks with improvement: {positive_loss_improvements}/{len(loss_improvements)}")
    print(f"   • Success rate: {positive_loss_improvements/len(loss_improvements)*100:.1f}%")

# Final recommendation
print(f"\n🏆 FINAL RECOMMENDATION:")
if identity_improvements and len(identity_improvements) > 0:
    identity_success_rate = sum(1 for imp in identity_improvements if imp > 0) / len(identity_improvements)
    if identity_success_rate >= 0.5:
        print(f"   ✅ DEPLOY AGGRESSIVE CURRICULUM: Shows consistent identity error improvements")
        print(f"      Success rate: {identity_success_rate*100:.0f}% of tasks improved")
    else:
        print(f"   ❌ STICK WITH BASELINE: Aggressive curriculum does not consistently improve performance")
        print(f"      Success rate: {identity_success_rate*100:.0f}% of tasks improved")
else:
    print(f"   ⚠️  INCONCLUSIVE: Insufficient data for reliable recommendation")
    print(f"      Consider longer training or additional validation metrics")

print(f"\n📋 TECHNICAL DETAILS:")
print(f"   • Training steps per task: {list(TASKS.values())[0]['train_steps']}")
print(f"   • Diffusion steps: {COMMON_ARGS['diffusion_steps']}")
print(f"   • Validation interval: Every {COMMON_ARGS['csv_log_interval']} steps")
print(f"   • Key metric: Identity error ||Pred @ Input - I||²")
print(f"   • Models tested: {', '.join(set(task['model'] for task in TASKS.values()))}")

print(f"\n📚 TASKS EVALUATED:")
for task_key, task_config in TASKS.items():
    success = all(training_results[task_key][curr] == 0 for curr in CURRICULA.keys())
    status = "✅" if success else "❌"
    print(f"   {status} {task_config['name']}: {task_config['description']}")

print("\n" + "="*80)
print("🚀 IRED TASK CURRICULUM COMPARISON COMPLETED SUCCESSFULLY! 🚀")
print("="*80)

print(f"\n💡 NEXT STEPS:")
print(f"   1. Review detailed task-specific visualizations above")
print(f"   2. Examine CSV files for numerical analysis")
print(f"   3. Consider extending to OOD (out-of-distribution) versions if aggressive shows promise")
print(f"   4. Validate findings with longer training runs if needed")
print(f"   5. Test on additional IRED tasks if curriculum shows clear benefits")