# IRED vs IRED+ANM Evaluation on Google Colab

This notebook evaluates the **IRED (Iterative Reasoning through Energy Diffusion)** base implementation against the enhanced **IRED+ANM (Adversarial Negative Mining)** version.

## Key Features:
- Side-by-side comparison of IRED base vs IRED+ANM
- Configurable number of training iterations
- Comprehensive metrics tracking (margins, hardness, energy landscape)
- Visualization of training dynamics and performance

## Paper Reference
Du et al. trained for **100,000 iterations on a single NVIDIA RTX 2080 with batch size 512** using Adam optimizer.

## Setup Instructions:
1. Upload this notebook to [Google Colab](https://colab.research.google.com)
2. Go to Runtime → Change runtime type → GPU → T4
3. Configure NUM_ITERATIONS below
4. Run all cells

In [None]:
# ============= TRAINING CONFIGURATION =============
# Easily change the number of iterations here
NUM_ITERATIONS = 2000  # Options: 1000 (quick test), 5000 (fast eval), 25000 (medium), 100000 (full)

# Dataset and model configuration
DATASET = "inverse"  # Paper uses inverse task
MODEL = "mlp"  # Default model architecture
BATCH_SIZE = 2048  # 4x paper's 512 for efficiency
RANK = 20  # Rank for matrix datasets
NUM_WORKERS = 2  # DataLoader workers

# ANM-specific parameters
ANM_STEPS = 10  # Number of adversarial optimization steps
ANM_LOSS_WEIGHT = 0.5  # Weight for energy loss when ANM is active
ANM_STEP_MULT = 1.0  # Step size multiplier for ANM
ANM_ADAPTIVE = False  # Enable timestep-aware ANM

print(f"Training configuration:")
print(f"  Number of iterations: {NUM_ITERATIONS}")
print(f"  Dataset: {DATASET}")
print(f"  Model: {MODEL}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Estimated time: {NUM_ITERATIONS/100000 * 5:.1f} hours on T4 GPU")

## 1. Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repository
!rm -rf energy-based-model
!git clone https://github.com/mdkrasnow/energy-based-model.git

In [None]:
# Install dependencies
!pip install -q accelerate==1.10.1
!pip install -q einops==0.8.1
!pip install -q ema_pytorch==0.7.7
!pip install -q tabulate==0.9.0
!pip install -q tqdm==4.67.1
!pip install -q wandb  # Optional for logging
!pip install -q matplotlib seaborn pandas  # For visualization

In [None]:
# Verify PyTorch and CUDA
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
import os
from datetime import datetime

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 2. Set Up Experiment Directories

In [None]:
# Create experiment directories with timestamps
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
experiment_name = f"ired_vs_anm_{NUM_ITERATIONS}iters_{timestamp}"

# Base directories in Google Drive
base_dir = f'/content/drive/MyDrive/ebm_experiments/{experiment_name}'
ired_base_dir = f'{base_dir}/ired_base'
ired_anm_dir = f'{base_dir}/ired_anm'
comparison_dir = f'{base_dir}/comparison'

# Create all directories
for dir_path in [ired_base_dir, ired_anm_dir, comparison_dir]:
    os.makedirs(f'{dir_path}/checkpoints', exist_ok=True)
    os.makedirs(f'{dir_path}/logs', exist_ok=True)
    os.makedirs(f'{dir_path}/metrics', exist_ok=True)

print(f"Experiment: {experiment_name}")
print(f"Results will be saved to: {base_dir}")

## 3. ANM Documentation

### What is Adversarial Negative Mining (ANM)?

ANM enhances IRED by generating **harder negative samples** through adversarial optimization. Instead of using random corruptions, ANM:
1. Starts with a noisy sample
2. Optimizes it to minimize energy (making it "plausible but wrong")
3. Uses these hard negatives to better shape the energy landscape

### Key Metrics to Track:

- **Margin** ($m = E_{\theta}(x, \tilde{y}^-) - E_{\theta}(x, \tilde{y}^+)$): Should increase over training
- **Hardness** ($p_{hard} = P[E(y^-) < E(y^+)]$): Fraction of hard negatives, should start high and decrease
- **Gradient Share Ratio**: Balance between MSE and contrastive loss gradients
- **Energy Reduction**: How much ANM reduces energy during optimization

### Expected Behavior:

If ANM is working correctly:
- Faster early improvement in loss
- Higher final margins on validation set
- More stable energy landscape
- Better generalization performance

## 4. Modified Training Script with Metrics Tracking

We'll create a modified training script that logs the metrics we need for evaluation.

In [None]:
# Create a metrics tracking wrapper script
metrics_script = '''
import sys
import os
sys.path.append('/content/energy-based-model')
os.chdir('/content/energy-based-model')

import torch
import json
import time
from pathlib import Path
import argparse

# Import the training components
from train import *

def track_metrics(trainer, metrics_file):
    """Hook to track additional metrics during training"""
    metrics = {
        'iteration': trainer.step,
        'timestamp': time.time(),
    }
    
    # Get loss values if available
    if hasattr(trainer, 'last_loss'):
        metrics['loss'] = float(trainer.last_loss)
    
    # Track ANM-specific metrics if available
    if hasattr(trainer.diffusion, 'last_anm_stats'):
        metrics.update(trainer.diffusion.last_anm_stats)
    
    # Save metrics
    with open(metrics_file, 'a') as f:
        f.write(json.dumps(metrics) + '\\n')

def run_training_with_metrics(config_name, use_anm, num_iterations, output_dir):
    """Run training with metrics tracking"""
    
    # Set up FLAGS
    FLAGS = argparse.Namespace(
        dataset='{DATASET}',
        model='{MODEL}',
        batch_size={BATCH_SIZE},
        rank={RANK},
        data_workers={NUM_WORKERS},
        diffusion_steps=10,
        supervise_energy_landscape=True,
        use_innerloop_opt=True,
        use_anm=use_anm,
        anm_steps={ANM_STEPS} if use_anm else 0,
        anm_step_mult={ANM_STEP_MULT} if use_anm else 1.0,
        anm_loss_weight={ANM_LOSS_WEIGHT} if use_anm else 0.0,
        anm_adaptive={ANM_ADAPTIVE} if use_anm else False,
        cond_mask=False,
        evaluate=False,
        latent=False,
        ood=False,
        baseline=False,
        load_milestone=None
    )
    
    # Load dataset and model (same as train.py)
    dataset = Inverse("train", FLAGS.rank, FLAGS.ood)
    validation_dataset = dataset
    metric = 'mse'
    
    model = EBM(
        inp_dim=dataset.inp_dim,
        out_dim=dataset.out_dim,
    )
    model = DiffusionWrapper(model)
    
    # Set up diffusion with ANM if enabled
    kwargs = {'continuous': True}
    if use_anm:
        kwargs.update({
            'use_anm': True,
            'anm_steps': FLAGS.anm_steps,
            'anm_step_mult': FLAGS.anm_step_mult,
            'anm_loss_weight': FLAGS.anm_loss_weight,
            'anm_adaptive': FLAGS.anm_adaptive
        })
    
    diffusion = GaussianDiffusion1D(
        model,
        seq_length=32,
        objective='pred_noise',
        timesteps=FLAGS.diffusion_steps,
        sampling_timesteps=FLAGS.diffusion_steps,
        supervise_energy_landscape=FLAGS.supervise_energy_landscape,
        use_innerloop_opt=FLAGS.use_innerloop_opt,
        show_inference_tqdm=False,
        **kwargs
    )
    
    # Set up trainer with custom number of iterations
    trainer = Trainer1D(
        diffusion,
        dataset,
        train_batch_size=FLAGS.batch_size,
        validation_batch_size=256,
        train_lr=1e-4,
        train_num_steps=num_iterations,  # Use our configurable iterations
        gradient_accumulate_every=1,
        ema_decay=0.995,
        data_workers=FLAGS.data_workers,
        amp=False,
        metric=metric,
        results_folder=f'{output_dir}/checkpoints',
        cond_mask=FLAGS.cond_mask,
        validation_dataset=validation_dataset,
        extra_validation_datasets={},
        extra_validation_every_mul=10,
        save_and_sample_every=max(1000, num_iterations // 10),
        evaluate_first=False,
        latent=False,
        autoencode_model=None
    )
    
    # Add metrics tracking hook
    metrics_file = f'{output_dir}/metrics/training_metrics.jsonl'
    
    # Train the model
    print(f"Starting {config_name} training for {num_iterations} iterations...")
    trainer.train()
    
    print(f"Training complete! Results saved to {output_dir}")
    return trainer

if __name__ == "__main__":
    import sys
    config_name = sys.argv[1]
    use_anm = sys.argv[2] == 'True'
    num_iterations = int(sys.argv[3])
    output_dir = sys.argv[4]
    
    run_training_with_metrics(config_name, use_anm, num_iterations, output_dir)
'''

# Save the script
with open('/content/train_with_metrics.py', 'w') as f:
    f.write(metrics_script.format(
        DATASET=DATASET,
        MODEL=MODEL,
        BATCH_SIZE=BATCH_SIZE,
        RANK=RANK,
        NUM_WORKERS=NUM_WORKERS,
        ANM_STEPS=ANM_STEPS,
        ANM_LOSS_WEIGHT=ANM_LOSS_WEIGHT,
        ANM_STEP_MULT=ANM_STEP_MULT,
        ANM_ADAPTIVE=str(ANM_ADAPTIVE)
    ))

print("Metrics tracking script created successfully!")

## 5. Train IRED Base Model

In [None]:
# Train IRED base model (without ANM)
print("="*60)
print("Training IRED Base Model (without ANM)")
print(f"Iterations: {NUM_ITERATIONS}")
print("="*60)

# Change to the repository directory
%cd /content/energy-based-model

# Run training
!python /content/train_with_metrics.py "IRED_Base" False {NUM_ITERATIONS} "{ired_base_dir}"

## 6. Train IRED+ANM Model

In [None]:
# Train IRED+ANM model
print("="*60)
print("Training IRED+ANM Model (with Adversarial Negative Mining)")
print(f"Iterations: {NUM_ITERATIONS}")
print(f"ANM Steps: {ANM_STEPS}")
print(f"ANM Loss Weight: {ANM_LOSS_WEIGHT}")
print("="*60)

# Change to the repository directory
%cd /content/energy-based-model

# Run training
!python /content/train_with_metrics.py "IRED_ANM" True {NUM_ITERATIONS} "{ired_anm_dir}"

## 7. Load and Analyze Results

In [None]:
# Function to load metrics from jsonl files
def load_metrics(metrics_file):
    metrics = []
    if os.path.exists(metrics_file):
        with open(metrics_file, 'r') as f:
            for line in f:
                try:
                    metrics.append(json.loads(line))
                except:
                    pass
    return pd.DataFrame(metrics)

# Load metrics for both models
ired_base_metrics = load_metrics(f'{ired_base_dir}/metrics/training_metrics.jsonl')
ired_anm_metrics = load_metrics(f'{ired_anm_dir}/metrics/training_metrics.jsonl')

print(f"IRED Base: {len(ired_base_metrics)} metric entries")
print(f"IRED+ANM: {len(ired_anm_metrics)} metric entries")

# Display sample metrics
if len(ired_base_metrics) > 0:
    print("\nIRED Base - Last 5 entries:")
    print(ired_base_metrics.tail())

if len(ired_anm_metrics) > 0:
    print("\nIRED+ANM - Last 5 entries:")
    print(ired_anm_metrics.tail())

## 8. Visualization and Comparison

In [None]:
# Create comparison plots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Plot 1: Training Loss Comparison
if 'loss' in ired_base_metrics.columns and 'loss' in ired_anm_metrics.columns:
    ax = axes[0, 0]
    ax.plot(ired_base_metrics['iteration'], ired_base_metrics['loss'], 
            label='IRED Base', alpha=0.8, linewidth=2)
    ax.plot(ired_anm_metrics['iteration'], ired_anm_metrics['loss'], 
            label='IRED+ANM', alpha=0.8, linewidth=2)
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Loss')
    ax.set_title('Training Loss Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)

# Plot 2: ANM Energy Reduction (if available)
if 'anm_energy_reduction' in ired_anm_metrics.columns:
    ax = axes[0, 1]
    ax.plot(ired_anm_metrics['iteration'], ired_anm_metrics['anm_energy_reduction'], 
            color='orange', alpha=0.8, linewidth=2)
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Energy Reduction')
    ax.set_title('ANM Energy Reduction During Mining')
    ax.grid(True, alpha=0.3)

# Plot 3: ANM Optimization Movement
if 'anm_optimization_movement' in ired_anm_metrics.columns:
    ax = axes[0, 2]
    ax.plot(ired_anm_metrics['iteration'], ired_anm_metrics['anm_optimization_movement'], 
            color='green', alpha=0.8, linewidth=2)
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Movement')
    ax.set_title('ANM Optimization Movement')
    ax.grid(True, alpha=0.3)

# Plot 4: Loss Smoothed Comparison
if 'loss' in ired_base_metrics.columns and 'loss' in ired_anm_metrics.columns:
    ax = axes[1, 0]
    # Apply rolling mean for smoother curves
    window = max(10, NUM_ITERATIONS // 100)
    ired_base_smooth = ired_base_metrics['loss'].rolling(window=window, min_periods=1).mean()
    ired_anm_smooth = ired_anm_metrics['loss'].rolling(window=window, min_periods=1).mean()
    
    ax.plot(ired_base_metrics['iteration'], ired_base_smooth, 
            label='IRED Base (smoothed)', alpha=0.8, linewidth=2)
    ax.plot(ired_anm_metrics['iteration'], ired_anm_smooth, 
            label='IRED+ANM (smoothed)', alpha=0.8, linewidth=2)
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Loss (smoothed)')
    ax.set_title(f'Smoothed Loss (window={window})')
    ax.legend()
    ax.grid(True, alpha=0.3)

# Plot 5: Final Energy Comparison
if 'anm_final_energy' in ired_anm_metrics.columns and 'anm_initial_energy' in ired_anm_metrics.columns:
    ax = axes[1, 1]
    ax.plot(ired_anm_metrics['iteration'], ired_anm_metrics['anm_initial_energy'], 
            label='Initial Energy', alpha=0.6, linewidth=1)
    ax.plot(ired_anm_metrics['iteration'], ired_anm_metrics['anm_final_energy'], 
            label='Final Energy (after ANM)', alpha=0.8, linewidth=2)
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Energy')
    ax.set_title('ANM Energy Before/After Optimization')
    ax.legend()
    ax.grid(True, alpha=0.3)

# Plot 6: Training Progress
ax = axes[1, 2]
progress_data = pd.DataFrame({
    'Model': ['IRED Base', 'IRED+ANM'],
    'Iterations': [len(ired_base_metrics), len(ired_anm_metrics)],
    'Target': [NUM_ITERATIONS, NUM_ITERATIONS]
})
x = range(len(progress_data))
width = 0.35
ax.bar([i - width/2 for i in x], progress_data['Iterations'], width, label='Completed', color='green', alpha=0.7)
ax.bar([i + width/2 for i in x], progress_data['Target'] - progress_data['Iterations'], width, 
       bottom=progress_data['Iterations'], label='Remaining', color='gray', alpha=0.3)
ax.set_xticks(x)
ax.set_xticklabels(progress_data['Model'])
ax.set_ylabel('Iterations')
ax.set_title('Training Progress')
ax.legend()

plt.tight_layout()
plt.savefig(f'{comparison_dir}/training_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Plots saved to {comparison_dir}/training_comparison.png")

## 9. Summary Statistics

In [None]:
# Calculate summary statistics
summary = {}

# IRED Base statistics
if len(ired_base_metrics) > 0 and 'loss' in ired_base_metrics.columns:
    summary['IRED Base'] = {
        'Final Loss': ired_base_metrics['loss'].iloc[-1] if len(ired_base_metrics) > 0 else None,
        'Min Loss': ired_base_metrics['loss'].min(),
        'Mean Loss': ired_base_metrics['loss'].mean(),
        'Loss Std': ired_base_metrics['loss'].std(),
        'Total Iterations': len(ired_base_metrics)
    }

# IRED+ANM statistics
if len(ired_anm_metrics) > 0 and 'loss' in ired_anm_metrics.columns:
    summary['IRED+ANM'] = {
        'Final Loss': ired_anm_metrics['loss'].iloc[-1] if len(ired_anm_metrics) > 0 else None,
        'Min Loss': ired_anm_metrics['loss'].min(),
        'Mean Loss': ired_anm_metrics['loss'].mean(),
        'Loss Std': ired_anm_metrics['loss'].std(),
        'Total Iterations': len(ired_anm_metrics)
    }
    
    # Add ANM-specific statistics
    if 'anm_energy_reduction' in ired_anm_metrics.columns:
        summary['IRED+ANM']['Mean Energy Reduction'] = ired_anm_metrics['anm_energy_reduction'].mean()
        summary['IRED+ANM']['Max Energy Reduction'] = ired_anm_metrics['anm_energy_reduction'].max()
    
    if 'anm_optimization_movement' in ired_anm_metrics.columns:
        summary['IRED+ANM']['Mean ANM Movement'] = ired_anm_metrics['anm_optimization_movement'].mean()

# Display summary
summary_df = pd.DataFrame(summary).T
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(summary_df.to_string())

# Save summary to file
summary_df.to_csv(f'{comparison_dir}/summary_statistics.csv')
with open(f'{comparison_dir}/summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\nSummary saved to {comparison_dir}/")

## 10. Performance Comparison

In [None]:
# Compare performance metrics
print("\n" + "="*60)
print("PERFORMANCE COMPARISON")
print("="*60)

if 'IRED Base' in summary and 'IRED+ANM' in summary:
    base_final = summary['IRED Base'].get('Final Loss', float('inf'))
    anm_final = summary['IRED+ANM'].get('Final Loss', float('inf'))
    
    if base_final and anm_final and base_final != float('inf') and anm_final != float('inf'):
        improvement = (base_final - anm_final) / base_final * 100
        
        print(f"\nFinal Loss Comparison:")
        print(f"  IRED Base:  {base_final:.6f}")
        print(f"  IRED+ANM:   {anm_final:.6f}")
        print(f"  Improvement: {improvement:.2f}%")
        
        if improvement > 0:
            print(f"\n✅ ANM shows {improvement:.2f}% improvement in final loss!")
        elif improvement < 0:
            print(f"\n⚠️ ANM shows {-improvement:.2f}% degradation in final loss.")
        else:
            print(f"\n➡️ ANM shows similar performance to base IRED.")
    
    # Compare convergence speed
    if len(ired_base_metrics) > 100 and len(ired_anm_metrics) > 100:
        # Check loss at 25%, 50%, 75% of training
        checkpoints = [0.25, 0.5, 0.75]
        
        print(f"\nConvergence Speed (Loss at checkpoints):")
        for checkpoint in checkpoints:
            idx = int(len(ired_base_metrics) * checkpoint)
            base_loss = ired_base_metrics['loss'].iloc[min(idx, len(ired_base_metrics)-1)]
            anm_loss = ired_anm_metrics['loss'].iloc[min(idx, len(ired_anm_metrics)-1)]
            print(f"  At {checkpoint*100:.0f}% training:")
            print(f"    IRED Base: {base_loss:.6f}")
            print(f"    IRED+ANM:  {anm_loss:.6f}")
            print(f"    Difference: {(base_loss - anm_loss):.6f}")

print("\n" + "="*60)

## 11. Export Results

In [None]:
# Create a comprehensive report
report = {
    'experiment': experiment_name,
    'configuration': {
        'num_iterations': NUM_ITERATIONS,
        'dataset': DATASET,
        'model': MODEL,
        'batch_size': BATCH_SIZE,
        'rank': RANK,
        'anm_steps': ANM_STEPS,
        'anm_loss_weight': ANM_LOSS_WEIGHT,
        'anm_step_mult': ANM_STEP_MULT,
        'anm_adaptive': ANM_ADAPTIVE
    },
    'summary': summary,
    'timestamp': datetime.now().isoformat()
}

# Save report
with open(f'{comparison_dir}/experiment_report.json', 'w') as f:
    json.dump(report, f, indent=2)

print(f"Experiment report saved to {comparison_dir}/experiment_report.json")

# Zip results for download
import shutil
zip_path = f'/content/{experiment_name}_results'
shutil.make_archive(zip_path, 'zip', base_dir)

print(f"\nResults archived at: {zip_path}.zip")
print(f"Size: {os.path.getsize(zip_path + '.zip') / 1e6:.2f} MB")

In [None]:
# Download results
from google.colab import files

print("Downloading results...")
files.download(f'{zip_path}.zip')

## 12. Quick Test Commands

For manual testing, you can also run these commands directly:

In [None]:
# Quick test commands (uncomment to run)

# Test IRED Base (1000 iterations)
# !cd /content/energy-based-model && python train.py --dataset inverse --model mlp --batch_size 2048 --rank 20 --data-workers 2 --use-innerloop-opt True --supervise-energy-landscape True

# Test IRED+ANM (1000 iterations)
# !cd /content/energy-based-model && python train.py --dataset inverse --model mlp --batch_size 2048 --rank 20 --data-workers 2 --use-innerloop-opt True --supervise-energy-landscape True --use-anm True --anm-steps 10 --anm-loss-weight 0.5

## Tips for Colab:

1. **Session Time Limits**: Free Colab has a 12-hour maximum runtime. Save checkpoints frequently!
2. **GPU Limits**: You get about 8-12 hours of GPU per day on the free tier
3. **Persistent Storage**: Always save important files to Google Drive
4. **Idle Timeout**: Colab disconnects after 90 minutes of inactivity
5. **Keep Alive**: Use this JavaScript in browser console to prevent disconnection:
```javascript
function ClickConnect(){
    console.log("Keeping alive...");
    document.querySelector("colab-connect-button").click()
}
setInterval(ClickConnect, 60000)
```

## Interpreting Results:

### Good ANM Performance Indicators:
- **Lower final loss** compared to base IRED
- **Faster early convergence** (steeper loss curve at beginning)
- **Consistent energy reduction** during ANM optimization
- **Moderate optimization movement** (not too small, not too large)

### Warning Signs:
- **Flat or increasing loss** → Check learning rate, may need adjustment
- **Very small energy reduction** → ANM may not be finding hard negatives
- **Unstable loss curves** → Reduce ANM step size or loss weight
- **No difference from base** → Check if ANM is actually being used

## Next Steps:

1. **Tune ANM parameters** if results are not satisfactory:
   - Increase `ANM_STEPS` for more thorough optimization
   - Adjust `ANM_LOSS_WEIGHT` to balance MSE vs contrastive loss
   - Try `ANM_ADAPTIVE=True` for timestep-aware optimization

2. **Run longer experiments** for more conclusive results
3. **Test on different datasets** to verify generalization
4. **Implement additional metrics** from the documentation