# 02 - Train Models

This notebook trains the 4 models needed for the experiment:

1. **Model A1** (Spurious, seed 1): ERM on Env A only
2. **Model A2** (Spurious, seed 2): ERM on Env A only
3. **Model R1** (Robust, seed 1): ERM on mixed Env A + Env B
4. **Model R2** (Robust, seed 2): ERM on mixed Env A + Env B

## Expected behavior:
- Spurious models (A1, A2) will achieve high ID accuracy but low OOD accuracy
- Robust models (R1, R2) will have similar ID and OOD accuracy

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json

from src.config import (
    get_config, set_seed, get_device,
    CHECKPOINTS_DIR, FIGURES_DIR, METRICS_DIR
)

config = get_config()
device = get_device()

print(f"Device: {device}")
print(f"Checkpoints will be saved to: {CHECKPOINTS_DIR}")

In [None]:
from src.data import (
    create_env_a_dataset,
    create_no_patch_dataset,
    create_mixed_env_dataset,
    get_dataloaders,
)
from src.models import create_model, count_parameters
from src.train import Trainer, save_training_history
from src.plotting import plot_training_curves, plot_multiple_training_curves, save_figure

## 1. Create Datasets and DataLoaders

In [None]:
# Create test datasets (same for all models)
test_id = create_env_a_dataset(train=False, config=config)  # ID test
test_ood = create_no_patch_dataset(train=False, config=config)  # OOD test

# Create training datasets
train_spurious = create_env_a_dataset(train=True, config=config)  # For A1, A2
train_robust = create_mixed_env_dataset(env_a_fraction=0.5, train=True, config=config)  # For R1, R2

print("Datasets created:")
print(f"  Spurious training (Env A): {len(train_spurious)} samples")
print(f"  Robust training (Mixed): {len(train_robust)} samples")
print(f"  ID test: {len(test_id)} samples")
print(f"  OOD test: {len(test_ood)} samples")

In [None]:
# Create dataloaders
batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']

loaders_spurious = get_dataloaders(train_spurious, test_id, test_ood, config)
loaders_robust = get_dataloaders(train_robust, test_id, test_ood, config)

print(f"\nDataLoaders created with batch_size={batch_size}")

## 2. Training Function

In [None]:
def train_model(name, seed, dataloaders, config, device):
    """
    Train a single model and save checkpoint + history.
    
    Args:
        name: Model name (e.g., 'A1', 'R1')
        seed: Random seed for this model
        dataloaders: Dictionary with 'train', 'test_id', 'test_ood' loaders
        config: Configuration dictionary
        device: Torch device
    
    Returns:
        model: Trained model
        history: Training history dictionary
    """
    print(f"\n{'='*60}")
    print(f"Training Model {name} (seed={seed})")
    print(f"{'='*60}")
    
    # Set seed for reproducibility
    set_seed(seed)
    
    # Create model
    model = create_model(config)
    model = model.to(device)
    
    print(f"Model parameters: {count_parameters(model):,}")
    
    # Create trainer
    trainer = Trainer(model, device, config)
    
    # Train
    checkpoint_path = CHECKPOINTS_DIR / f"model_{name}.pt"
    history = trainer.train(
        train_loader=dataloaders['train'],
        test_id_loader=dataloaders['test_id'],
        test_ood_loader=dataloaders['test_ood'],
        num_epochs=config['training']['num_epochs'],
        verbose=True,
        checkpoint_path=checkpoint_path,
    )
    
    # Save history
    history_path = METRICS_DIR / f"history_{name}.json"
    save_training_history(history, history_path)
    
    print(f"\nModel {name} training complete!")
    print(f"  Final ID accuracy: {history['id_acc'][-1]*100:.2f}%")
    print(f"  Final OOD accuracy: {history['ood_acc'][-1]*100:.2f}%")
    print(f"  Checkpoint saved to: {checkpoint_path}")
    print(f"  History saved to: {history_path}")
    
    return model, history

## 3. Train Spurious Model A1

In [None]:
model_A1, history_A1 = train_model(
    name='A1',
    seed=config['seeds']['model_A1'],
    dataloaders=loaders_spurious,
    config=config,
    device=device,
)

## 4. Train Spurious Model A2

In [None]:
model_A2, history_A2 = train_model(
    name='A2',
    seed=config['seeds']['model_A2'],
    dataloaders=loaders_spurious,
    config=config,
    device=device,
)

## 5. Train Robust Model R1

In [None]:
model_R1, history_R1 = train_model(
    name='R1',
    seed=config['seeds']['model_R1'],
    dataloaders=loaders_robust,
    config=config,
    device=device,
)

## 6. Train Robust Model R2

In [None]:
model_R2, history_R2 = train_model(
    name='R2',
    seed=config['seeds']['model_R2'],
    dataloaders=loaders_robust,
    config=config,
    device=device,
)

## 7. Visualize Training Curves

In [None]:
# Individual training curves for each model
import matplotlib.pyplot as plt

fig = plot_training_curves(history_A1, title="Model A1 (Spurious, seed 1)", save_name='training_A1')
plt.show()

fig = plot_training_curves(history_A2, title="Model A2 (Spurious, seed 2)", save_name='training_A2')
plt.show()

fig = plot_training_curves(history_R1, title="Model R1 (Robust, seed 1)", save_name='training_R1')
plt.show()

fig = plot_training_curves(history_R2, title="Model R2 (Robust, seed 2)", save_name='training_R2')
plt.show()

In [None]:
# Comparison plots
histories = {
    'A1 (Spurious)': history_A1,
    'A2 (Spurious)': history_A2,
    'R1 (Robust)': history_R1,
    'R2 (Robust)': history_R2,
}

# ID accuracy comparison
fig = plot_multiple_training_curves(
    histories, 
    metric='id_acc', 
    title='ID Test Accuracy Comparison',
    save_name='comparison_id_acc'
)
plt.show()

# OOD accuracy comparison
fig = plot_multiple_training_curves(
    histories, 
    metric='ood_acc', 
    title='OOD Test Accuracy Comparison',
    save_name='comparison_ood_acc'
)
plt.show()

In [None]:
# Combined comparison figure
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = {'A1': 'tab:red', 'A2': 'tab:orange', 'R1': 'tab:blue', 'R2': 'tab:green'}
linestyles = {'A1': '-', 'A2': '--', 'R1': '-', 'R2': '--'}

for name, history in [('A1', history_A1), ('A2', history_A2), ('R1', history_R1), ('R2', history_R2)]:
    epochs = range(1, len(history['id_acc']) + 1)
    
    axes[0].plot(epochs, [x*100 for x in history['id_acc']], 
                 label=name, color=colors[name], linestyle=linestyles[name], linewidth=2)
    axes[1].plot(epochs, [x*100 for x in history['ood_acc']], 
                 label=name, color=colors[name], linestyle=linestyles[name], linewidth=2)

axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy (%)')
axes[0].set_title('ID Test Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('OOD Test Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle('Training Progress: Spurious vs Robust Models', fontsize=14, y=1.02)
plt.tight_layout()
save_figure(fig, 'training_comparison_combined')
plt.show()

## 8. Summary Statistics

In [None]:
# Compute and display summary statistics
summary = {}

for name, history in [('A1', history_A1), ('A2', history_A2), ('R1', history_R1), ('R2', history_R2)]:
    final_id_acc = history['id_acc'][-1]
    final_ood_acc = history['ood_acc'][-1]
    ood_drop = final_id_acc - final_ood_acc
    
    summary[name] = {
        'id_acc': final_id_acc,
        'ood_acc': final_ood_acc,
        'ood_drop': ood_drop,
        'final_train_loss': history['train_loss'][-1],
    }

print("\nFinal Model Performance:")
print("=" * 70)
print(f"{'Model':<10} {'ID Acc':<12} {'OOD Acc':<12} {'OOD Drop':<12} {'Train Loss':<12}")
print("-" * 70)

for name, stats in summary.items():
    model_type = "Spurious" if name.startswith('A') else "Robust"
    print(f"{name} ({model_type[0]})   {stats['id_acc']*100:>6.2f}%      {stats['ood_acc']*100:>6.2f}%      "
          f"{stats['ood_drop']*100:>+6.2f}%      {stats['final_train_loss']:.4f}")

print("=" * 70)

In [None]:
# Verify expected behavior
print("\nVerification:")
print("-" * 50)

# Spurious models should have large OOD drop
spurious_avg_drop = (summary['A1']['ood_drop'] + summary['A2']['ood_drop']) / 2
robust_avg_drop = (summary['R1']['ood_drop'] + summary['R2']['ood_drop']) / 2

print(f"Average OOD drop (Spurious A1, A2): {spurious_avg_drop*100:.2f}%")
print(f"Average OOD drop (Robust R1, R2):   {robust_avg_drop*100:.2f}%")

if spurious_avg_drop > robust_avg_drop + 0.1:
    print("\n[PASS] Spurious models show significantly larger OOD drop!")
else:
    print("\n[WARNING] OOD drop difference is smaller than expected.")
    print("          This may affect the semantic barrier analysis.")

In [None]:
# Save summary to JSON
summary_path = METRICS_DIR / 'training_summary.json'
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)
print(f"\nSummary saved to: {summary_path}")

## 9. Final Summary

In [None]:
print("\n" + "=" * 60)
print("MODEL TRAINING COMPLETE")
print("=" * 60)
print(f"""
Models trained:

Spurious Models (trained on Env A only):
  - A1: ID={summary['A1']['id_acc']*100:.1f}%, OOD={summary['A1']['ood_acc']*100:.1f}%, Drop={summary['A1']['ood_drop']*100:+.1f}%
  - A2: ID={summary['A2']['id_acc']*100:.1f}%, OOD={summary['A2']['ood_acc']*100:.1f}%, Drop={summary['A2']['ood_drop']*100:+.1f}%

Robust Models (trained on mixed Env A + B):
  - R1: ID={summary['R1']['id_acc']*100:.1f}%, OOD={summary['R1']['ood_acc']*100:.1f}%, Drop={summary['R1']['ood_drop']*100:+.1f}%
  - R2: ID={summary['R2']['id_acc']*100:.1f}%, OOD={summary['R2']['ood_acc']*100:.1f}%, Drop={summary['R2']['ood_drop']*100:+.1f}%

Checkpoints saved:
  - {CHECKPOINTS_DIR / 'model_A1.pt'}
  - {CHECKPOINTS_DIR / 'model_A2.pt'}
  - {CHECKPOINTS_DIR / 'model_R1.pt'}
  - {CHECKPOINTS_DIR / 'model_R2.pt'}

Training histories saved:
  - {METRICS_DIR / 'history_A1.json'}
  - {METRICS_DIR / 'history_A2.json'}
  - {METRICS_DIR / 'history_R1.json'}
  - {METRICS_DIR / 'history_R2.json'}

Figures saved:
  - Training curves for each model
  - Comparison plots

Next: Run 03_mechanism_verification.ipynb to quantify spurious reliance.
""")