# üöÄ Setup (Run Every Session)

Run this cell every time you open the notebook to:
- Mount Google Drive
- Install dependencies
- Clone/update the code repository
- Set up paths

In [2]:
# === SETUP (Run Every Session) ===
from google.colab import drive
import os
import sys

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Install dependencies
!pip install -q wilds tqdm scikit-learn

# 3. Clone or update repo
REPO_DIR = '/content/repo'
if os.path.exists(REPO_DIR):
    print("Repo exists, pulling latest...")
    !cd {REPO_DIR} && git pull
else:
    print("Cloning repo...")
    !git clone https://github.com/dat-tran05/robust-ensemble-kd.git {REPO_DIR}

# 4. Add to Python path
sys.path.insert(0, f'{REPO_DIR}/light-code')

# 5. Define paths (EDIT THIS IF YOUR DRIVE PATH IS DIFFERENT)
DRIVE_ROOT = '/content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd'
DATA_DIR = f'{DRIVE_ROOT}/data/waterbirds_v1.0'
TEACHER_DIR = f'{DRIVE_ROOT}/teacher_checkpoints'
CHECKPOINT_DIR = f'{DRIVE_ROOT}/checkpoints'

print("\n‚úÖ Setup complete!")
print(f"   Data: {DATA_DIR}")
print(f"   Teachers: {TEACHER_DIR}")

Mounted at /content/drive
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m126.2/126.2 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m78.8/78.8 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hCloning repo...
Cloning into '/content/repo'...
remote: Enumerating objects: 36, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (27/27), done.[K
remote: Total 36 (delta 13), reused 31 (delta 8), pack-reused 0 (from 0)[K
Receiving objects: 100% (36/36), 59.44 KiB | 9.91 MiB/s, done.
Resolving deltas: 100% (13/13), done.

‚úÖ Setup complete!
   Data: /content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd/data/waterbirds_v1.0
   Teachers: /content/drive/MyDrive

# üìä Load Data & Verify

Load the Waterbirds dataset and verify that teacher checkpoints exist.

In [None]:
# === LOAD DATA & VERIFY ===
import torch
from data import get_waterbirds_loaders
from models import get_teacher_model, load_teacher_checkpoint
from eval import compute_group_accuracies, print_results

# Load data
print("Loading Waterbirds data...")
loaders = get_waterbirds_loaders(DATA_DIR, batch_size=32, num_workers=4)

# Check teacher checkpoints
print(f"\nTeacher checkpoints in {TEACHER_DIR}:")
ckpts = sorted([f for f in os.listdir(TEACHER_DIR) if f.endswith('.pt')])
for f in ckpts:
    print(f"  - {f}")

# Quick sanity check
print("\nQuick sanity check...")
teacher = get_teacher_model('resnet50', num_classes=2, pretrained=False)
load_teacher_checkpoint(teacher, os.path.join(TEACHER_DIR, ckpts[0]))
teacher.cuda().eval()

batch = next(iter(loaders['test']))
with torch.no_grad():
    preds = teacher(batch['image'].cuda()).argmax(dim=1)
    acc = (preds.cpu() == batch['label']).float().mean()
print(f"Test batch accuracy: {acc*100:.1f}%")
print("‚úÖ Everything loaded correctly!")

# üìà Evaluate Baseline Teacher

Evaluate the original ERM teacher on the full test set to see baseline performance.
Expected: ~73% WGA (before DFR debiasing).

In [None]:
# === EVALUATE BASELINE TEACHER ===
print("Evaluating baseline (biased) teacher on full test set...")
print(f"Using: {ckpts[0]}\n")

baseline_results = compute_group_accuracies(teacher, loaders['test'], device='cuda')
print_results(baseline_results, f"Baseline Teacher: {ckpts[0]}")

print(f"\nüìä Baseline WGA: {baseline_results['wga']*100:.2f}%")
print("   (This should improve to ~90%+ after DFR)")

Evaluating baseline (biased) teacher on full test set...
Using: erm_seed1.pt

  Evaluating 5794 samples (182 batches)...


                                                            

  Evaluation complete (416.6s) - WGA: 73.8%

 Baseline Teacher: erm_seed1.pt

Per-group accuracy:
  Landbird + Land (majority): 99.51% (n=2255)
  Landbird + Water (minority): 87.89% (n=2255)
  Waterbird + Land (minority, hardest): 73.83% (n=642)
  Waterbird + Water (majority): 96.42% (n=642)

Aggregate metrics:
  Worst-Group Accuracy (WGA): 73.83%
  Average Accuracy: 91.80%
  Accuracy Gap: 25.68%
  Worst Group: 2


üìä Baseline WGA: 73.83%
   (This should improve to ~90%+ after DFR)




# üîß Apply DFR to Teachers

Apply Deep Feature Reweighting to create debiased versions of the ERM teachers.
- Takes ~30 minutes for 5 teachers
- Only need to run this ONCE (results saved to Drive)

In [None]:
# === APPLY DFR TO ALL TEACHERS ===
from prepare_teachers import colab_prepare_teachers

print("Applying DFR to all ERM teachers...")
print(f"This will create debiased versions in: {TEACHER_DIR}\n")

results = colab_prepare_teachers(
    checkpoint_dir=TEACHER_DIR,
    data_dir=DATA_DIR,
    num_teachers=5
)

print("\n‚úÖ DFR complete! Debiased teachers saved.")

Applying DFR to all ERM teachers...
This will create debiased versions in: /content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd/teacher_checkpoints


PREPARING 5 TEACHERS

Found 5 ERM checkpoints:
  - erm_seed1.pt (seed 1)
  - erm_seed2.pt (seed 2)
  - erm_seed3.pt (seed 3)
  - erm_seed4.pt (seed 4)
  - erm_seed5.pt (seed 5)

Loading data...
Loaded train split: 4795 samples
  Group counts: {0: np.int64(3498), 1: np.int64(184), 2: np.int64(56), 3: np.int64(1057)}
  Worst group: 2 with 56 samples
Loaded val split: 1199 samples
  Group counts: {0: np.int64(467), 1: np.int64(466), 2: np.int64(133), 3: np.int64(133)}
  Worst group: 2 with 133 samples
Loaded test split: 5794 samples
  Group counts: {0: np.int64(2255), 1: np.int64(2255), 2: np.int64(642), 3: np.int64(642)}
  Worst group: 2 with 642 samples
Data loaded (0.1s)

[1/5] Processing erm_seed1.pt
  [1/5] Loading model...
Loaded checkpoint from /content/drive/MyDrive/MIT/M

                                                                

            Features: (1199, 2048) (9.9s)
      [2/3] Creating balanced subset...
            532 samples (0.0s)
      [3/3] Training new classifier...
            Done (0.1s)
        Done (10.0s)
  [4/5] Evaluating debiased (DFR) model...




        WGA: 93.3% (43.7s)
  [5/5] Saving checkpoint...
        Saved: teacher_1_debiased.pt (0.3s)

  Summary: 73.8% -> 93.3% (+19.5%) in 97.6s

[1/5] Done in 97.6s (elapsed: 1.6min, remaining: ~6.5min)

[2/5] Processing erm_seed2.pt
  [1/5] Loading model...
Loaded checkpoint from /content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd/teacher_checkpoints/erm_seed2.pt
  All keys matched!
        Done (0.6s)
  [2/5] Evaluating biased (ERM) model...
        WGA: 71.2% (44.5s)
  [3/5] Applying DFR (retraining last layer)...
      [1/3] Extracting features from 1199 samples...




            Features: (1199, 2048) (10.1s)
      [2/3] Creating balanced subset...
            532 samples (0.0s)
      [3/3] Training new classifier...
            Done (0.1s)
        Done (10.2s)
  [4/5] Evaluating debiased (DFR) model...
        WGA: 92.1% (43.9s)
  [5/5] Saving checkpoint...
        Saved: teacher_2_debiased.pt (0.2s)

  Summary: 71.2% -> 92.1% (+20.9%) in 99.4s

[2/5] Done in 99.4s (elapsed: 3.3min, remaining: ~4.9min)

[3/5] Processing erm_seed3.pt
  [1/5] Loading model...
Loaded checkpoint from /content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd/teacher_checkpoints/erm_seed3.pt
  All keys matched!
        Done (3.1s)
  [2/5] Evaluating biased (ERM) model...
        WGA: 69.6% (44.2s)
  [3/5] Applying DFR (retraining last layer)...
      [1/3] Extracting features from 1199 samples...




            Features: (1199, 2048) (9.1s)
      [2/3] Creating balanced subset...
            532 samples (0.0s)
      [3/3] Training new classifier...
            Done (0.2s)
        Done (9.3s)
  [4/5] Evaluating debiased (DFR) model...
        WGA: 91.9% (43.4s)
  [5/5] Saving checkpoint...
        Saved: teacher_3_debiased.pt (0.2s)

  Summary: 69.6% -> 91.9% (+22.3%) in 100.2s

[3/5] Done in 100.2s (elapsed: 5.0min, remaining: ~3.3min)

[4/5] Processing erm_seed4.pt
  [1/5] Loading model...
Loaded checkpoint from /content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd/teacher_checkpoints/erm_seed4.pt
  All keys matched!
        Done (3.8s)
  [2/5] Evaluating biased (ERM) model...
        WGA: 70.9% (43.4s)
  [3/5] Applying DFR (retraining last layer)...
      [1/3] Extracting features from 1199 samples...




            Features: (1199, 2048) (9.6s)
      [2/3] Creating balanced subset...
            532 samples (0.0s)
      [3/3] Training new classifier...
            Done (0.2s)
        Done (9.8s)
  [4/5] Evaluating debiased (DFR) model...
        WGA: 93.6% (43.9s)
  [5/5] Saving checkpoint...
        Saved: teacher_4_debiased.pt (0.3s)

  Summary: 70.9% -> 93.6% (+22.7%) in 101.2s

[4/5] Done in 101.2s (elapsed: 6.6min, remaining: ~1.7min)

[5/5] Processing erm_seed5.pt
  [1/5] Loading model...
Loaded checkpoint from /content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd/teacher_checkpoints/erm_seed5.pt
  All keys matched!
        Done (3.9s)
  [2/5] Evaluating biased (ERM) model...
        WGA: 68.7% (43.4s)
  [3/5] Applying DFR (retraining last layer)...
      [1/3] Extracting features from 1199 samples...


                                                                

            Features: (1199, 2048) (9.7s)
      [2/3] Creating balanced subset...
            532 samples (0.0s)
      [3/3] Training new classifier...
            Done (0.1s)
        Done (9.8s)
  [4/5] Evaluating debiased (DFR) model...




        WGA: 93.8% (42.9s)
  [5/5] Saving checkpoint...
        Saved: teacher_5_debiased.pt (0.2s)

  Summary: 68.7% -> 93.8% (+25.1%) in 100.2s

[5/5] Done in 100.2s (elapsed: 8.3min, remaining: ~0.0min)

TEACHER PREPARATION COMPLETE (8.3 min)

All teachers in: /content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd/teacher_checkpoints

Biased (ERM):
  erm_seed1.pt -> WGA=73.8%
  erm_seed2.pt -> WGA=71.2%
  erm_seed3.pt -> WGA=69.6%
  erm_seed4.pt -> WGA=70.9%
  erm_seed5.pt -> WGA=68.7%

Debiased (DFR):
  teacher_1_debiased.pt -> WGA=93.3%
  teacher_2_debiased.pt -> WGA=92.1%
  teacher_3_debiased.pt -> WGA=91.9%
  teacher_4_debiased.pt -> WGA=93.6%
  teacher_5_debiased.pt -> WGA=93.8%

Average WGA improvement: +22.1%
Summary saved to /content/drive/MyDrive/MIT/MIT Junior Year (2025-2026)/Fall Semester/6.7960/6.7960 Final Project/robust-ensemble-kd/teacher_checkpoints/preparation_summary.pt

‚úÖ DFR complete! Debiased teache

# ‚úÖ Summarize Teachers

Load each debiased teacher and verify WGA improved from ~73% to ~90%+.

In [4]:
import torch

summary = torch.load(f'{TEACHER_DIR}/preparation_summary.pt')

print("Teacher Results:")
print("-" * 60)
for seed, res in sorted(summary['results'].items()):
    biased = res['biased']
    debiased = res['debiased']
    print(f"Seed {seed}:")
    print(f"  Biased:   WGA={biased['wga']*100:.1f}%, Avg={biased['avg_acc']*100:.1f}%")
    print(f"  Debiased: WGA={debiased['wga']*100:.1f}%, Avg={debiased['avg_acc']*100:.1f}%")


Teacher Results:
------------------------------------------------------------
Seed 1:
  Biased:   WGA=73.8%, Avg=91.8%
  Debiased: WGA=93.3%, Avg=94.4%
Seed 2:
  Biased:   WGA=71.2%, Avg=91.2%
  Debiased: WGA=92.1%, Avg=94.9%
Seed 3:
  Biased:   WGA=69.6%, Avg=91.4%
  Debiased: WGA=91.9%, Avg=94.9%
Seed 4:
  Biased:   WGA=70.9%, Avg=91.9%
  Debiased: WGA=93.6%, Avg=94.9%
Seed 5:
  Biased:   WGA=68.7%, Avg=90.6%
  Debiased: WGA=93.8%, Avg=94.7%


# üì• One-Time: Download Waterbirds Dataset

**Skip this section if you've already downloaded the data.**

This downloads the Waterbirds dataset (~500MB) from Stanford NLP.

In [None]:
# === DOWNLOAD WATERBIRDS (One-time only) ===
import os
import tarfile
import urllib.request
import ssl

# Workaround for SSL cert issues
ssl._create_default_https_context = ssl._create_unverified_context

# Paths
dataset_dir = f'{DRIVE_ROOT}/data'
waterbirds_dir = f'{dataset_dir}/waterbirds_v1.0'
os.makedirs(dataset_dir, exist_ok=True)

# Download
url = "https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz"
tar_path = f'{dataset_dir}/waterbirds.tar.gz'

print("Downloading Waterbirds dataset (~500MB)...")
urllib.request.urlretrieve(url, tar_path)
print("Download complete!")

# Extract
print("Extracting...")
with tarfile.open(tar_path, 'r:gz') as tar:
    tar.extractall(path=dataset_dir)

# Rename extracted folder
extracted = f'{dataset_dir}/waterbird_complete95_forest2water2'
if os.path.exists(extracted):
    import shutil
    if os.path.exists(waterbirds_dir):
        shutil.rmtree(waterbirds_dir)
    shutil.move(extracted, waterbirds_dir)

# Cleanup
os.remove(tar_path)

# Verify
if os.path.exists(f'{waterbirds_dir}/metadata.csv'):
    print(f"‚úÖ Dataset ready at: {waterbirds_dir}")
else:
    print("‚ùå Warning: metadata.csv not found")