# AML Project 2: Sparse Federated Learning 

- Mask Calibration (Fisher Information)
- SparseSGDM Optimizer Verification
- Sparse FedAvg Training (with IID/Non-IID and various Mask Rules)

**Prerequisites:**
- This notebook assumes **(Central Baseline)** has been run and a checkpoint `central_baseline.pt` is available in the outputs directory.
- Alternatively, it can run with a fresh model (but mask calibration will be random/less effective).


In [None]:
import os
import sys

# If running on Kaggle, clone repo if not present
if not os.path.exists('AML-Project-2') and not os.path.exists('src'):
    !git clone https://github.com/emanueleR3/AML-Project-2.git
    %cd AML-Project-2

# Install requirements
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from src.utils import set_seed, get_device, ensure_dir, save_checkpoint, load_checkpoint, save_metrics_json, count_parameters, AverageMeter
from src.data import load_cifar100, create_dataloader, partition_iid, partition_non_iid
from src.model import build_model
from src.train import evaluate
from src.optim import SparseSGDM
from src.sparse_fedavg import client_update_sparse, run_fedavg_sparse_round
from src.masking import compute_sensitivity_scores, create_mask, save_mask, load_mask, get_mask_sparsity

sys.path.append('.')

# Setup output dirs
OUTPUT_DIR = 'outputs'
ensure_dir(OUTPUT_DIR)
device = get_device()
print(f"Device: {device}")

# Global Config
config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'head_only',
    'dropout': 0.1,
    'device': device,
    'seed': 42
}

set_seed(config['seed'])

## Load Data
We load CIFAR-100 once and reuse it for creating partitions.

In [None]:

train_full, test_data = load_cifar100(data_dir='./data', download=True)

# Split Train/Val (80/20)
train_size = int(0.8 * len(train_full))
val_size = len(train_full) - train_size
train_data, val_data = torch.utils.data.random_split(
    train_full, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

val_loader = create_dataloader(val_data, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_data, batch_size=64, shuffle=False)

print(f"Train size: {len(train_data)}, Val size: {len(val_data)}, Test size: {len(test_data)}")

# Mask Calibration

In [None]:
baseline_path = 'output/main/central_baseline.pt'

model = build_model(config)
model.to(device)

if os.path.exists(baseline_path):
    print(f"Loading baseline from {baseline_path}...")
    ckpt = torch.load(baseline_path, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
else:
    print("Warning: central_baseline.pt not found. Using random initialization for calibration (suboptimal).")

calib_loader = create_dataloader(train_data, batch_size=32, shuffle=True)

print("Computing sensitivity scores...")
scores = compute_sensitivity_scores(model, calib_loader, device, num_batches=100) 

sparsity_ratio = 0.8

rules = ['least_sensitive', 'random', 'highest_magnitude']
masks = {}

for rule in rules:
    mask = create_mask(scores, model, sparsity_ratio=sparsity_ratio, rule=rule)
    masks[rule] = mask
    
    path = os.path.join(OUTPUT_DIR, f'mask_{rule}_0.8.pt')
    save_mask(mask, path)
    
    active_ratio = get_mask_sparsity(mask)
    print(f"Rule: {rule} | Active params: {active_ratio:.4f}")

# SparseSGDM Verification


In [None]:
print("--- Testing SparseSGDM ---")
dummy_model = nn.Linear(10, 1)

# First 5 active, last 5 masked (0)
dummy_mask = {
    'weight': torch.cat([torch.ones(1, 5), torch.zeros(1, 5)], dim=1),
    'bias': torch.ones(1)
}

optimizer = SparseSGDM(dummy_model.parameters(), lr=0.1, mask=dummy_mask)

initial_weight = dummy_model.weight.data.clone()

# Artificial gradient
dummy_model.weight.grad = torch.ones_like(dummy_model.weight)
dummy_model.bias.grad = torch.ones_like(dummy_model.bias)

optimizer.step()

updated_weight = dummy_model.weight.data
diff = (updated_weight - initial_weight).abs()

print(f"Initial Weights (first 5 active): {initial_weight[0, :5]}")
print(f"Initial Weights (last 5 masked):  {initial_weight[0, 5:]}")
print("-" * 30)
print(f"Updated Weights (first 5 active): {updated_weight[0, :5]}")
print(f"Updated Weights (last 5 masked):  {updated_weight[0, 5:]}")
print("-" * 30)
print(f"Difference (first 5): {diff[0, :5]}")
print(f"Difference (last 5):  {diff[0, 5:]}")

# Verification Logic
if torch.all(diff[0, 5:] == 0) and torch.any(diff[0, :5] > 0):
    print("\n✅ PASS: Masked parameters remained unchanged, active parameters updated.")
else:
    print("\n❌ FAIL: Optimization behavior incorrect.")

# Sparse FedAvg Experiments (Real Experiment)

In [None]:
# Common Training Loop Function
def run_experiment(exp_name, is_iid, nc, mask_rule, num_rounds=100):
    
    num_clients = 100 
    if is_iid:
        client_datasets = partition_iid(train_data, num_clients)
    else:
        client_datasets = partition_non_iid(train_data, num_clients, num_classes_per_client=nc)
        
    client_loaders = [create_dataloader(ds, batch_size=32, shuffle=True) for ds in client_datasets]
    
    model = build_model(config)
    model.to(device)
    
    mask = masks[mask_rule] 
    
    clients_per_round = 0.1 
    m = max(1, int(num_clients * clients_per_round))
    
    history = {'round': [], 'val_acc': [], 'test_acc': []}
    
    for r in range(1, num_rounds + 1):
        selected_clients = np.random.choice(num_clients, m, replace=False)
        
        loss, acc = run_fedavg_sparse_round(
            model, client_loaders, selected_clients,
            lr=0.01, weight_decay=1e-4, device=device,
            local_steps=4, 
            mask=mask
        )
        
         # Eval every 2 rounds
        if r % 2 == 0 or r == num_rounds:
            val_loss, val_acc = evaluate(model, val_loader, nn.CrossEntropyLoss(), device, show_progress=False)
            test_loss, test_acc = evaluate(model, test_loader, nn.CrossEntropyLoss(), device, show_progress=False)
            print(f"Round {r}: TrainAcc={acc:.2f}%, ValAcc={val_acc:.2f}%, TestAcc={test_acc:.2f}%")
            
            history['round'].append(r)
            history['val_acc'].append(val_acc)
            history['test_acc'].append(test_acc)
        else:
            print(f"Round {r}: TrainAcc={acc:.2f}%")
        
    save_metrics_json(os.path.join(OUTPUT_DIR, f'{exp_name}_metrics.json'), history)
    return history

In [None]:
# Experiment 1: IID + Least Sensitive
hist_iid = run_experiment('exp_iid_ls', is_iid=True, nc=None, mask_rule='least_sensitive', num_rounds=100)

In [None]:
# Experiment 2: Non-IID (Nc=1) + Least Sensitive
hist_niid_ls = run_experiment('exp_niid_ls', is_iid=False, nc=1, mask_rule='least_sensitive', num_rounds=100)

In [None]:
# Experiment 3: Non-IID (Nc=1) + Random Mask (Extension)
hist_niid_rnd = run_experiment('exp_niid_rnd', is_iid=False, nc=1, mask_rule='random', num_rounds=100)