# iREM vs SANS-Modified Model Smoke Test (Google Colab)

This notebook runs actual training comparing:
1. **Baseline iREM** - The original iterative refinement energy model
2. **SANS-Modified Model** - With Self-Adversarial Negative Sampling (RotatE-style)

**Environment**: Google Colab with T4 GPU

**Purpose**: Quick validation of training dynamics and performance differences using actual training code

## 1. Setup Environment and Clone Repository

In [None]:
# Check GPU availability
import torch
import os
import sys

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {gpu_name}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    
    if 'T4' in gpu_name:
        print("✓ T4 GPU detected")
        batch_size = 256
    elif 'V100' in gpu_name:
        print("✓ V100 GPU detected")
        batch_size = 512
    else:
        print(f"✓ {gpu_name} detected")
        batch_size = 128
else:
    print("⚠ No GPU detected - running on CPU (will be slower)")
    batch_size = 32

print(f"\nUsing batch size: {batch_size}")

In [None]:
# Clone the repository
!rm -rf energy-based-model
!git clone https://github.com/mdkrasnow/energy-based-model.git

# If repo is private or doesn't exist on GitHub, upload the files directly
# For now, we'll assume the files are uploaded to Colab
import os
if not os.path.exists('energy-based-model'):
    print("Repository not found. Please either:")
    print("1. Update the git clone URL with your repository")
    print("2. Upload the energy-based-model folder to Colab")
    print("\nCreating directory structure...")
    os.makedirs('energy-based-model', exist_ok=True)

# Change to the repository directory
os.chdir('energy-based-model')
print(f"Current directory: {os.getcwd()}")

In [None]:
# Install required dependencies
!pip install -q accelerate ema-pytorch einops tabulate tqdm matplotlib seaborn pandas
print("✓ Dependencies installed")

## 2. Import Training Modules and Configure Experiment

In [None]:
# Set environment variables
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

# Import required modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import time
from typing import Dict, List, Optional

# Import training modules from the repository
from diffusion_lib.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Trainer1D
from models import EBM, DiffusionWrapper
from dataset import Inverse, Addition, LowRankDataset

# For baseline iREM
from irem_lib.irem import Trainer1D as iREMTrainer1D

print("✓ Modules imported successfully")

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

In [None]:
# Configuration for smoke test
CONFIG = {
    # Dataset parameters
    'dataset': 'inverse',
    'rank': 10,
    'ood': False,
    
    # Model parameters
    'model': 'mlp',
    'diffusion_steps': 10,
    
    # Training parameters
    'batch_size': batch_size,  # Set based on GPU
    'learning_rate': 1e-4,
    'num_steps': 5000,  # Reduced for smoke test
    'val_every': 250,
    'save_every': 1000,
    
    # SANS parameters
    'sans_enabled': True,
    'sans_num_negs': 4,
    'sans_temp': 1.0,
    'sans_temp_schedule': True,
    'sans_chunk': 0,
    
    # Other parameters
    'supervise_energy_landscape': True,
    'use_innerloop_opt': False,
    'baseline': False,
    'data_workers': 2,
    'ema_decay': 0.995,
    'gradient_accumulate_every': 1,
    'amp': False,
    
    # Paths
    'results_dir': 'smoke_test_results',
    'baseline_dir': 'smoke_test_results/baseline_irem',
    'sans_dir': 'smoke_test_results/sans_modified'
}

# Create directories
for dir_path in [CONFIG['results_dir'], CONFIG['baseline_dir'], CONFIG['sans_dir']]:
    Path(dir_path).mkdir(parents=True, exist_ok=True)

# Save configuration
with open(f"{CONFIG['results_dir']}/config.json", 'w') as f:
    json.dump(CONFIG, f, indent=2)

print("Experiment Configuration:")
print("="*50)
for key, value in CONFIG.items():
    if not key.endswith('_dir'):
        print(f"  {key}: {value}")
print("="*50)

## 3. Setup Dataset and Models

In [None]:
# Create dataset
if CONFIG['dataset'] == 'inverse':
    dataset = Inverse("train", CONFIG['rank'], CONFIG['ood'])
elif CONFIG['dataset'] == 'addition':
    dataset = Addition("train", CONFIG['rank'], CONFIG['ood'])
elif CONFIG['dataset'] == 'lowrank':
    dataset = LowRankDataset("train", CONFIG['rank'], CONFIG['ood'])
else:
    raise ValueError(f"Unknown dataset: {CONFIG['dataset']}")

validation_dataset = dataset
metric = 'mse'

print(f"Dataset: {CONFIG['dataset']}")
print(f"  Input dimension: {dataset.inp_dim}")
print(f"  Output dimension: {dataset.out_dim}")
print(f"  Dataset size: {len(dataset)}")

## 4. Track Training Metrics

In [None]:
# We need to extract metrics from the trainers after they run
# The trainers log to console but don't save metrics.csv files
# We'll need to modify the training approach or parse logs

print("✓ Note: Trainers need to be modified to save metrics.csv or we need to parse their logs")

## 5. Run Baseline iREM Training

In [None]:
print("="*60)
print("BASELINE iREM TRAINING")
print("="*60)
print("\n⚠️ NOTE: The iREM trainer does not save metrics.csv by default.")
print("You need to modify irem_lib/irem.py to log metrics or capture console output.")
print("For now, training will run but metrics.csv won't be created automatically.\n")

# Create model for baseline
baseline_model = EBM(
    inp_dim=dataset.inp_dim,
    out_dim=dataset.out_dim,
)
baseline_model = DiffusionWrapper(baseline_model)

# Create baseline trainer
baseline_trainer = iREMTrainer1D(
    baseline_model,
    dataset,
    train_batch_size=CONFIG['batch_size'],
    validation_batch_size=min(256, CONFIG['batch_size']),
    train_lr=CONFIG['learning_rate'],
    train_num_steps=CONFIG['num_steps'],
    gradient_accumulate_every=CONFIG['gradient_accumulate_every'],
    ema_decay=CONFIG['ema_decay'],
    data_workers=CONFIG['data_workers'],
    amp=CONFIG['amp'],
    metric=metric,
    results_folder=CONFIG['baseline_dir'],
    cond_mask=False,
    validation_dataset=validation_dataset,
    save_and_sample_every=CONFIG['save_every'],
    evaluate_first=False
)

print(f"Training Baseline iREM for {CONFIG['num_steps']} steps...")
print(f"Results will be saved to: {CONFIG['baseline_dir']}")

# Run actual training
try:
    baseline_trainer.train()
    print("\n✓ Baseline iREM training complete")
except Exception as e:
    print(f"\n⚠ Training error: {str(e)}")
    print("Continuing with partial results...")

## 6. Run SANS-Modified Training

In [None]:
print("\n" + "="*60)
print("SANS-MODIFIED TRAINING")
print("="*60)
print("\n⚠️ NOTE: The SANS trainer does not save metrics.csv by default.")
print("You need to modify diffusion_lib/denoising_diffusion_pytorch_1d.py to log metrics.")
print("For now, training will run but metrics.csv won't be created automatically.\n")

# Create model for SANS
sans_model = EBM(
    inp_dim=dataset.inp_dim,
    out_dim=dataset.out_dim,
)
sans_model = DiffusionWrapper(sans_model)

# Create diffusion with SANS
sans_diffusion = GaussianDiffusion1D(
    sans_model,
    seq_length=32,
    objective='pred_noise',
    timesteps=CONFIG['diffusion_steps'],
    sampling_timesteps=CONFIG['diffusion_steps'],
    supervise_energy_landscape=CONFIG['supervise_energy_landscape'],
    use_innerloop_opt=CONFIG['use_innerloop_opt'],
    show_inference_tqdm=False,
    # SANS parameters
    sans_enabled=CONFIG['sans_enabled'],
    sans_num_negs=CONFIG['sans_num_negs'],
    sans_temp=CONFIG['sans_temp'],
    sans_temp_schedule=CONFIG['sans_temp_schedule'],
    sans_chunk=CONFIG['sans_chunk'],
    continuous=True  # For inverse dataset
)

# Create SANS trainer
sans_trainer = Trainer1D(
    sans_diffusion,
    dataset,
    train_batch_size=CONFIG['batch_size'],
    validation_batch_size=min(256, CONFIG['batch_size']),
    train_lr=CONFIG['learning_rate'],
    train_num_steps=CONFIG['num_steps'],
    gradient_accumulate_every=CONFIG['gradient_accumulate_every'],
    ema_decay=CONFIG['ema_decay'],
    data_workers=CONFIG['data_workers'],
    amp=CONFIG['amp'],
    metric=metric,
    results_folder=CONFIG['sans_dir'],
    cond_mask=False,
    validation_dataset=validation_dataset,
    save_and_sample_every=CONFIG['save_every'],
    evaluate_first=False
)

print(f"Training SANS-Modified for {CONFIG['num_steps']} steps...")
print(f"SANS Configuration:")
print(f"  - Number of negatives: {CONFIG['sans_num_negs']}")
print(f"  - Temperature: {CONFIG['sans_temp']}")
print(f"  - Temperature schedule: {CONFIG['sans_temp_schedule']}")
print(f"Results will be saved to: {CONFIG['sans_dir']}")

# Run actual training
try:
    sans_trainer.train()
    print("\n✓ SANS-Modified training complete")
except Exception as e:
    print(f"\n⚠ Training error: {str(e)}")
    print("Continuing with partial results...")

## 7. Load and Compare Results

In [None]:
# Check if metrics files exist, if not provide instructions
import os

baseline_metrics_path = f"{CONFIG['baseline_dir']}/metrics.csv"
sans_metrics_path = f"{CONFIG['sans_dir']}/metrics.csv"

if not os.path.exists(baseline_metrics_path) or not os.path.exists(sans_metrics_path):
    print("⚠️ ERROR: Metrics files not found!")
    print("\nThe trainers do not automatically save metrics.csv files.")
    print("\nTo fix this, you need to:")
    print("1. Modify irem_lib/irem.py to save training metrics to CSV")
    print("2. Modify diffusion_lib/denoising_diffusion_pytorch_1d.py to save training metrics to CSV")
    print("3. OR capture console output during training and parse it")
    print("\nThe metrics.csv files should have columns: step, loss, val_loss, lr, time")
    
    # Stop execution here
    raise FileNotFoundError("Metrics files not found. See instructions above.")

# Load metrics from saved files
baseline_df = pd.read_csv(baseline_metrics_path)
sans_df = pd.read_csv(sans_metrics_path)

# Add model type column
baseline_df['model_type'] = 'Baseline iREM'
sans_df['model_type'] = 'SANS-Modified'

# Combine dataframes
combined_df = pd.concat([baseline_df, sans_df], ignore_index=True)

# Calculate improvement metrics
final_baseline_loss = baseline_df['loss'].iloc[-1]
final_sans_loss = sans_df['loss'].iloc[-1]
improvement = (final_baseline_loss - final_sans_loss) / final_baseline_loss * 100

# Calculate convergence speed
def get_convergence_step(df, threshold=0.1):
    initial_loss = df['loss'].iloc[0]
    final_loss = df['loss'].iloc[-1]
    target_loss = initial_loss - 0.9 * (initial_loss - final_loss)
    conv_idx = df[df['loss'] <= target_loss].index
    if len(conv_idx) > 0:
        return df.loc[conv_idx[0], 'step']
    return df['step'].iloc[-1]

baseline_conv = get_convergence_step(baseline_df)
sans_conv = get_convergence_step(sans_df)

print("\n" + "="*60)
print("RESULTS SUMMARY")
print("="*60)
print(f"\nFinal Loss:")
print(f"  Baseline iREM: {final_baseline_loss:.6f}")
print(f"  SANS-Modified: {final_sans_loss:.6f}")
print(f"  Improvement: {improvement:.1f}%")

print(f"\nConvergence Speed (90% reduction):")
print(f"  Baseline iREM: {baseline_conv} steps")
print(f"  SANS-Modified: {sans_conv} steps")
if sans_conv > 0:
    print(f"  Speedup: {baseline_conv/sans_conv:.2f}x")

print(f"\nTraining Time:")
print(f"  Baseline iREM: {baseline_df['time'].iloc[-1]:.1f} seconds")
print(f"  SANS-Modified: {sans_df['time'].iloc[-1]:.1f} seconds")

# Save combined results
combined_df.to_csv(f"{CONFIG['results_dir']}/combined_metrics.csv", index=False)
print(f"\n✓ Combined metrics saved to {CONFIG['results_dir']}/combined_metrics.csv")

## 8. Visualize Training Dynamics

In [None]:
# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('iREM vs SANS-Modified Training Comparison', fontsize=16, y=1.02)

# 1. Training Loss
ax = axes[0, 0]
ax.plot(baseline_df['step'], baseline_df['loss'], label='Baseline iREM', 
        color='blue', alpha=0.8, linewidth=2)
ax.plot(sans_df['step'], sans_df['loss'], label='SANS-Modified', 
        color='red', alpha=0.8, linewidth=2)
ax.set_xlabel('Training Step')
ax.set_ylabel('Loss')
ax.set_title('Training Loss Comparison')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

# 2. Validation Loss
ax = axes[0, 1]
baseline_val = baseline_df.dropna(subset=['val_loss'])
sans_val = sans_df.dropna(subset=['val_loss'])
if len(baseline_val) > 0:
    ax.plot(baseline_val['step'], baseline_val['val_loss'], 'o-', 
            label='Baseline iREM', color='blue', alpha=0.8, markersize=6)
if len(sans_val) > 0:
    ax.plot(sans_val['step'], sans_val['val_loss'], 's-', 
            label='SANS-Modified', color='red', alpha=0.8, markersize=6)
ax.set_xlabel('Training Step')
ax.set_ylabel('Validation Loss')
ax.set_title('Validation Performance')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. Loss Reduction Rate
ax = axes[1, 0]
window = min(50, len(baseline_df) // 10)
baseline_smooth = baseline_df['loss'].rolling(window=window, min_periods=1).mean()
sans_smooth = sans_df['loss'].rolling(window=window, min_periods=1).mean()
ax.plot(baseline_df['step'], baseline_smooth, label='Baseline iREM', 
        color='blue', alpha=0.8, linewidth=2)
ax.plot(sans_df['step'], sans_smooth, label='SANS-Modified', 
        color='red', alpha=0.8, linewidth=2)
ax.set_xlabel('Training Step')
ax.set_ylabel('Smoothed Loss')
ax.set_title(f'Smoothed Loss (window={window})')
ax.legend()
ax.grid(True, alpha=0.3)

# 4. Relative Improvement
ax = axes[1, 1]
if len(baseline_smooth) == len(sans_smooth):
    improvement_curve = (baseline_smooth - sans_smooth) / baseline_smooth * 100
    ax.plot(baseline_df['step'], improvement_curve, color='green', linewidth=2)
    ax.fill_between(baseline_df['step'], 0, improvement_curve, 
                     where=(improvement_curve > 0), color='green', alpha=0.3, 
                     label='SANS Better')
    ax.fill_between(baseline_df['step'], 0, improvement_curve, 
                     where=(improvement_curve <= 0), color='red', alpha=0.3, 
                     label='Baseline Better')
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Improvement (%)')
    ax.set_title('SANS Relative Improvement')
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CONFIG['results_dir']}/comparison_plots.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✓ Plots saved to {CONFIG['results_dir']}/comparison_plots.png")

## 9. Performance Summary Table

In [None]:
# Create performance comparison table
summary_data = {
    'Metric': [
        'Initial Loss',
        'Final Loss',
        'Loss Reduction (%)',
        'Best Loss',
        'Convergence Step (90%)',
        'Training Time (s)',
        'Steps per Second'
    ],
    'Baseline iREM': [],
    'SANS-Modified': [],
    'Improvement': []
}

# Calculate metrics for both models
for df, col_name in [(baseline_df, 'Baseline iREM'), (sans_df, 'SANS-Modified')]:
    initial_loss = df['loss'].iloc[0]
    final_loss = df['loss'].iloc[-1]
    best_loss = df['loss'].min()
    loss_reduction = (initial_loss - final_loss) / initial_loss * 100
    conv_step = get_convergence_step(df)
    train_time = df['time'].iloc[-1]
    steps_per_sec = len(df) / train_time if train_time > 0 else 0
    
    summary_data[col_name] = [
        f"{initial_loss:.6f}",
        f"{final_loss:.6f}",
        f"{loss_reduction:.1f}",
        f"{best_loss:.6f}",
        f"{conv_step}",
        f"{train_time:.1f}",
        f"{steps_per_sec:.1f}"
    ]

# Calculate improvements
for i, metric in enumerate(summary_data['Metric']):
    try:
        baseline_val = float(summary_data['Baseline iREM'][i].replace(',', ''))
        sans_val = float(summary_data['SANS-Modified'][i].replace(',', ''))
        
        if metric in ['Final Loss', 'Best Loss', 'Convergence Step (90%)', 'Training Time (s)']:
            # Lower is better
            imp = (baseline_val - sans_val) / baseline_val * 100 if baseline_val != 0 else 0
        else:
            # Higher is better
            imp = (sans_val - baseline_val) / baseline_val * 100 if baseline_val != 0 else 0
        
        summary_data['Improvement'].append(f"{imp:+.1f}%")
    except:
        summary_data['Improvement'].append("-")

# Display table
summary_df = pd.DataFrame(summary_data)
print("\n" + "="*80)
print("PERFORMANCE COMPARISON TABLE")
print("="*80)
print(summary_df.to_string(index=False))

# Save summary
summary_df.to_csv(f"{CONFIG['results_dir']}/performance_summary.csv", index=False)
print(f"\n✓ Performance summary saved to {CONFIG['results_dir']}/performance_summary.csv")

## 10. Final Conclusions

In [None]:
print("\n" + "="*80)
print("SMOKE TEST CONCLUSIONS")
print("="*80)

print(f"\n📊 Dataset: {CONFIG['dataset']} (rank={CONFIG['rank']})")
print(f"🔧 Model: {CONFIG['model']}")
print(f"⚙️  Training Steps: {CONFIG['num_steps']}")
print(f"🎯 Batch Size: {CONFIG['batch_size']}")
print(f"🔄 Diffusion Steps: {CONFIG['diffusion_steps']}")

print("\n" + "-"*40)
print("KEY FINDINGS")
print("-"*40)

if improvement > 0:
    print(f"\n✅ SANS Advantages:")
    print(f"   • {improvement:.1f}% better final loss")
    if sans_conv < baseline_conv:
        print(f"   • {baseline_conv/sans_conv:.2f}x faster convergence")
    print(f"   • Self-adversarial sampling improves gradient quality")
else:
    print(f"\n⚠️  Mixed Results:")
    print(f"   • Baseline performed better in this short test")
    print(f"   • SANS may need more steps to show benefits")

print(f"\n📝 SANS Configuration Used:")
print(f"   • Number of negatives (M): {CONFIG['sans_num_negs']}")
print(f"   • Temperature (α): {CONFIG['sans_temp']}")
print(f"   • Temperature schedule: {'Enabled' if CONFIG['sans_temp_schedule'] else 'Disabled'}")

print("\n" + "-"*40)
print("RECOMMENDATIONS")
print("-"*40)
print("\n1. For full evaluation:")
print("   • Run longer training (50k-100k steps)")
print("   • Test on harder datasets (sudoku, connectivity)")
print("   • Try different SANS hyperparameters")

print("\n2. For optimization:")
print("   • Increase sans_num_negs for harder problems")
print("   • Adjust temperature based on dataset difficulty")
print("   • Enable AMP for faster training on compatible GPUs")

print("\n" + "="*80)
print("SMOKE TEST COMPLETED SUCCESSFULLY")
print("="*80)

# List all output files
print("\n📁 Output Files:")
for file in Path(CONFIG['results_dir']).rglob('*'):
    if file.is_file():
        print(f"   • {file}")