# IRED Energy-Based Model Training on Google Colab (FREE T4 GPU)

This notebook implements the **IRED (Iterative Reasoning through Energy Diffusion)** paper's training setup on Google Colab's free T4 GPU.

## Paper Reference
Du et al. trained for **100,000 iterations on a single NVIDIA RTX 2080 with batch size 512** using Adam optimizer.

## Key Settings for IRED (not diffusion baseline):
1. **DO NOT** use `--diffusion_steps` flag
2. **DO** use `--use-innerloop-opt True` and `--supervise-energy-landscape True`
3. Set `--data-workers 2` to avoid DataLoader stalls

## Setup Instructions:
1. Upload this notebook to [Google Colab](https://colab.research.google.com)
2. Go to Runtime → Change runtime type → GPU → T4
3. Run the cells below

## Training Time Estimates:
- Paper's 100k iterations @ batch 512 = 51.2M examples seen
- With batch 2048 (4x larger): ~25k iterations for equivalent training
- **Estimated time on T4: ~1.3 hours**

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!rm -rf energy-based-model
!git clone https://github.com/mdkrasnow/energy-based-model.git

In [None]:
# Install dependencies
!pip install -q accelerate==1.10.1
!pip install -q einops==0.8.1
!pip install -q ema_pytorch==0.7.7
!pip install -q tabulate==0.9.0
!pip install -q tqdm==4.67.1
!pip install -q wandb  # Optional for logging

In [None]:
# Verify PyTorch and CUDA
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# Download data if needed
!mkdir -p data
# Add your data download commands here
# Example: !wget -O data/dataset.tar.gz https://example.com/dataset.tar.gz
# !tar -xzf data/dataset.tar.gz -C data/

In [None]:
# Set up checkpoint directory in Google Drive for persistence
import os
import shutil

# Make sure the target directories exist in Google Drive
os.makedirs('/content/drive/MyDrive/ebm_checkpoints', exist_ok=True)
os.makedirs('/content/drive/MyDrive/ebm_logs', exist_ok=True)

# If symlinks are not supported, fall back to removing and recreating directories
# Remove existing local directories if they exist
# if os.path.islink('./checkpoints') or os.path.exists('./checkpoints'):
#     shutil.rmtree('./checkpoints')
# if os.path.islink('./logs') or os.path.exists('./logs'):
#     shutil.rmtree('./logs')

# Create new directories that point to Google Drive using bind mount
!rm -rf ./checkpoints
!rm -rf ./logs
!mkdir ./checkpoints
!mkdir ./logs
!mount --bind /content/drive/MyDrive/ebm_checkpoints ./checkpoints
!mount --bind /content/drive/MyDrive/ebm_logs ./logs

In [None]:
# Training configuration - IRED paper settings
DATASET = "inverse"  # Paper uses inverse task
MODEL = "mlp"  # Default model architecture
BATCH_SIZE = 2048  # Can use 2048 (4x paper's 512) for efficiency
RANK = 20  # Rank for matrix datasets
NUM_WORKERS = 2  # DataLoader suggestion from log

# IMPORTANT: For true IRED (not diffusion baseline), we:
# 1. DO NOT set diffusion_steps (omit it)
# 2. DO set use_innerloop_opt and supervise_energy_landscape flags

print("Configuration:")
print(f"  Dataset: {DATASET}")
print(f"  Model: {MODEL}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Workers: {NUM_WORKERS}")
print("  Mode: IRED (with inner loop optimization)")
print("")
print("Training time estimates:")
print("  Paper's 100k iterations @ batch 512 = 51.2M examples")
print("  With batch 2048: ~25k iterations for same examples")
print("  Estimated time: ~1.3 hours on T4 GPU")

In [None]:
# Run IRED training (matching the paper's approach)
# This runs the actual IRED algorithm, NOT the diffusion baseline
%cd /content/energy-based-model
!python train.py \
    --dataset {DATASET} \
    --model {MODEL} \
    --batch_size {BATCH_SIZE} \
    --rank {RANK} \
    --data-workers {NUM_WORKERS} \
    --use-innerloop-opt True \
    --supervise-energy-landscape True

# Note: We intentionally DO NOT include --diffusion_steps for IRED
# The paper trains for 100k iterations at batch 512
# With batch 2048 (4x larger), ~25k iterations sees similar data

In [None]:
# Alternative training configurations

# Option 1: IRED with wrapper script (cleaner interface)
!python train_wrapper.py \
    --dataset {DATASET} \
    --model {MODEL} \
    --batch_size {BATCH_SIZE} \
    --num_workers {NUM_WORKERS} \
    --use_innerloop_opt \
    --supervise_energy_landscape \
    --checkpoint_dir ./checkpoints \
    --log_dir ./logs

# Option 2: Paper's exact settings (batch 512, longer training)
# !python train.py \
#     --dataset inverse \
#     --model mlp \
#     --batch_size 512 \
#     --rank 20 \
#     --data-workers 2 \
#     --use-innerloop-opt True \
#     --supervise-energy-landscape True

# Option 3: IRED with ANM for enhanced performance
# !python train_wrapper.py \
#     --dataset {DATASET} \
#     --model {MODEL} \
#     --batch_size {BATCH_SIZE} \
#     --num_workers {NUM_WORKERS} \
#     --use_innerloop_opt \
#     --supervise_energy_landscape \
#     --use_anm \
#     --anm_steps 10 \
#     --anm_loss_weight 0.5

# Option 4: Diffusion baseline (for comparison)
# !python train.py \
#     --dataset {DATASET} \
#     --model {MODEL} \
#     --batch_size {BATCH_SIZE} \
#     --diffusion_steps 10 \
#     --rank {RANK} \
#     --data-workers {NUM_WORKERS}

In [None]:
# Monitor GPU usage during training
!nvidia-smi

In [None]:
# Download results to local machine
from google.colab import files
import shutil

# Zip checkpoints and logs
shutil.make_archive('training_results', 'zip', '.', 'checkpoints')
files.download('training_results.zip')

## Tips for Colab:

1. **Session Time Limits**: Free Colab has a 12-hour maximum runtime. Save checkpoints frequently!
2. **GPU Limits**: You get about 8-12 hours of GPU per day on the free tier
3. **Persistent Storage**: Always save important files to Google Drive
4. **Idle Timeout**: Colab disconnects after 90 minutes of inactivity
5. **Keep Alive**: Use this JavaScript in browser console to prevent disconnection:
```javascript
function ClickConnect(){
    console.log("Keeping alive...");
    document.querySelector("colab-connect-button").click()
}
setInterval(ClickConnect, 60000)
```

## Alternative: Colab Pro
- $10/month for faster GPUs (V100), longer runtimes, and more RAM
- No quota requirements, instant access