# AML Project 2: Sparse FedAvg IID (Comparison)

This notebook runs the **Sparse** FedAvg IID experiment to compare against the Dense IID Baseline.
- **Protocol:** Backbone Fine-tuning (Frozen Head)
- **Input Mask:** `output/masks/mask_fisher.pt` (or Best Mask from ablation)
- **Rounds:** 300
- **Output:** `output/comparison/sparse_fedavg_iid.pt`

In [None]:
# Clone Repository & Install Dependencies
!git clone https://github.com/emanueleR3/AML-Project-2.git
%cd AML-Project-2
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:
import sys
import os
import torch
import torch.nn as nn
import numpy as np

from src.utils import set_seed, get_device, ensure_dir, save_metrics_json, save_checkpoint
from src.data import load_cifar100, create_dataloader, partition_iid
from src.model import build_model
from src.train import evaluate
from src.sparse_fedavg import run_fedavg_sparse_round
from src.masking import compute_fisher_diagonal, create_mask, get_mask_sparsity

sys.path.append('.')

OUTPUT_DIR = 'output/comparison'
ensure_dir(OUTPUT_DIR)
device = get_device()
set_seed(42)

In [None]:
# Load Data
train_full, test_data = load_cifar100(data_dir='./data', download=True)
train_size = int(0.9 * len(train_full))
val_size = len(train_full) - train_size
train_data, val_data = torch.utils.data.random_split(
    train_full, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

val_loader = create_dataloader(val_data, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_data, batch_size=64, shuffle=False)

# Configuration (same as dense baseline)
NUM_ROUNDS = 300
NUM_CLIENTS = 100
CLIENTS_PER_ROUND = 0.1
LOCAL_STEPS = 4
LR = 1e-4
WEIGHT_DECAY = 1e-4
EVAL_FREQ = 10
SPARSITY = 0.8 # TODO update 
CALIB_ROUNDS = 3 # TODO update

model_config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'finetune_all',
    'freeze_head': True,
    'device': device
}

BASELINE_PATH = 'output/main/pretrained_head.pt'

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

In [None]:
# Partition Data (IID)
client_datasets = partition_iid(train_data, NUM_CLIENTS)
client_loaders = [create_dataloader(ds, 32, True, 0) for ds in client_datasets]
print(f"Created {NUM_CLIENTS} IID client loaders")

In [None]:
# ============================================================
# Step 0: Load Pretrained Head Model
# ============================================================
model = build_model(model_config)
model.to(device)
checkpoint = torch.load(BASELINE_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded pretrained head from {BASELINE_PATH}")

# ============================================================
# Step 1: Multi-Round Fisher Calibration (compute mask inline)
# ============================================================
print(f"\n=== Fisher Calibration ({CALIB_ROUNDS} rounds) ===")

# Calibration uses a subset of clients each round
calib_clients_per_round = max(1, int(NUM_CLIENTS * CLIENTS_PER_ROUND))

# Aggregate Fisher scores across calibration rounds
aggregated_fisher = None
for calib_round in range(CALIB_ROUNDS):
    # Select random clients
    selected_clients = np.random.choice(NUM_CLIENTS, calib_clients_per_round, replace=False)
    
    # Compute Fisher for each selected client and aggregate
    round_fisher = None
    for client_idx in selected_clients:
        client_loader = client_loaders[client_idx]
        fisher = compute_fisher_diagonal(model, client_loader, device, num_batches=5)
        
        if round_fisher is None:
            round_fisher = {k: v.clone() for k, v in fisher.items()}
        else:
            for k in round_fisher:
                round_fisher[k] += fisher[k]
    
    # Average over selected clients
    for k in round_fisher:
        round_fisher[k] /= len(selected_clients)
    
    # Aggregate across rounds
    if aggregated_fisher is None:
        aggregated_fisher = {k: v.clone() for k, v in round_fisher.items()}
    else:
        for k in aggregated_fisher:
            aggregated_fisher[k] += round_fisher[k]
    
    print(f"  Calibration round {calib_round + 1}/{CALIB_ROUNDS} complete")

# Average across calibration rounds
for k in aggregated_fisher:
    aggregated_fisher[k] /= CALIB_ROUNDS

# Create mask using LEAST SENSITIVE parameters (low Fisher = can be pruned)
mask = create_mask(aggregated_fisher, model, sparsity_ratio=SPARSITY, rule='least_sensitive')
actual_sparsity = get_mask_sparsity(mask)
print(f"Created mask with {actual_sparsity*100:.1f}% sparsity (rule: least_sensitive)")

# Save mask for reference
MASK_OUTPUT = os.path.join(OUTPUT_DIR, 'mask_iid_ls.pt')
torch.save(mask, MASK_OUTPUT)
print(f"Saved mask to {MASK_OUTPUT}")

# ============================================================
# Step 2: Sparse Fine-tuning with SparseSGDM
# ============================================================
print(f"\n=== Sparse IID Training ({NUM_ROUNDS} rounds) ===")

# NOTE: Removed initial pruning (apply_mask_to_model). Only gradient masking is used.
# This ensures model starts with valid pretrained weights.

# Training history
history = {'round': [], 'train_acc': [], 'val_acc': [], 'test_acc': []}
best_val_acc = 0.0

for round_idx in range(NUM_ROUNDS):
    # Select clients
    num_selected = max(1, int(NUM_CLIENTS * CLIENTS_PER_ROUND))
    selected_clients = np.random.choice(NUM_CLIENTS, num_selected, replace=False)
    
    # Run sparse federated round (uses SparseSGDM internally)
    loss, train_acc = run_fedavg_sparse_round(
        model, client_loaders, selected_clients.tolist(),
        lr=LR, weight_decay=WEIGHT_DECAY, device=device,
        local_steps=LOCAL_STEPS, mask=mask
    )
    
    # Record training metrics every round
    history['round'].append(round_idx + 1)
    history['train_acc'].append(train_acc)
    
    # Evaluate periodically
    if (round_idx + 1) % EVAL_FREQ == 0 or round_idx == NUM_ROUNDS - 1:
        val_acc = evaluate(model, val_loader, device)
        test_acc = evaluate(model, test_loader, device)
        history['val_acc'].append(val_acc)
        history['test_acc'].append(test_acc)
        
        print(f"Round {round_idx + 1}/{NUM_ROUNDS}: Train={train_acc:.2f}%, Val={val_acc:.2f}%, Test={test_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_checkpoint({'model_state_dict': model.state_dict()}, 
                          os.path.join(OUTPUT_DIR, 'sparse_iid_best.pt'))
    else:
        history['val_acc'].append(None)
        history['test_acc'].append(None)

# Save final metrics
save_metrics_json(os.path.join(OUTPUT_DIR, 'sparse_iid_metrics.json'), history)
print(f"\n=== Done! Best Val Acc: {best_val_acc:.2f}% ===")