# Robust Ensemble Knowledge Distillation Experiments

This notebook runs all experiments for extending AGRE-KD with:
1. **Experiment 1**: Class labels (α < 1, γ = 0)
2. **Experiment 2**: Feature distillation (α = 1, γ > 0)
3. **Experiment 3**: Combined (α < 1, γ > 0)

**Resume Support**: Each experiment checkpoints every 5 epochs. If Colab disconnects, re-run cells 1-5 then resume the experiment.

## Cell 1: Setup & Mount Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Install dependencies
!pip install -q wilds tqdm

# Verify GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Cell 2: Configuration & Paths

In [None]:
import os
import sys
import json

# ============================================================
# CONFIGURE THESE PATHS
# ============================================================
DRIVE_ROOT = '/content/drive/MyDrive/robust-ensemble-kd'
CODE_DIR = f'{DRIVE_ROOT}/light-code'
DATA_DIR = f'{DRIVE_ROOT}/data/waterbirds_v1.0'
TEACHER_DIR = f'{DRIVE_ROOT}/teacher_checkpoints'
CHECKPOINT_DIR = f'{DRIVE_ROOT}/checkpoints'
LOG_DIR = f'{DRIVE_ROOT}/logs'

# Create directories
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)

# Add code to path
sys.path.insert(0, CODE_DIR)

# Verify paths exist
print("Checking paths...")
assert os.path.exists(CODE_DIR), f"Code not found: {CODE_DIR}"
assert os.path.exists(DATA_DIR), f"Data not found: {DATA_DIR}"
assert os.path.exists(TEACHER_DIR), f"Teachers not found: {TEACHER_DIR}"
print("All paths verified!")

# List available checkpoints
ckpts = [f for f in os.listdir(TEACHER_DIR) if f.endswith('.pt') or f.endswith('.pth')]
print(f"\nFound {len(ckpts)} teacher checkpoints: {ckpts}")

## Cell 3: Load Data

In [None]:
from data import get_waterbirds_loaders

# Load data
print("Loading Waterbirds dataset...")
loaders = get_waterbirds_loaders(
    root_dir=DATA_DIR,
    batch_size=32,
    num_workers=2,
    augment=True
)

# Quick verification
batch = next(iter(loaders['train']))
print(f"\nBatch shapes:")
print(f"  Images: {batch['image'].shape}")
print(f"  Labels: {batch['label'].shape}")
print(f"  Groups: {batch['group'].shape}")

## Cell 4: Load Teachers

In [None]:
from models import get_teacher_model, load_dfr_checkpoint, load_teachers_from_dir

# Option 1: Load all teachers from directory
print("Loading teacher models...")
teachers = load_teachers_from_dir(
    checkpoint_dir=TEACHER_DIR,
    model_fn=lambda: get_teacher_model('resnet50', num_classes=2, pretrained=False),
    num_teachers=5,  # Adjust based on how many you have
    device='cuda'
)

print(f"\nLoaded {len(teachers)} teachers")

# Test forward pass
with torch.no_grad():
    test_batch = batch['image'][:4].cuda()
    test_out = teachers[0](test_batch)
    print(f"Test forward pass: {test_out.shape}")

# Use first teacher as biased reference model
biased_model = teachers[0]
print(f"Using teachers[0] as biased reference model")

## Cell 5: Define Experiment Runner

In [None]:
from config import Config
from train import train_student
from eval import print_results

def run_experiment(exp_name, alpha, gamma, use_agre=True):
    """
    Run a single experiment with resume support.
    
    Args:
        exp_name: Name for saving (e.g., 'baseline', 'exp1_alpha07')
        alpha: Weight for KD vs CE (1.0 = pure KD)
        gamma: Weight for feature distillation (0.0 = no features)
        use_agre: Use gradient-based teacher weighting
    """
    print(f"\n{'='*60}")
    print(f"EXPERIMENT: {exp_name}")
    print(f"  alpha={alpha}, gamma={gamma}, use_agre={use_agre}")
    print(f"{'='*60}\n")
    
    # Check if already completed
    results_path = os.path.join(CHECKPOINT_DIR, f'student_{exp_name}_results.pt')
    if os.path.exists(results_path):
        print(f"Experiment already completed! Loading results...")
        results = torch.load(results_path)
        print_results(results['test_results'], f"{exp_name} (cached)")
        return results['test_results']
    
    # Check for resume checkpoint
    resume_path = os.path.join(CHECKPOINT_DIR, f'student_{exp_name}_latest.pt')
    if os.path.exists(resume_path):
        print(f"Found checkpoint, will resume...")
    else:
        resume_path = None
    
    # Create config
    config = Config(
        data_dir=DATA_DIR,
        checkpoint_dir=CHECKPOINT_DIR,
        alpha=alpha,
        gamma=gamma,
        epochs=30,
        lr=0.001,
        batch_size=32,
    )
    
    # Train
    student, history, test_results = train_student(
        config=config,
        teachers=teachers,
        biased_model=biased_model,
        exp_name=exp_name,
        use_agre=use_agre,
        checkpoint_path=resume_path,
    )
    
    # Save to log
    log_experiment(exp_name, alpha, gamma, test_results)
    
    return test_results


def log_experiment(exp_name, alpha, gamma, results):
    """Log experiment results to JSON file."""
    log_path = os.path.join(LOG_DIR, 'experiment_results.json')
    
    if os.path.exists(log_path):
        with open(log_path, 'r') as f:
            all_results = json.load(f)
    else:
        all_results = {}
    
    all_results[exp_name] = {
        'alpha': alpha,
        'gamma': gamma,
        'wga': results['wga'],
        'avg_acc': results['avg_acc'],
        'group_accs': results['group_accs'],
        'acc_gap': results['acc_gap'],
    }
    
    with open(log_path, 'w') as f:
        json.dump(all_results, f, indent=2)
    
    print(f"Results saved to {log_path}")


print("Experiment runner ready!")

---
# Run Experiments

Run these cells one at a time. Each experiment takes ~2 hours.

If Colab disconnects, re-run cells 1-5, then resume the interrupted experiment cell.

## Cell 6: Baseline (alpha=1.0, gamma=0.0)

In [None]:
# BASELINE: Pure AGRE-KD (no class labels, no features)
baseline_results = run_experiment(
    exp_name='baseline',
    alpha=1.0,
    gamma=0.0,
    use_agre=True
)

## Cell 7: Exp1-a (alpha=0.5, gamma=0.0)

In [None]:
# EXPERIMENT 1a: Add class labels (alpha=0.5)
exp1a_results = run_experiment(
    exp_name='exp1_alpha05',
    alpha=0.5,
    gamma=0.0,
    use_agre=True
)

## Cell 8: Exp1-b (alpha=0.7, gamma=0.0)

In [None]:
# EXPERIMENT 1b: Add class labels (alpha=0.7)
exp1b_results = run_experiment(
    exp_name='exp1_alpha07',
    alpha=0.7,
    gamma=0.0,
    use_agre=True
)

## Cell 9: Exp1-c (alpha=0.9, gamma=0.0)

In [None]:
# EXPERIMENT 1c: Add class labels (alpha=0.9)
exp1c_results = run_experiment(
    exp_name='exp1_alpha09',
    alpha=0.9,
    gamma=0.0,
    use_agre=True
)

## Cell 10: Exp2-a (alpha=1.0, gamma=0.1)

In [None]:
# EXPERIMENT 2a: Feature distillation (gamma=0.1)
exp2a_results = run_experiment(
    exp_name='exp2_gamma01',
    alpha=1.0,
    gamma=0.1,
    use_agre=True
)

## Cell 11: Exp2-b (alpha=1.0, gamma=0.25)

In [None]:
# EXPERIMENT 2b: Feature distillation (gamma=0.25)
exp2b_results = run_experiment(
    exp_name='exp2_gamma025',
    alpha=1.0,
    gamma=0.25,
    use_agre=True
)

## Cell 12: Exp3-a (alpha=0.7, gamma=0.1)

In [None]:
# EXPERIMENT 3a: Combined (alpha=0.7, gamma=0.1)
exp3a_results = run_experiment(
    exp_name='exp3_a07_g01',
    alpha=0.7,
    gamma=0.1,
    use_agre=True
)

## Cell 13: Exp3-b (alpha=0.9, gamma=0.1)

In [None]:
# EXPERIMENT 3b: Combined (alpha=0.9, gamma=0.1)
exp3b_results = run_experiment(
    exp_name='exp3_a09_g01',
    alpha=0.9,
    gamma=0.1,
    use_agre=True
)

---
# Results Analysis

## Cell 14: Compile Results

In [None]:
import pandas as pd

# Load all results
log_path = os.path.join(LOG_DIR, 'experiment_results.json')
with open(log_path, 'r') as f:
    all_results = json.load(f)

# Create summary table
rows = []
for exp_name, data in all_results.items():
    rows.append({
        'Experiment': exp_name,
        'Alpha': data['alpha'],
        'Gamma': data['gamma'],
        'WGA (%)': f"{data['wga']*100:.2f}",
        'Avg Acc (%)': f"{data['avg_acc']*100:.2f}",
        'Gap (%)': f"{data['acc_gap']*100:.2f}",
    })

df = pd.DataFrame(rows)
print("\n" + "="*70)
print("EXPERIMENT RESULTS SUMMARY")
print("="*70)
print(df.to_string(index=False))
print("="*70)

# Find best
best_exp = max(all_results.items(), key=lambda x: x[1]['wga'])
print(f"\nBest WGA: {best_exp[0]} with {best_exp[1]['wga']*100:.2f}%")

## Cell 15: Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Load results
log_path = os.path.join(LOG_DIR, 'experiment_results.json')
with open(log_path, 'r') as f:
    all_results = json.load(f)

# Plot 1: WGA comparison bar chart
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart
names = list(all_results.keys())
wgas = [all_results[n]['wga'] * 100 for n in names]

colors = ['#2ecc71' if w == max(wgas) else '#3498db' for w in wgas]
axes[0].bar(names, wgas, color=colors)
axes[0].set_ylabel('Worst-Group Accuracy (%)')
axes[0].set_title('WGA by Experiment')
axes[0].set_xticklabels(names, rotation=45, ha='right')
axes[0].axhline(y=85, color='r', linestyle='--', label='AGRE-KD baseline (85%)')
axes[0].legend()

# Per-group accuracy for best experiment
best_name = max(all_results.keys(), key=lambda k: all_results[k]['wga'])
best = all_results[best_name]
group_names = ['Landbird+Land', 'Landbird+Water', 'Waterbird+Land', 'Waterbird+Water']
group_accs = [best['group_accs'][str(i)] * 100 for i in range(4)]

colors = ['#e74c3c' if g == min(group_accs) else '#3498db' for g in group_accs]
axes[1].bar(group_names, group_accs, color=colors)
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title(f'Per-Group Accuracy: {best_name}')
axes[1].set_xticklabels(group_names, rotation=45, ha='right')

plt.tight_layout()
plt.savefig(os.path.join(LOG_DIR, 'results_visualization.png'), dpi=150)
plt.show()

print(f"\nFigure saved to {LOG_DIR}/results_visualization.png")