In [None]:
!rm -rf energy-based-model-2
!git clone https://github.com/mdkrasnow/energy-based-model-2
!cd energy-based-model-2

In [None]:
import os
import sys
import subprocess
import argparse
import json
from pathlib import Path
import time
import re

# Hyperparameters from the paper (Appendix A)
BATCH_SIZE = 2048
LEARNING_RATE = 1e-4
TRAIN_ITERATIONS = 1000  # Paper states ~50,000 iterations
DIFFUSION_STEPS = 10
RANK = 20  # For 20x20 matrices

# Tasks to run
TASKS = ['addition']

class ExperimentRunner:
    def __init__(self, base_dir='experiments'):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(exist_ok=True)
        self.results = {}
        
    def get_result_dir(self, dataset):
        """Get the results directory for a given dataset"""
        return f'results/ds_{dataset}/model_mlp_diffsteps_{DIFFUSION_STEPS}'
    
    def train_model(self, dataset, force_retrain=False):
        """Train a model for a specific dataset"""
        result_dir = self.get_result_dir(dataset)
        
        # Check if model already exists
        if not force_retrain and os.path.exists(f'{result_dir}/model-1.pt'):
            print(f"\n{'='*80}")
            print(f"Model for {dataset} already exists. Skipping training.")
            print(f"Use --force to retrain.")
            print(f"{'='*80}\n")
            sys.stdout.flush()
            return True
            
        print(f"\n{'='*80}")
        print(f"Training IRED on {dataset.upper()} task")
        print(f"{'='*80}")
        print(f"Batch size: {BATCH_SIZE}")
        print(f"Learning rate: {LEARNING_RATE}")
        print(f"Training iterations: {TRAIN_ITERATIONS}")
        print(f"Diffusion steps: {DIFFUSION_STEPS}")
        print(f"Matrix rank: {RANK}")
        print(f"Result directory: {result_dir}")
        print(f"{'='*80}\n")
        sys.stdout.flush()
        
        # Build command
        cmd = [
            'python', 'train.py',
            '--dataset', dataset,
            '--model', 'mlp',
            '--batch_size', str(BATCH_SIZE),
            '--diffusion_steps', str(DIFFUSION_STEPS),
            '--rank', str(RANK),
        ]
        
        # Run training with real-time output
        try:
            start_time = time.time()
            
            # Use subprocess.Popen for real-time output with flushing
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                bufsize=1
            )
            
            # Display output line by line as it comes
            for line in iter(process.stdout.readline, ''):
                if line:
                    print(line.rstrip())
                    sys.stdout.flush()
            
            # Wait for process to complete
            result = process.wait()
            elapsed = time.time() - start_time
            
            if result == 0:
                print(f"\n{'='*80}")
                print(f"Training completed for {dataset} in {elapsed/3600:.2f} hours")
                print(f"{'='*80}\n")
                sys.stdout.flush()
                return True
            else:
                print(f"\n{'='*80}")
                print(f"ERROR: Training failed for {dataset} with exit code {result}")
                print(f"{'='*80}\n")
                sys.stdout.flush()
                return False
            
        except Exception as e:
            print(f"\n{'='*80}")
            print(f"ERROR: Training failed for {dataset}: {e}")
            print(f"{'='*80}\n")
            sys.stdout.flush()
            return False
    
    def evaluate_model(self, dataset, ood=False):
        """Evaluate a trained model on same or harder difficulty"""
        result_dir = self.get_result_dir(dataset)
        
        # Check if model exists
        if not os.path.exists(f'{result_dir}/model-1.pt'):
            print(f"\n{'='*80}")
            print(f"ERROR: No trained model found for {dataset}")
            print(f"Expected location: {result_dir}/model-1.pt")
            print(f"Please train the model first.")
            print(f"{'='*80}\n")
            sys.stdout.flush()
            return None
        
        difficulty = "Harder Difficulty (OOD)" if ood else "Same Difficulty"
        print(f"\n{'='*80}")
        print(f"Evaluating IRED on {dataset.upper()} - {difficulty}")
        print(f"{'='*80}\n")
        sys.stdout.flush()
        
        # Build command
        cmd = [
            'python', 'train.py',
            '--dataset', dataset,
            '--model', 'mlp',
            '--batch_size', str(BATCH_SIZE),
            '--diffusion_steps', str(DIFFUSION_STEPS),
            '--rank', str(RANK),
            '--load-milestone', '1',
            '--evaluate',
        ]
        
        if ood:
            cmd.append('--ood')
        
        # Run evaluation with real-time output
        try:
            # Collect output for MSE parsing while also displaying it
            output_lines = []
            
            # Use subprocess.Popen for real-time output with flushing
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                bufsize=1
            )
            
            # Display output line by line as it comes
            for line in iter(process.stdout.readline, ''):
                if line:
                    print(line.rstrip())
                    sys.stdout.flush()
                    output_lines.append(line)
            
            # Wait for process to complete
            result = process.wait()
            
            if result == 0:
                # Parse output to extract MSE
                output_text = ''.join(output_lines)
                mse = self._parse_mse_from_output(output_text, '')
                
                print(f"\n{'='*80}")
                print(f"Evaluation completed for {dataset} - {difficulty}")
                if mse is not None:
                    print(f"MSE: {mse:.4f}")
                print(f"{'='*80}\n")
                sys.stdout.flush()
                
                return mse
            else:
                print(f"\n{'='*80}")
                print(f"ERROR: Evaluation failed for {dataset} - {difficulty} with exit code {result}")
                print(f"{'='*80}\n")
                sys.stdout.flush()
                return None
            
        except Exception as e:
            print(f"\n{'='*80}")
            print(f"ERROR: Evaluation failed for {dataset} - {difficulty}: {e}")
            print(f"{'='*80}\n")
            sys.stdout.flush()
            return None
    
    def _parse_mse_from_output(self, stdout, stderr):
        """Parse MSE from training/evaluation output"""
        output = stdout + stderr
        lines = output.split('\n')
        
        # Look for validation result tables with MSE values
        # Format:
        # ---  --------
        # mse  0.636725
        # ---  --------
        
        mse_value = None
        for i, line in enumerate(lines):
            # Look for the specific pattern of mse in a table
            if line.startswith('mse') and '  ' in line:
                # This looks like a table row with MSE
                parts = line.split()
                if len(parts) >= 2 and parts[0] == 'mse':
                    try:
                        mse_value = float(parts[1])
                        # Continue searching to find the last MSE value (most recent)
                    except (ValueError, IndexError):
                        pass
        
        # If we didn't find MSE in table format, try alternative formats
        if mse_value is None:
            # Look for patterns like "mse_error  0.635722"
            for line in lines:
                if 'mse_error' in line.lower():
                    parts = line.split()
                    for i, part in enumerate(parts):
                        if 'mse' in part.lower() and i + 1 < len(parts):
                            try:
                                mse_value = float(parts[i + 1])
                            except ValueError:
                                pass
        
        return mse_value
    
    def train_all(self, force_retrain=False):
        """Train all models"""
        print(f"\n{'#'*80}")
        print(f"# TRAINING ALL CONTINUOUS TASKS")
        print(f"# Tasks: {', '.join(TASKS)}")
        print(f"{'#'*80}\n")
        sys.stdout.flush()
        
        success = {}
        for dataset in TASKS:
            success[dataset] = self.train_model(dataset, force_retrain)
        
        print(f"\n{'#'*80}")
        print(f"# TRAINING SUMMARY")
        print(f"{'#'*80}")
        for dataset, status in success.items():
            status_str = "✓ SUCCESS" if status else "✗ FAILED"
            print(f"{dataset:20s}: {status_str}")
        print(f"{'#'*80}\n")
        sys.stdout.flush()
        
        return all(success.values())
    
    def evaluate_all(self):
        """Evaluate all models on both same and harder difficulty"""
        print(f"\n{'#'*80}")
        print(f"# EVALUATING ALL CONTINUOUS TASKS")
        print(f"# Tasks: {', '.join(TASKS)}")
        print(f"{'#'*80}\n")
        sys.stdout.flush()
        
        results = {}
        for dataset in TASKS:
            results[dataset] = {
                'same_difficulty': self.evaluate_model(dataset, ood=False),
                'harder_difficulty': self.evaluate_model(dataset, ood=True)
            }
        
        self.results = results
        self._print_results_table()
        self._save_results()
        
        return results
    
    def _print_results_table(self):
        """Print results in a table format like Table 1"""
        print(f"\n{'#'*80}")
        print(f"# RESULTS TABLE (Replicating Table 1 from IRED Paper)")
        print(f"{'#'*80}\n")
        
        # Print header
        print(f"{'Task':<20s} {'Method':<15s} {'Same Difficulty':>15s} {'Harder Difficulty':>17s}")
        print(f"{'-'*20} {'-'*15} {'-'*15} {'-'*17}")
        
        # Task name mapping for display
        task_display = {
            'addition': 'Addition',
            'lowrank': 'Matrix Completion',
            'inverse': 'Matrix Inverse'
        }
        
        # Print results for each task
        for dataset in TASKS:
            task_name = task_display.get(dataset, dataset)
            same = self.results.get(dataset, {}).get('same_difficulty')
            harder = self.results.get(dataset, {}).get('harder_difficulty')
            
            same_str = f"{same:.4f}" if same is not None else "N/A"
            harder_str = f"{harder:.4f}" if harder is not None else "N/A"
            
            print(f"{task_name:<20s} {'IRED (ours)':<15s} {same_str:>15s} {harder_str:>17s}")
        
        print(f"\n{'#'*80}")
        print(f"# Paper's reported IRED results for comparison:")
        print(f"{'#'*80}")
        print(f"{'Addition':<20s} {'IRED (paper)':<15s} {'0.0002':>15s} {'0.0020':>17s}")
        print(f"{'Matrix Completion':<20s} {'IRED (paper)':<15s} {'0.0174':>15s} {'0.2054':>17s}")
        print(f"{'Matrix Inverse':<20s} {'IRED (paper)':<15s} {'0.0095':>15s} {'0.2063':>17s}")
        print(f"{'#'*80}\n")
        sys.stdout.flush()
    
    def _save_results(self):
        """Save results to JSON file"""
        results_file = self.base_dir / 'continuous_results.json'
        
        # Add metadata
        results_with_meta = {
            'metadata': {
                'batch_size': BATCH_SIZE,
                'learning_rate': LEARNING_RATE,
                'train_iterations': TRAIN_ITERATIONS,
                'diffusion_steps': DIFFUSION_STEPS,
                'rank': RANK,
                'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
            },
            'results': self.results
        }
        
        with open(results_file, 'w') as f:
            json.dump(results_with_meta, f, indent=2)
        
        print(f"Results saved to: {results_file}\n")
        sys.stdout.flush()

In [None]:
# Initialize runner with base directory
args = argparse.Namespace(base_dir='experiments', force=False)
runner = ExperimentRunner(base_dir=args.base_dir)

# Train all models
success = runner.train_all(force_retrain=args.force)

# Evaluate if training succeeded
if success:
    # Evaluate all
    runner.evaluate_all()
else:
    print("\nSome training jobs failed. Skipping evaluation.")
    sys.stdout.flush()