# AML Project 2: Protocol Step 1 - Pre-train Linear Head

This notebook implements the first step of the experimental protocol:
- **Goal:** Train a linear classifier on top of the frozen DINO backbone.
- **Backbone:** Frozen (`freeze_policy='head_only'`)
- **Classifier:** Unfrozen (`freeze_head=False`)
- **Epochs:** 20
- **Output:** `output/main/pretrained_head.pt`

In [None]:
# Clone Repository & Install Dependencies
!git clone https://github.com/emanueleR3/AML-Project-2.git
%cd AML-Project-2
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:
import sys
import os
import torch
import torch.nn as nn

from src.utils import set_seed, get_device, ensure_dir, save_checkpoint, save_metrics_json, count_parameters
from src.data import load_cifar100, create_dataloader
from src.model import build_model
from src.train import evaluate, train_one_epoch

sys.path.append('.')

OUTPUT_DIR = 'output/main'
ensure_dir(OUTPUT_DIR)
device = get_device()
set_seed(42)

In [None]:
print("Loading CIFAR-100...")
train_trainval, test_dataset = load_cifar100(data_dir='./data', image_size=224, download=True)

train_size = int(0.9 * len(train_trainval))
val_size = len(train_trainval) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    train_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

train_loader = create_dataloader(train_dataset, batch_size=64, shuffle=True)
val_loader = create_dataloader(val_dataset, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# Configuration for Head Pre-training
config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'head_only',  # Freeze backbone
    'freeze_head': False,         # Train head
    'dropout': 0.1,
    'device': device
}

model = build_model(config)
model.to(device)

print(f"Total parameters: {count_parameters(model):,}")
print(f"Trainable parameters: {count_parameters(model, trainable_only=True):,}")
# Expect ~38k trainable params (384*100 + bias)

In [None]:
epochs = 20
optimizer = torch.optim.SGD(model.get_trainable_params(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
criterion = nn.CrossEntropyLoss()

best_acc = 0.0
history = {'train_loss': [], 'val_acc': []}

print(f"Starting Head Pre-training for {epochs} epochs...")

for epoch in range(epochs):
    loss, acc = train_one_epoch(model, train_loader, optimizer, criterion, device, show_progress=False)
    
    # Validation
    val_loss, val_acc = evaluate(model, val_loader, criterion, device, show_progress=False)
        
    if val_acc > best_acc:
        best_acc = val_acc
        # Save as PRETRAINED HEAD
        save_checkpoint({'model_state_dict': model.state_dict()}, os.path.join(OUTPUT_DIR, 'pretrained_head.pt'))
            
    print(f"Epoch {epoch+1}/{epochs} | Train Acc: {acc:.2f}% | Val Acc: {val_acc:.2f}% | Best: {best_acc:.2f}%")
    
    scheduler.step()
    history['train_loss'].append(loss)
    history['val_acc'].append(val_acc)

print(f"Pre-training finished. Best Val Acc: {best_acc:.2f}%")
save_metrics_json(os.path.join(OUTPUT_DIR, 'pretrained_head_metrics.json'), history)