# Step 3 - FedAvg IID (Backbone Fine-tuning)

This notebook runs the FedAvg training on IID-partitioned data.
- **Pre-requisite:** `pretrained_head.pt` (Step 1)
- **Backbone:** Unfrozen (`finetune_all`)
- **Classifier:** Frozen (`freeze_head=True`)
- **Rounds:** 300
- **Output:** `output/main/fedavg_iid_best.pt`

In [None]:
# Clone Repository & Install Dependencies
!git clone https://github.com/emanueleR3/AML-Project-2.git
%cd AML-Project-2
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:
# Imports & Setup
import sys
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.utils import set_seed, get_device, ensure_dir, save_checkpoint, save_metrics_json, count_parameters
from src.data import load_cifar100, create_dataloader, partition_iid
from src.model import build_model
from src.train import evaluate
from src.fedavg import run_fedavg

sys.path.append('.')

# Setup output dirs
OUTPUT_DIR = 'output/main'
ensure_dir(OUTPUT_DIR)
device = get_device()
print(f"Device: {device}")

# Set seed for reproducibility
set_seed(42)

In [None]:
# Load Data
print("Loading CIFAR-100...")
train_trainval, test_dataset = load_cifar100(data_dir='./data', image_size=224, download=True)

train_size = int(0.9 * len(train_trainval))
val_size = len(train_trainval) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    train_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

# Create loaders for evaluation
val_loader = create_dataloader(val_dataset, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_dataset, batch_size=64, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

In [None]:
# Model Configuration
config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'finetune_all', 
    'freeze_head': True,             
    'dropout': 0.1,
    'device': device
}

model = build_model(config)
model.to(device)

# Load Pre-trained Head
head_path = os.path.join(OUTPUT_DIR, 'pretrained_head.pt')
if os.path.exists(head_path):
    ckpt = torch.load(head_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    print(f"✓ Loaded pre-trained head from {head_path}")
else:
    raise FileNotFoundError("Please run pretrain_head.ipynb first!")

print(f"Trainable parameters: {count_parameters(model, trainable_only=True):,}")

In [None]:
# FedAvg IID Configuration
iid_config = {
    'num_clients': 100,
    'clients_per_round': 0.1,
    'local_steps': 4,
    'num_rounds': 200,
    'batch_size': 64,
    'lr': 1e-4,  
    'weight_decay': 1e-4,
    'seed': 42,
    'eval_freq': 10
}

# Partition IID
print("Partitioning IID...")
client_datasets = partition_iid(train_dataset, iid_config['num_clients'], iid_config['seed'])
client_loaders = [create_dataloader(ds, iid_config['batch_size'], True, 0) for ds in client_datasets]

print(f"Created {len(client_loaders)} client dataloaders")
print(f"Samples per client: ~{len(train_dataset) // iid_config['num_clients']}")

In [None]:
# Run FedAvg IID
print(f"\nStarting FedAvg IID Training (Backbone FT) for {iid_config['num_rounds']} rounds...")
print(f"  Clients per round: {int(iid_config['num_clients'] * iid_config['clients_per_round'])}")
print(f"  Local steps: {iid_config['local_steps']}")
print(f"  Learning rate: {iid_config['lr']}")

history = run_fedavg(model, client_loaders, val_loader, test_loader, iid_config, device)

# Save
save_metrics_json(os.path.join(OUTPUT_DIR, 'fedavg_iid_metrics.json'), history)
save_checkpoint({'model_state_dict': model.state_dict()}, os.path.join(OUTPUT_DIR, 'fedavg_iid_best.pt'))

print(f"\n✓ Training complete!")
print(f"  Final Test Accuracy: {history['test_acc'][-1]:.2f}%")