# AML Project 2: Protocol Step 2 - Central Baseline (Backbone Fine-tuning)

This notebook trains the central baseline according to the new protocol.
- **Pre-requisite:** `pretrained_head.pt` (Protocol Step 1)
- **Backbone:** Unfrozen (`finetune_all`)
- **Classifier:** Frozen (`freeze_head=True`)
- **Epochs:** 40
- **Output:** `output/main/central_baseline.pt`

In [None]:
# Clone Repository & Install Dependencies
!git clone https://github.com/emanueleR3/AML-Project-2.git
%cd AML-Project-2
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:
# Imports & Setup
import sys
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.utils import set_seed, get_device, ensure_dir, save_checkpoint, save_metrics_json, count_parameters
from src.data import load_cifar100, create_dataloader
from src.model import build_model
from src.train import evaluate, train_one_epoch

sys.path.append('.')

# Setup output dirs
OUTPUT_DIR = 'output/main'
ensure_dir(OUTPUT_DIR)
device = get_device()
print(f"Device: {device}")

# Set seed for reproducibility
set_seed(42)

In [None]:
# Load Data
print("Loading CIFAR-100...")
train_trainval, test_dataset = load_cifar100(data_dir='./data', image_size=224, download=True)

# Split Train/Val
train_size = int(0.9 * len(train_trainval))
val_size = len(train_trainval) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    train_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

# Create loaders
train_loader = create_dataloader(train_dataset, batch_size=64, shuffle=True)
val_loader = create_dataloader(val_dataset, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_dataset, batch_size=64, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

In [None]:
# Model Configuration
config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'finetune_all',  # Unfreeze backbone
    'freeze_head': True,              # Freeze classifier (Protocol Step 2)
    'dropout': 0.1,
    'device': device
}

model = build_model(config)
model.to(device)

# Load Pre-trained Head
head_path = os.path.join(OUTPUT_DIR, 'pretrained_head.pt')
if os.path.exists(head_path):
    ckpt = torch.load(head_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    print(f"✓ Loaded pre-trained head from {head_path}")
else:
    raise FileNotFoundError("Please run pretrain_head.ipynb first!")

print(f"Total parameters: {count_parameters(model):,}")
print(f"Trainable parameters: {count_parameters(model, trainable_only=True):,}")

In [None]:
# DIAGNOSTIC: Verify pretrained model accuracy BEFORE training
print("\n=== SANITY CHECK: Pretrained Model Performance ===")
criterion = nn.CrossEntropyLoss()
model.eval()
with torch.no_grad():
    val_loss, val_acc = evaluate(model, val_loader, criterion, device, show_progress=False)
print(f"Pretrained Model Val Accuracy: {val_acc:.2f}%")
print("(Should be ~75-85% if pretrained head and DINO backbone are working)")
print("(If ~1%, the pretrained_head.pt may be corrupted or mismatched)")
print("="*50)

In [None]:
# Hyperparameters
epochs = 40
eval_freq = 5

# lr=1e-4 for fine-tuning Vision Transformers with SGD
optimizer = torch.optim.SGD(model.get_trainable_params(), lr=1e-5, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
criterion = nn.CrossEntropyLoss()

best_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

model.train()
print(f"Starting Central Backbone Training for {epochs} epochs...")

for epoch in range(epochs):
    loss, acc = train_one_epoch(model, train_loader, optimizer, criterion, device, show_progress=False)
    
    # Validation logic
    current_epoch = epoch + 1
    if current_epoch % eval_freq == 0 or current_epoch == epochs: 
        val_loss, val_acc = evaluate(model, val_loader, criterion, device, show_progress=False)
        
        if val_acc > best_acc:
            best_acc = val_acc
            save_checkpoint({'model_state_dict': model.state_dict()}, os.path.join(OUTPUT_DIR, 'central_baseline.pt'))
            
        print(f"Epoch {current_epoch}/{epochs} | Train Acc: {acc:.2f}% | Val Acc: {val_acc:.2f}% | Best: {best_acc:.2f}%")
        
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

    else:
        print(f"Epoch {current_epoch}/{epochs} | Train Acc: {acc:.2f}% | (Skipping Eval)")
    
    scheduler.step()
    
    history['train_loss'].append(loss)
    history['train_acc'].append(acc)

print(f"\nBaseline finished. Best Val Acc: {best_acc:.2f}%")

save_metrics_json(os.path.join(OUTPUT_DIR, 'central_baseline_metrics.json'), history)

In [None]:
# Final Test Evaluation
print("\nEvaluating on Test Set...")
# Load best model
ckpt = torch.load(os.path.join(OUTPUT_DIR, 'central_baseline.pt'), map_location=device)
model.load_state_dict(ckpt['model_state_dict'])

test_loss, test_acc = evaluate(model, test_loader, criterion, device, show_progress=True)
print(f"\n✓ Final Test Accuracy: {test_acc:.2f}%")