# AML Project 2: Milestones 6-8 (Sparse Federated Learning)

This notebook covers the advanced stages of the project:
- **M6**: Mask Calibration (Fisher Information)
- **M7**: SparseSGDM Optimizer Verification
- **M8**: Sparse FedAvg Training (with IID/Non-IID and various Mask Rules)

**Prerequisites:**
- This notebook assumes **M3 (Central Baseline)** has been run and a checkpoint `central_baseline.pt` is available in the outputs directory.
- Alternatively, it can run with a fresh model (but mask calibration will be random/less effective).

**Note:** This notebook is configured for a **Real Experiment** scenario (100 Clients, 20 Rounds, 4 Local Steps).

In [None]:
# 1. Setup & Dependencies
# Ensure we are in the project root if running locally/on Kaggle
import os
import sys

# If running on Kaggle, clone repo if not present
if not os.path.exists('AML-Project-2') and not os.path.exists('src'):
    !git clone https://github.com/emanueleR3/AML-Project-2.git
    %cd AML-Project-2

# Install requirements
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:
# 2. Imports & Configuration
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy
from pathlib import Path

# Add src to path
sys.path.append('.')

from src.utils import set_seed, get_device, ensure_dir, save_checkpoint, load_checkpoint, save_metrics_json, count_parameters, AverageMeter
from src.data import load_cifar100, create_dataloader, partition_iid, partition_non_iid
from src.model import build_model
from src.train import evaluate, local_train
from src.fedavg import fedavg_aggregate
from src.optim import SparseSGDM
from src.masking import compute_sensitivity_scores, create_mask, save_mask, load_mask, get_mask_sparsity

# Setup output dirs
OUTPUT_DIR = 'outputs'
ensure_dir(OUTPUT_DIR)
device = get_device()
print(f"Device: {device}")

# Global Config
config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'head_only',
    'dropout': 0.1,
    'device': device,
    'seed': 42
}

set_seed(config['seed'])

## Helper Functions for Sparse FedAvg (M8 Logic)

We define `client_update_sparse` and `run_fedavg_sparse_round` here since they integrate `SparseSGDM`.

In [None]:
def client_update_sparse(
    model: nn.Module,
    train_loader: torch.utils.data.DataLoader,
    lr: float,
    weight_decay: float,
    device: torch.device,
    local_steps: int,
    mask: dict,
    criterion=None
):
    """Performs local training using SparseSGDM."""
    local_model = copy.deepcopy(model)
    local_model.to(device)
    local_model.train()
    
    if criterion is None:
        criterion = nn.CrossEntropyLoss()
    
    # Initialize SparseSGDM with the mask
    optimizer = SparseSGDM(
        local_model.get_trainable_params(),
        lr=lr,
        momentum=0.9,
        weight_decay=weight_decay,
        mask=mask,
        apply_wd_to_masked_only=True
    )
    
    # Run local steps
    avg_loss, avg_acc, n_samples = local_train(
        local_model, train_loader, optimizer, criterion, device, local_steps
    )
    
    return local_model.state_dict(), avg_loss, avg_acc, n_samples


def run_fedavg_sparse_round(
    global_model: nn.Module,
    client_loaders: list,
    selected_clients: list,
    lr: float,
    weight_decay: float,
    device: torch.device,
    local_steps: int,
    mask: dict,
    criterion=None
):
    """Run a single FedAvg round with sparse training."""
    client_state_dicts = []
    client_weights = []
    round_loss = AverageMeter()
    round_acc = AverageMeter()
    
    # We can use tqdm if running interactively
    # for client_idx in tqdm(selected_clients, desc='Clients', leave=False):
    for client_idx in selected_clients:
        loader = client_loaders[client_idx]
        
        state_dict, loss, acc, n_samples = client_update_sparse(
            global_model, loader, lr, weight_decay, device, local_steps, mask, criterion
        )
        
        client_state_dicts.append(state_dict)
        client_weights.append(n_samples)
        round_loss.update(loss, n_samples)
        round_acc.update(acc, n_samples)
    
    # Aggregate updates into global model
    fedavg_aggregate(global_model, client_state_dicts, client_weights)
    
    return round_loss.avg, round_acc.avg

## Load Data
We load CIFAR-100 once and reuse it for creating partitions.

In [None]:
# Load full dataset
train_full, test_data = load_cifar100(data_dir='./data', download=True)

# Split Train/Val (80/20 for this stage to be safe, or match project specs)
train_size = int(0.8 * len(train_full))
val_size = len(train_full) - train_size
train_data, val_data = torch.utils.data.random_split(
    train_full, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

val_loader = create_dataloader(val_data, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_data, batch_size=64, shuffle=False)

print(f"Train size: {len(train_data)}, Val size: {len(val_data)}, Test size: {len(test_data)}")

# Milestone 6: Mask Calibration
We calculate the Fisher Information Diagonal to estimate parameter sensitivity and create masks.

In [None]:
# 1. Load Baseline Model
# Check if checkpoints exist, otherwise initialize fresh
baseline_path = os.path.join(OUTPUT_DIR, 'central_baseline.pt')

model = build_model(config)
model.to(device)

if os.path.exists(baseline_path):
    print(f"Loading baseline from {baseline_path}...")
    ckpt = torch.load(baseline_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
else:
    print("Warning: central_baseline.pt not found. Using random initialization for calibration (suboptimal).")

# 2. Calibration Data (Subset of train)
calib_loader = create_dataloader(train_data, batch_size=32, shuffle=True)

# 3. Compute Sensitivity (Fisher)
print("Computing sensitivity scores...")
# Use 100 batches for better estimation (M8 Report recommends sufficient batches)
scores = compute_sensitivity_scores(model, calib_loader, device, num_batches=100) 

# 4. Generate & Save Masks
sparsity_ratio = 0.8

rules = ['least_sensitive', 'random', 'highest_magnitude']
masks = {}

for rule in rules:
    mask = create_mask(scores, model, sparsity_ratio=sparsity_ratio, rule=rule)
    masks[rule] = mask
    
    # Save
    path = os.path.join(OUTPUT_DIR, f'mask_{rule}_0.8.pt')
    save_mask(mask, path)
    
    # Check sparsity (fraction of 1s)
    active_ratio = get_mask_sparsity(mask)
    print(f"Rule: {rule} | Active params: {active_ratio:.4f}")

# Milestone 7: SparseSGDM Verification
We verify that the `SparseSGDM` optimizer correctly updates only the unmasked parameters.
**Expectation**: Masked parameters (mask=0) should NOT change. Active parameters (mask=1) SHOULD change.

In [None]:
print("--- Testing SparseSGDM ---")
# Create a dummy model and mask
dummy_model = nn.Linear(10, 1)
dummy_mask = {
    'weight': torch.cat([torch.ones(1, 5), torch.zeros(1, 5)], dim=1), # First 5 active, last 5 masked (0)
    'bias': torch.ones(1)
}

optimizer = SparseSGDM(dummy_model.parameters(), lr=0.1, mask=dummy_mask)

initial_weight = dummy_model.weight.data.clone()

# Artificial gradient
dummy_model.weight.grad = torch.ones_like(dummy_model.weight)
dummy_model.bias.grad = torch.ones_like(dummy_model.bias)

optimizer.step()

updated_weight = dummy_model.weight.data
diff = (updated_weight - initial_weight).abs()

print(f"Initial Weights (first 5 active): {initial_weight[0, :5]}")
print(f"Initial Weights (last 5 masked):  {initial_weight[0, 5:]}")
print("-" * 30)
print(f"Updated Weights (first 5 active): {updated_weight[0, :5]}")
print(f"Updated Weights (last 5 masked):  {updated_weight[0, 5:]}")
print("-" * 30)
print(f"Difference (first 5): {diff[0, :5]}")
print(f"Difference (last 5):  {diff[0, 5:]}")

# Verification Logic
if torch.all(diff[0, 5:] == 0) and torch.any(diff[0, :5] > 0):
    print("\n✅ PASS: Masked parameters remained unchanged, active parameters updated.")
else:
    print("\n❌ FAIL: Optimization behavior incorrect.")

# Milestone 8: Sparse FedAvg Experiments (Real Experiment)

We run 3 experiments to compare performance. 
**Configuration (Real Experiment):**
- `num_clients`: 100
- `clients_per_round`: 10 (0.1)
- `local_steps`: 4 (J=4)
- `num_rounds`: 20 (Set to 20-50 for convergence, kept to 20 for Kaggle runtime feasibility)

**Experiments:**
1. **IID** with `least_sensitive` mask
2. **Non-IID (Nc=1)** with `least_sensitive` mask
3. **Non-IID (Nc=1)** with `random` mask (Extension)

In [None]:
# Common Training Loop Function
def run_experiment(exp_name, is_iid, nc, mask_rule, num_rounds=20):
    print(f"\n>>> Starting Experiment: {exp_name} <<<")
    print(f"Settings: IID={is_iid}, Nc={nc}, Mask={mask_rule}, Rounds={num_rounds}")
    
    # 1. Partition Data (Real Scale: K=100)
    num_clients = 100 
    if is_iid:
        client_datasets = partition_iid(train_data, num_clients)
    else:
        client_datasets = partition_non_iid(train_data, num_clients, nc=nc)
        
    client_loaders = [create_dataloader(ds, batch_size=32, shuffle=True) for ds in client_datasets]
    
    # 2. Setup Model & Mask
    model = build_model(config)
    model.to(device)
    
    # Load appropriate mask
    mask = masks[mask_rule] # From previous cell
    
    # 3. Training Loop
    clients_per_round = 0.1 # 10 clients
    m = max(1, int(num_clients * clients_per_round))
    
    history = {'round': [], 'val_acc': [], 'test_acc': []}
    
    for r in range(1, num_rounds + 1):
        selected_clients = np.random.choice(num_clients, m, replace=False)
        
        loss, acc = run_fedavg_sparse_round(
            model, client_loaders, selected_clients,
            lr=0.01, weight_decay=1e-4, device=device,
            local_steps=4, # J=4 (Real standard)
            mask=mask
        )
        
        # Validate
        if r % 2 == 0 or r == num_rounds: # Eval every 2 rounds
            val_loss, val_acc = evaluate(model, val_loader, nn.CrossEntropyLoss(), device, show_progress=False)
            test_loss, test_acc = evaluate(model, test_loader, nn.CrossEntropyLoss(), device, show_progress=False)
            print(f"Round {r}: TrainAcc={acc:.2f}%, ValAcc={val_acc:.2f}%, TestAcc={test_acc:.2f}%")
            
            history['round'].append(r)
            history['val_acc'].append(val_acc)
            history['test_acc'].append(test_acc)
        else:
            print(f"Round {r}: TrainAcc={acc:.2f}%")
        
    # Save metrics
    save_metrics_json(os.path.join(OUTPUT_DIR, f'{exp_name}_metrics.json'), history)
    return history

In [None]:
# Experiment 1: IID + Least Sensitive
hist_iid = run_experiment('exp_iid_ls', is_iid=True, nc=None, mask_rule='least_sensitive', num_rounds=20)

In [None]:
# Experiment 2: Non-IID (Nc=1) + Least Sensitive
hist_niid_ls = run_experiment('exp_niid_ls', is_iid=False, nc=1, mask_rule='least_sensitive', num_rounds=20)

In [None]:
# Experiment 3: Non-IID (Nc=1) + Random Mask (Extension)
hist_niid_rnd = run_experiment('exp_niid_rnd', is_iid=False, nc=1, mask_rule='random', num_rounds=20)

## Results Visualization
Comparison of Validation Accuracy across different settings.

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(hist_iid['round'], hist_iid['val_acc'], marker='o', label='IID (Least Sens)')
plt.plot(hist_niid_ls['round'], hist_niid_ls['val_acc'], marker='s', label='Non-IID Nc=1 (Least Sens)')
plt.plot(hist_niid_rnd['round'], hist_niid_rnd['val_acc'], marker='^', linestyle='--', label='Non-IID Nc=1 (Random)')

plt.title('Sparse FedAvg Performance Comparison')
plt.xlabel('Round')
plt.ylabel('Validation Accuracy (%)')
plt.grid(True)
plt.legend()
plt.savefig(os.path.join(OUTPUT_DIR, 'sparse_fedavg_comparison.png'))
plt.show()