# AML Project 2: Sparse Non-IID Sweep

This notebook runs the **Sparse** Non-IID experiments using the Scaled Rounds logic.
- **Protocol:** Backbone Fine-tuning (Frozen Head)
- **Input Mask:** `output/masks/mask_fisher.pt` (or derived from scores)
- **Base Rounds:** 300
- **Scaling Logic:** Constant Computation (R * J)
- **Config:** Team-distributed via `TARGET_NC`

In [None]:
# Clone Repository & Install Dependencies
!git clone https://github.com/emanueleR3/AML-Project-2.git
%cd AML-Project-2
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:
import sys
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.utils import set_seed, get_device, ensure_dir, save_metrics_json
from src.data import load_cifar100, create_dataloader, partition_non_iid
from src.model import build_model
from src.train import evaluate
from src.sparse_fedavg import run_fedavg_sparse_round
from src.masking import compute_fisher_diagonal, create_mask, get_mask_sparsity

sys.path.append('.')

OUTPUT_DIR = 'output/scaled_sparse'
ensure_dir(OUTPUT_DIR)
device = get_device()
set_seed(42)

In [None]:
# Configuration (same as dense baseline)
NUM_CLIENTS = 100
CLIENTS_PER_ROUND = 0.1
LR = 1e-4
WEIGHT_DECAY = 1e-4
EVAL_FREQ = 10
SPARSITY = 0.8  # TODO update
CALIB_ROUNDS = 3 # TODO update

model_config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'finetune_all',
    'freeze_head': True,
    'device': device
}

BASELINE_PATH = 'output/main/pretrained_head.pt'

In [None]:
# Load Data
train_full, test_data = load_cifar100(data_dir='./data', download=True)
train_size = int(0.9 * len(train_full))
val_size = len(train_full) - train_size
train_data, val_data = torch.utils.data.random_split(
    train_full, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

val_loader = create_dataloader(val_data, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_data, batch_size=64, shuffle=False)

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

In [None]:
# Load pretrained model
if os.path.exists(BASELINE_PATH):
    pretrained_state = torch.load(BASELINE_PATH, map_location=device)['model_state_dict']
    print(f"✓ Loaded pretrained model from {BASELINE_PATH}")
else:
    pretrained_state = None
    print("⚠ No pretrained model found, using random initialization")

# Calibrate mask with multi-round Fisher on IID data
print("\nCalibrating mask with Fisher (3 rounds)...")
from src.data import partition_iid

iid_datasets = partition_iid(train_data, NUM_CLIENTS)
iid_loaders = [create_dataloader(ds, 32, True, 0) for ds in iid_datasets]

# Build model for calibration
model = build_model(model_config)
model.to(device)
if pretrained_state is not None:
    model.load_state_dict(pretrained_state)

# Multi-round Fisher calibration
cumulative_fisher = None
m = max(1, int(NUM_CLIENTS * CLIENTS_PER_ROUND))

for cal_round in range(1, CALIB_ROUNDS + 1):
    print(f"  Calibration round {cal_round}/{CALIB_ROUNDS}")
    selected = np.random.choice(NUM_CLIENTS, m, replace=False)
    round_fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
    
    for idx in selected:
        client_fisher = compute_fisher_diagonal(model, iid_loaders[idx], device, num_batches=50)
        for n in round_fisher:
            if n in client_fisher:
                round_fisher[n] += client_fisher[n]
    
    for n in round_fisher:
        round_fisher[n] /= len(selected)
    
    if cumulative_fisher is None:
        cumulative_fisher = {n: f.clone() for n, f in round_fisher.items()}
    else:
        for n in cumulative_fisher:
            cumulative_fisher[n] += round_fisher[n]

for n in cumulative_fisher:
    cumulative_fisher[n] /= CALIB_ROUNDS

# Create mask with LEAST SENSITIVE rule
mask = create_mask(cumulative_fisher, model, sparsity_ratio=SPARSITY, rule='least_sensitive')
print(f"✓ Mask created: {get_mask_sparsity(mask):.2%} active weights")

In [None]:
# ==========================================
# SWEEP CONTROL
# ==========================================
RUN_ALL_NC = False
TARGET_NC = 1   # Options: 1, 5, 10, 50
# ==========================================

# Scaling Logic (constant computation budget)
BASE_J = 4
BASE_ROUNDS = 300
TOTAL_STEPS = BASE_J * BASE_ROUNDS

def get_scaled_rounds(j):
    return TOTAL_STEPS // j

# Params
NC_VALUES_ALL = [1, 5, 10, 50]
J_VALUES = [4, 8, 16]

nc_scenarios = NC_VALUES_ALL if RUN_ALL_NC else [TARGET_NC]

for nc in nc_scenarios:
    for j in J_VALUES:
        scaled_rounds = get_scaled_rounds(j)
        exp_name = f'sparse_noniid_nc{nc}_j{j}'
        print(f"\n" + "*"*50)
        print(f"SPARSE: Nc={nc}, J={j}, Rounds={scaled_rounds}")
        print(f"{'*'*50}")
        
        # 1. Partition (Non-IID)
        client_datasets = partition_non_iid(train_data, NUM_CLIENTS, nc, 42)
        client_loaders = [create_dataloader(ds, 32, True, 0) for ds in client_datasets]
        
        # 2. Build fresh model
        model = build_model(model_config)
        model.to(device)
        if pretrained_state is not None:
            model.load_state_dict(pretrained_state)
        
        # NOTE: Removed initial pruning to fix 1% accuracy bug
        # apply_mask_to_model(model, mask)  <-- REMOVED
        
        # 4. Train with SparseSGDM
        m = max(1, int(NUM_CLIENTS * CLIENTS_PER_ROUND))
        history = {'round': [], 'train_acc': [], 'val_acc': [], 'test_acc': []}
        
        for r in range(1, scaled_rounds + 1):
            selected = np.random.choice(NUM_CLIENTS, m, replace=False)
            
            # Use SparseSGDM via run_fedavg_sparse_round
            loss, acc = run_fedavg_sparse_round(
                model, client_loaders, selected.tolist(),
                lr=LR, weight_decay=WEIGHT_DECAY, device=device,
                local_steps=j, mask=mask
            )
            
            # Always save train accuracy
            history['round'].append(r)
            history['train_acc'].append(acc)
            
            # Only evaluate val/test periodically (expensive)
            if r % EVAL_FREQ == 0 or r == scaled_rounds:
                val_acc = evaluate(model, val_loader, device)
                test_acc = evaluate(model, test_loader, device)
                print(f"  Round {r}/{scaled_rounds} | Train: {acc:.1f}% | Val: {val_acc:.1f}% | Test: {test_acc:.1f}%")
                history['val_acc'].append(val_acc)
                history['test_acc'].append(test_acc)
            else:
                history['val_acc'].append(None)
                history['test_acc'].append(None)
        
        # Save
        save_metrics_json(os.path.join(OUTPUT_DIR, f'{exp_name}.json'), history)
        print(f"✓ Final Test Acc: {history['test_acc'][-1]:.2f}%")