In [None]:
!rm -rf energy-based-model-2
!git clone https://github.com/mdkrasnow/energy-based-model-2
%cd energy-based-model-2

In [None]:
!pip install -q torch torchvision einops accelerate tqdm tabulate matplotlib numpy pandas ema-pytorch ipdb seaborn scikit-learn

In [None]:
import os
import sys
import subprocess
import argparse
import json
from pathlib import Path
import time
import re

# Hyperparameters from the paper (Appendix A)
BATCH_SIZE = 2048
LEARNING_RATE = 1e-4
TRAIN_ITERATIONS = 1000  
DIFFUSION_STEPS = 10
RANK = 20  # For 20x20 matrices

# Tasks to run
TASKS = ['addition']

class ExperimentRunner:
    def __init__(self, base_dir='experiments'):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(exist_ok=True)
        self.results = {}
        
    def get_result_dir(self, dataset, model_type='baseline'):
        """Get the results directory for a given dataset and model type"""
        base = f'results/ds_{dataset}/model_mlp_diffsteps_{DIFFUSION_STEPS}'
        if model_type == 'anm':
            base += '_anm'
        elif model_type == 'anm_curriculum':
            base += '_anm_curriculum'
        return base
    
    def train_model(self, dataset, model_type='baseline', force_retrain=False):
        """Train a model for a specific dataset and model type
        
        Args:
            dataset: Dataset name
            model_type: One of 'baseline', 'anm', 'anm_curriculum'
            force_retrain: Force retraining even if model exists
        """
        result_dir = self.get_result_dir(dataset, model_type)
        
        # Check if model already exists
        if not force_retrain and os.path.exists(f'{result_dir}/model-1.pt'):
            print(f"\n{'='*80}")
            print(f"Model for {dataset} ({model_type}) already exists. Skipping training.")
            print(f"Use --force to retrain.")
            print(f"{'='*80}\n")
            sys.stdout.flush()
            return True
            
        print(f"\n{'='*80}")
        print(f"Training IRED ({model_type.upper()}) on {dataset.upper()} task")
        print(f"{'='*80}")
        print(f"Model Type: {model_type}")
        print(f"Batch size: {BATCH_SIZE}")
        print(f"Learning rate: {LEARNING_RATE}")
        print(f"Training iterations: {TRAIN_ITERATIONS}")
        print(f"Diffusion steps: {DIFFUSION_STEPS}")
        print(f"Matrix rank: {RANK}")
        print(f"Result directory: {result_dir}")
        
        if model_type == 'anm_curriculum':
            print(f"\nCurriculum Schedule (% of {TRAIN_ITERATIONS} steps):")
            print(f"  Warmup (0-10%): Clean samples only")
            print(f"  Easy (10-30%): 50% clean, 30% adversarial, 20% gaussian")
            print(f"  Medium (30-60%): 30% clean, 50% adversarial, 20% gaussian")
            print(f"  Hard (60-100%): 10% clean, 80% adversarial, 10% gaussian")
            
        print(f"{'='*80}\n")
        sys.stdout.flush()
        
        # Build command
        cmd = [
            'python', 'train.py',
            '--dataset', dataset,
            '--model', 'mlp',
            '--batch_size', str(BATCH_SIZE),
            '--diffusion_steps', str(DIFFUSION_STEPS),
            '--rank', str(RANK),
            '--train-steps', str(TRAIN_ITERATIONS),  # Pass training steps
        ]
        
        # Add model-specific parameters
        if model_type == 'anm':
            cmd.extend([
                '--use-anm',
                '--anm-adversarial-steps', '5',
                '--anm-distance-penalty', '0.1',
                # anm_warmup_steps will be calculated as 10% of train_steps automatically
            ])
        elif model_type == 'anm_curriculum':
            cmd.extend([
                '--use-anm',
                '--use-curriculum',
                '--anm-adversarial-steps', '5',
                '--anm-distance-penalty', '0.1',
                # Curriculum and warmup will be percentage-based
            ])
        
        # Run training with real-time output
        try:
            start_time = time.time()
            
            # Use subprocess.Popen for real-time output with flushing
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                bufsize=1
            )
            
            # Display output line by line as it comes
            for line in iter(process.stdout.readline, ''):
                if line:
                    print(line.rstrip())
                    sys.stdout.flush()
            
            # Wait for process to complete
            result = process.wait()
            elapsed = time.time() - start_time
            
            if result == 0:
                print(f"\n{'='*80}")
                print(f"Training completed for {dataset} ({model_type}) in {elapsed/60:.2f} minutes")
                print(f"{'='*80}\n")
                sys.stdout.flush()
                return True
            else:
                print(f"\n{'='*80}")
                print(f"ERROR: Training failed for {dataset} ({model_type}) with exit code {result}")
                print(f"{'='*80}\n")
                sys.stdout.flush()
                return False
            
        except Exception as e:
            print(f"\n{'='*80}")
            print(f"ERROR: Training failed for {dataset} ({model_type}): {e}")
            print(f"{'='*80}\n")
            sys.stdout.flush()
            return False
    
    def evaluate_model(self, dataset, model_type='baseline', ood=False):
        """Evaluate a trained model on same or harder difficulty"""
        result_dir = self.get_result_dir(dataset, model_type)
        
        # Check if model exists
        if not os.path.exists(f'{result_dir}/model-1.pt'):
            print(f"\n{'='*80}")
            print(f"ERROR: No trained model found for {dataset} ({model_type})")
            print(f"Expected location: {result_dir}/model-1.pt")
            print(f"Please train the model first.")
            print(f"{'='*80}\n")
            sys.stdout.flush()
            return None
        
        difficulty = "Harder Difficulty (OOD)" if ood else "Same Difficulty"
        print(f"\n{'='*80}")
        print(f"Evaluating IRED ({model_type.upper()}) on {dataset.upper()} - {difficulty}")
        print(f"{'='*80}\n")
        sys.stdout.flush()
        
        # Build command
        cmd = [
            'python', 'train.py',
            '--dataset', dataset,
            '--model', 'mlp',
            '--batch_size', str(BATCH_SIZE),
            '--diffusion_steps', str(DIFFUSION_STEPS),
            '--rank', str(RANK),
            '--train-steps', str(TRAIN_ITERATIONS),  # Pass for consistency
            '--load-milestone', '1',
            '--evaluate',
        ]
        
        # Add model-specific parameters for evaluation
        if model_type == 'anm':
            cmd.extend([
                '--use-anm',
                '--anm-adversarial-steps', '5',
                '--anm-distance-penalty', '0.1',
            ])
        elif model_type == 'anm_curriculum':
            cmd.extend([
                '--use-anm',
                '--use-curriculum',
                '--anm-adversarial-steps', '5',
                '--anm-distance-penalty', '0.1',
            ])
        
        if ood:
            cmd.append('--ood')
        
        # Run evaluation with real-time output
        try:
            # Collect output for MSE parsing while also displaying it
            output_lines = []
            
            # Use subprocess.Popen for real-time output with flushing
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                bufsize=1
            )
            
            # Display output line by line as it comes
            for line in iter(process.stdout.readline, ''):
                if line:
                    print(line.rstrip())
                    sys.stdout.flush()
                    output_lines.append(line)
            
            # Wait for process to complete
            result = process.wait()
            
            if result == 0:
                # Parse output to extract MSE
                output_text = ''.join(output_lines)
                mse = self._parse_mse_from_output(output_text, '')
                
                print(f"\n{'='*80}")
                print(f"Evaluation completed for {dataset} ({model_type}) - {difficulty}")
                if mse is not None:
                    print(f"MSE: {mse:.4f}")
                print(f"{'='*80}\n")
                sys.stdout.flush()
                
                return mse
            else:
                print(f"\n{'='*80}")
                print(f"ERROR: Evaluation failed for {dataset} ({model_type}) - {difficulty} with exit code {result}")
                print(f"{'='*80}\n")
                sys.stdout.flush()
                return None
            
        except Exception as e:
            print(f"\n{'='*80}")
            print(f"ERROR: Evaluation failed for {dataset} ({model_type}) - {difficulty}: {e}")
            print(f"{'='*80}\n")
            sys.stdout.flush()
            return None
    
    def _parse_mse_from_output(self, stdout, stderr):
        """Parse MSE from training/evaluation output"""
        output = stdout + stderr
        lines = output.split('\n')
        
        # Look for validation result tables with MSE values
        mse_value = None
        for i, line in enumerate(lines):
            # Look for the specific pattern of mse in a table
            if line.startswith('mse') and '  ' in line:
                # This looks like a table row with MSE
                parts = line.split()
                if len(parts) >= 2 and parts[0] == 'mse':
                    try:
                        mse_value = float(parts[1])
                        # Continue searching to find the last MSE value (most recent)
                    except (ValueError, IndexError):
                        pass
        
        # If we didn't find MSE in table format, try alternative formats
        if mse_value is None:
            # Look for patterns like "mse_error  0.635722"
            for line in lines:
                if 'mse_error' in line.lower():
                    parts = line.split()
                    for i, part in enumerate(parts):
                        if 'mse' in part.lower() and i + 1 < len(parts):
                            try:
                                mse_value = float(parts[i + 1])
                            except ValueError:
                                pass
        
        return mse_value
    
    def train_all(self, force_retrain=False):
        """Train all models (baseline and ANM variants)"""
        print(f"\n{'#'*80}")
        print(f"# TRAINING ALL CONTINUOUS TASKS WITH BASELINE AND ANM MODELS")
        print(f"# Tasks: {', '.join(TASKS)}")
        print(f"# Model Types: baseline, anm, anm_curriculum")
        print(f"# Training Steps: {TRAIN_ITERATIONS}")
        print(f"# Curriculum: Percentage-based stage transitions")
        print(f"{'#'*80}\n")
        sys.stdout.flush()
        
        success = {}
        model_types = ['baseline', 'anm', 'anm_curriculum']
        
        for dataset in TASKS:
            for model_type in model_types:
                key = f"{dataset}_{model_type}"
                success[key] = self.train_model(dataset, model_type, force_retrain)
        
        print(f"\n{'#'*80}")
        print(f"# TRAINING SUMMARY")
        print(f"{'#'*80}")
        for dataset in TASKS:
            print(f"\n{dataset.upper()}:")
            for model_type in model_types:
                key = f"{dataset}_{model_type}"
                status_str = "✓ SUCCESS" if success.get(key, False) else "✗ FAILED"
                print(f"  {model_type:20s}: {status_str}")
        print(f"{'#'*80}\n")
        sys.stdout.flush()
        
        return all(success.values())
    
    def evaluate_all(self):
        """Evaluate all models on both same and harder difficulty"""
        print(f"\n{'#'*80}")
        print(f"# EVALUATING ALL CONTINUOUS TASKS")
        print(f"# Tasks: {', '.join(TASKS)}")
        print(f"# Model Types: baseline, anm, anm_curriculum")
        print(f"{'#'*80}\n")
        sys.stdout.flush()
        
        results = {}
        model_types = ['baseline', 'anm', 'anm_curriculum']
        
        for dataset in TASKS:
            results[dataset] = {}
            for model_type in model_types:
                results[dataset][model_type] = {
                    'same_difficulty': self.evaluate_model(dataset, model_type, ood=False),
                    'harder_difficulty': self.evaluate_model(dataset, model_type, ood=True)
                }
        
        self.results = results
        self._print_results_table()
        self._save_results()
        
        return results
    
    def _print_results_table(self):
        """Print results in a comparison table format"""
        print(f"\n{'#'*80}")
        print(f"# RESULTS COMPARISON TABLE")
        print(f"# Training Steps: {TRAIN_ITERATIONS}")
        print(f"{'#'*80}\n")
        
        # Print header
        print(f"{'Task':<20s} {'Method':<25s} {'Same Difficulty':>15s} {'Harder Difficulty':>17s}")
        print(f"{'-'*20} {'-'*25} {'-'*15} {'-'*17}")
        
        # Task name mapping for display
        task_display = {
            'addition': 'Addition',
            'lowrank': 'Matrix Completion',
            'inverse': 'Matrix Inverse'
        }
        
        # Print results for each task
        for dataset in TASKS:
            task_name = task_display.get(dataset, dataset)
            
            # Baseline
            baseline_same = self.results.get(dataset, {}).get('baseline', {}).get('same_difficulty')
            baseline_harder = self.results.get(dataset, {}).get('baseline', {}).get('harder_difficulty')
            baseline_same_str = f"{baseline_same:.4f}" if baseline_same is not None else "N/A"
            baseline_harder_str = f"{baseline_harder:.4f}" if baseline_harder is not None else "N/A"
            print(f"{task_name:<20s} {'IRED (baseline)':<25s} {baseline_same_str:>15s} {baseline_harder_str:>17s}")
            
            # ANM without curriculum
            anm_same = self.results.get(dataset, {}).get('anm', {}).get('same_difficulty')
            anm_harder = self.results.get(dataset, {}).get('anm', {}).get('harder_difficulty')
            anm_same_str = f"{anm_same:.4f}" if anm_same is not None else "N/A"
            anm_harder_str = f"{anm_harder:.4f}" if anm_harder is not None else "N/A"
            print(f"{'':<20s} {'IRED + ANM':<25s} {anm_same_str:>15s} {anm_harder_str:>17s}")
            
            # ANM with curriculum
            anm_curr_same = self.results.get(dataset, {}).get('anm_curriculum', {}).get('same_difficulty')
            anm_curr_harder = self.results.get(dataset, {}).get('anm_curriculum', {}).get('harder_difficulty')
            anm_curr_same_str = f"{anm_curr_same:.4f}" if anm_curr_same is not None else "N/A"
            anm_curr_harder_str = f"{anm_curr_harder:.4f}" if anm_curr_harder is not None else "N/A"
            print(f"{'':<20s} {'IRED + ANM + Curriculum':<25s} {anm_curr_same_str:>15s} {anm_curr_harder_str:>17s}")
            
            print()  # Blank line between tasks
        
        # Print improvement percentages if baseline exists
        print(f"\n{'#'*80}")
        print(f"# RELATIVE IMPROVEMENTS vs BASELINE")
        print(f"{'#'*80}\n")
        
        for dataset in TASKS:
            task_name = task_display.get(dataset, dataset)
            baseline_same = self.results.get(dataset, {}).get('baseline', {}).get('same_difficulty')
            baseline_harder = self.results.get(dataset, {}).get('baseline', {}).get('harder_difficulty')
            
            if baseline_same and baseline_harder:
                print(f"{task_name}:")
                
                # ANM improvements
                anm_same = self.results.get(dataset, {}).get('anm', {}).get('same_difficulty')
                anm_harder = self.results.get(dataset, {}).get('anm', {}).get('harder_difficulty')
                if anm_same and anm_harder:
                    same_imp = ((baseline_same - anm_same) / baseline_same) * 100
                    harder_imp = ((baseline_harder - anm_harder) / baseline_harder) * 100
                    print(f"  ANM: {same_imp:+.1f}% (same), {harder_imp:+.1f}% (harder)")
                
                # Curriculum improvements
                curr_same = self.results.get(dataset, {}).get('anm_curriculum', {}).get('same_difficulty')
                curr_harder = self.results.get(dataset, {}).get('anm_curriculum', {}).get('harder_difficulty')
                if curr_same and curr_harder:
                    same_imp = ((baseline_same - curr_same) / baseline_same) * 100
                    harder_imp = ((baseline_harder - curr_harder) / baseline_harder) * 100
                    print(f"  ANM+Curriculum: {same_imp:+.1f}% (same), {harder_imp:+.1f}% (harder)")
        
        print(f"\n{'#'*80}")
        print(f"# Paper's reported IRED results for comparison:")
        print(f"{'#'*80}")
        print(f"{'Addition':<20s} {'IRED (paper)':<25s} {'0.0002':>15s} {'0.0020':>17s}")
        print(f"{'Matrix Completion':<20s} {'IRED (paper)':<25s} {'0.0174':>15s} {'0.2054':>17s}")
        print(f"{'Matrix Inverse':<20s} {'IRED (paper)':<25s} {'0.0095':>15s} {'0.2063':>17s}")
        print(f"{'#'*80}\n")
        sys.stdout.flush()
    
    def _save_results(self):
        """Save results to JSON file"""
        results_file = self.base_dir / 'continuous_results_with_anm.json'
        
        # Add metadata
        results_with_meta = {
            'metadata': {
                'batch_size': BATCH_SIZE,
                'learning_rate': LEARNING_RATE,
                'train_iterations': TRAIN_ITERATIONS,
                'diffusion_steps': DIFFUSION_STEPS,
                'rank': RANK,
                'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
                'model_types': ['baseline', 'anm', 'anm_curriculum'],
                'curriculum': {
                    'warmup': '0-10%',
                    'easy': '10-30%',
                    'medium': '30-60%',
                    'hard': '60-100%'
                }
            },
            'results': self.results
        }
        
        with open(results_file, 'w') as f:
            json.dump(results_with_meta, f, indent=2)
        
        print(f"Results saved to: {results_file}\n")
        sys.stdout.flush()

In [None]:
# Initialize runner with base directory
args = argparse.Namespace(base_dir='experiments', force=False)
runner = ExperimentRunner(base_dir=args.base_dir)

# Train all models
success = runner.train_all(force_retrain=args.force)

# Evaluate if training succeeded
if success:
    # Evaluate all
    runner.evaluate_all()
else:
    print("\nSome training jobs failed. Skipping evaluation.")
    sys.stdout.flush()