# AML Project 2: Complete Sparse Federated Learning

This notebook runs all sparse FedAvg experiments:

**Phase 1: Ablation Studies (50 rounds each)**
1. Calibration rounds sweep: 1, 3, 5, 10 rounds
2. Sparsity ratio sweep: 60%, 70%, 80%, 90%

**Phase 2: Final Experiments (100 rounds each)**
- All 5 mask rules on Non-IID (Nc=1) data:
  1. Least Sensitive
  2. Most Sensitive
  3. Lowest Magnitude
  4. Highest Magnitude
  5. Random
- Plus: IID + Least Sensitive baseline

**Estimated runtime: ~10.5 hours on T4 GPU**

In [None]:
import os
import sys

# Clone repo if needed
if not os.path.exists('AML-Project-2') and not os.path.exists('src'):
    !git clone https://github.com/emanueleR3/AML-Project-2.git
    %cd AML-Project-2

!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm pandas

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import copy
from tqdm import tqdm

from src.utils import set_seed, get_device, ensure_dir, save_checkpoint, save_metrics_json, AverageMeter
from src.data import load_cifar100, create_dataloader, partition_iid, partition_non_iid
from src.model import build_model
from src.train import evaluate
from src.optim import SparseSGDM
from src.sparse_fedavg import run_fedavg_sparse_round
from src.masking import compute_sensitivity_scores, compute_fisher_diagonal, create_mask, save_mask, get_mask_sparsity

sys.path.append('.')

# Output directory
OUTPUT_DIR = 'output/sparse'
ensure_dir(OUTPUT_DIR)
device = get_device()
print(f"Device: {device}")

# Model config
config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'head_only',
    'dropout': 0.1,
    'device': device,
    'seed': 42
}

set_seed(config['seed'])

## Load Data & Pretrained Model

In [None]:
# Load CIFAR-100
train_full, test_data = load_cifar100(data_dir='./data', download=True)

train_size = int(0.9 * len(train_full))
val_size = len(train_full) - train_size
train_data, val_data = torch.utils.data.random_split(
    train_full, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

val_loader = create_dataloader(val_data, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_data, batch_size=64, shuffle=False)

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

In [None]:
# Load pretrained model
baseline_path = 'output/main/central_baseline.pt'
pretrained_state = None

if os.path.exists(baseline_path):
    ckpt = torch.load(baseline_path, map_location=device)
    pretrained_state = ckpt['model_state_dict']
    print(f"✓ Loaded pretrained model from {baseline_path}")
else:
    print("⚠ No pretrained model found, using random initialization")

## Helper Functions

In [None]:
def compute_multi_round_fisher(
    model, client_loaders, device,
    num_calibration_rounds=3,
    clients_per_round=0.1,
    local_steps=4,
    lr=0.01,
    num_batches=50
):
    """Multi-round Fisher calibration (Paper [15] Sec. 4.2)"""
    num_clients = len(client_loaders)
    m = max(1, int(num_clients * clients_per_round))
    
    cumulative_fisher = None
    model_copy = copy.deepcopy(model)
    model_copy.to(device)
    
    for round_idx in range(1, num_calibration_rounds + 1):
        selected = np.random.choice(num_clients, m, replace=False)
        
        # Compute Fisher from sampled clients
        round_fisher = {n: torch.zeros_like(p) for n, p in model_copy.named_parameters() if p.requires_grad}
        
        for idx in selected:
            client_fisher = compute_fisher_diagonal(model_copy, client_loaders[idx], device, num_batches=num_batches)
            for n in round_fisher:
                if n in client_fisher:
                    round_fisher[n] += client_fisher[n]
        
        for n in round_fisher:
            round_fisher[n] /= len(selected)
        
        # Accumulate
        if cumulative_fisher is None:
            cumulative_fisher = {n: f.clone() for n, f in round_fisher.items()}
        else:
            for n in cumulative_fisher:
                cumulative_fisher[n] += round_fisher[n]
        
        # Update model between rounds
        if round_idx < num_calibration_rounds:
            temp_mask = create_mask(cumulative_fisher, model_copy, sparsity_ratio=0.8, rule='least_sensitive')
            run_fedavg_sparse_round(
                model_copy, client_loaders, selected.tolist(),
                lr=lr, weight_decay=1e-4, device=device,
                local_steps=local_steps, mask=temp_mask
            )
    
    # Normalize
    for n in cumulative_fisher:
        cumulative_fisher[n] /= num_calibration_rounds
    
    return cumulative_fisher


def run_sparse_training(
    client_loaders, mask, num_rounds, 
    eval_freq=10, exp_name="exp"
):
    """Run sparse FedAvg training"""
    model = build_model(config)
    model.to(device)
    if pretrained_state is not None:
        model.load_state_dict(pretrained_state)
    
    num_clients = len(client_loaders)
    m = max(1, int(num_clients * 0.1))
    
    history = {'round': [], 'train_acc': [], 'val_acc': [], 'test_acc': []}
    
    for r in range(1, num_rounds + 1):
        selected = np.random.choice(num_clients, m, replace=False)
        
        loss, acc = run_fedavg_sparse_round(
            model, client_loaders, selected.tolist(),
            lr=0.01, weight_decay=1e-4, device=device,
            local_steps=4, mask=mask
        )
        
        if r % eval_freq == 0 or r == num_rounds:
            val_loss, val_acc = evaluate(model, val_loader, nn.CrossEntropyLoss(), device, show_progress=False)
            test_loss, test_acc = evaluate(model, test_loader, nn.CrossEntropyLoss(), device, show_progress=False)
            print(f"Round {r}/{num_rounds} | Train: {acc:.1f}% | Val: {val_acc:.1f}% | Test: {test_acc:.1f}%")
            
            history['round'].append(r)
            history['train_acc'].append(acc)
            history['val_acc'].append(val_acc)
            history['test_acc'].append(test_acc)
    
    save_metrics_json(os.path.join(OUTPUT_DIR, f'{exp_name}.json'), history)
    return history

print("Helper functions defined.")

---

# Phase 1: Ablation Studies (50 rounds each)

Find optimal calibration rounds and sparsity ratio.

In [None]:
# Create IID partition for ablations
num_clients = 100
client_datasets_iid = partition_iid(train_data, num_clients)
client_loaders_iid = [create_dataloader(ds, batch_size=32, shuffle=True) for ds in client_datasets_iid]
print(f"Created {num_clients} IID client loaders")

## 1.1 Calibration Rounds Ablation

In [None]:
CALIB_ROUNDS = [1, 3, 5, 10]
ABLATION_ROUNDS = 50
FIXED_SPARSITY = 0.8

calib_results = {}

for num_calib in CALIB_ROUNDS:
    print(f"\n{'='*50}")
    print(f"Calibration Rounds = {num_calib}")
    print(f"{'='*50}")
    
    # Fresh model for calibration
    model = build_model(config)
    model.to(device)
    if pretrained_state is not None:
        model.load_state_dict(pretrained_state)
    
    # Multi-round calibration
    fisher = compute_multi_round_fisher(
        model, client_loaders_iid, device,
        num_calibration_rounds=num_calib
    )
    
    # Create mask
    mask = create_mask(fisher, model, sparsity_ratio=FIXED_SPARSITY, rule='least_sensitive')
    print(f"Mask active ratio: {get_mask_sparsity(mask):.2%}")
    
    # Train
    history = run_sparse_training(
        client_loaders_iid, mask, ABLATION_ROUNDS,
        eval_freq=10, exp_name=f'ablation_calib{num_calib}'
    )
    
    calib_results[num_calib] = history['test_acc'][-1]
    print(f"→ Final Test Acc: {calib_results[num_calib]:.2f}%")

In [None]:
print("\nCalibration Rounds Ablation Results:")
print("-" * 30)
for k, v in calib_results.items():
    print(f"  {k} rounds: {v:.2f}%")

best_calib = max(calib_results, key=calib_results.get)
print(f"\n→ Best: {best_calib} calibration rounds")

## 1.2 Sparsity Ratio Ablation

In [None]:
SPARSITY_RATIOS = [0.6, 0.7, 0.8, 0.9]
FIXED_CALIB = 3  # Use 3 calibration rounds

sparsity_results = {}

# Compute Fisher once with 3 calibration rounds
model = build_model(config)
model.to(device)
if pretrained_state is not None:
    model.load_state_dict(pretrained_state)

print("Computing Fisher with 3 calibration rounds...")
fisher = compute_multi_round_fisher(
    model, client_loaders_iid, device,
    num_calibration_rounds=FIXED_CALIB
)

for sparsity in SPARSITY_RATIOS:
    print(f"\n{'='*50}")
    print(f"Sparsity Ratio = {sparsity:.0%}")
    print(f"{'='*50}")
    
    mask = create_mask(fisher, model, sparsity_ratio=sparsity, rule='least_sensitive')
    print(f"Mask active ratio: {get_mask_sparsity(mask):.2%}")
    
    history = run_sparse_training(
        client_loaders_iid, mask, ABLATION_ROUNDS,
        eval_freq=10, exp_name=f'ablation_sparsity{int(sparsity*100)}'
    )
    
    sparsity_results[sparsity] = history['test_acc'][-1]
    print(f"→ Final Test Acc: {sparsity_results[sparsity]:.2f}%")

In [None]:
print("\nSparsity Ratio Ablation Results:")
print("-" * 30)
for k, v in sparsity_results.items():
    print(f"  {k:.0%}: {v:.2f}%")

best_sparsity = max(sparsity_results, key=sparsity_results.get)
print(f"\n→ Best: {best_sparsity:.0%} sparsity")

---

# Phase 2: Final Experiments (100 rounds each)

Testing all 5 mask rules on Non-IID data.

In [None]:
# Use best parameters (or defaults if ablation didn't complete)
OPTIMAL_CALIB = best_calib if 'best_calib' in dir() else 3
OPTIMAL_SPARSITY = best_sparsity if 'best_sparsity' in dir() else 0.8

print(f"Optimal Parameters:")
print(f"  Calibration Rounds: {OPTIMAL_CALIB}")
print(f"  Sparsity Ratio: {OPTIMAL_SPARSITY:.0%}")

# Define all mask rules to test
MASK_RULES = ['least_sensitive', 'most_sensitive', 'lowest_magnitude', 'highest_magnitude', 'random']

In [None]:
# Create Non-IID partition (Nc=1)
client_datasets_noniid = partition_non_iid(train_data, num_clients, num_classes_per_client=1, seed=42)
client_loaders_noniid = [create_dataloader(ds, batch_size=32, shuffle=True) for ds in client_datasets_noniid]
print(f"Created {num_clients} Non-IID client loaders (Nc=1)")

In [None]:
# Compute Fisher on Non-IID data (needed for sensitivity-based masks)
print("Computing Fisher scores on Non-IID data...")
model = build_model(config)
model.to(device)
if pretrained_state is not None:
    model.load_state_dict(pretrained_state)

fisher_noniid = compute_multi_round_fisher(
    model, client_loaders_noniid, device,
    num_calibration_rounds=OPTIMAL_CALIB
)
print("Fisher computation complete.")

## Experiment 1: IID + Least Sensitive (Baseline)

In [None]:
print("\n" + "="*60)
print("Experiment 1: IID + Least Sensitive Mask")
print("="*60)

# Calibrate on IID
model = build_model(config)
model.to(device)
if pretrained_state is not None:
    model.load_state_dict(pretrained_state)

fisher_iid = compute_multi_round_fisher(
    model, client_loaders_iid, device,
    num_calibration_rounds=OPTIMAL_CALIB
)

mask_iid_ls = create_mask(fisher_iid, model, sparsity_ratio=OPTIMAL_SPARSITY, rule='least_sensitive')
save_mask(mask_iid_ls, os.path.join(OUTPUT_DIR, 'mask_iid_ls.pt'))
print(f"Mask active ratio: {get_mask_sparsity(mask_iid_ls):.2%}")

hist_iid_ls = run_sparse_training(
    client_loaders_iid, mask_iid_ls, 100,
    eval_freq=10, exp_name='exp_iid_ls'
)
print(f"\n→ Final Test Accuracy: {hist_iid_ls['test_acc'][-1]:.2f}%")

## Experiments 2-6: Non-IID + All Mask Rules

In [None]:
# Run all mask rules on Non-IID data
mask_results = {}

for rule_idx, rule in enumerate(MASK_RULES, start=2):
    print(f"\n{'='*60}")
    print(f"Experiment {rule_idx}: Non-IID (Nc=1) + {rule.replace('_', ' ').title()}")
    print(f"{'='*60}")
    
    # Create mask with this rule
    mask = create_mask(fisher_noniid, model, sparsity_ratio=OPTIMAL_SPARSITY, rule=rule)
    save_mask(mask, os.path.join(OUTPUT_DIR, f'mask_noniid_{rule}.pt'))
    print(f"Mask active ratio: {get_mask_sparsity(mask):.2%}")
    
    # Train
    history = run_sparse_training(
        client_loaders_noniid, mask, 100,
        eval_freq=10, exp_name=f'exp_noniid_{rule}'
    )
    
    mask_results[rule] = history['test_acc'][-1]
    print(f"\n→ Final Test Accuracy: {mask_results[rule]:.2f}%")

---

# Results Summary

In [None]:
print("\n" + "="*60)
print("FINAL RESULTS SUMMARY")
print("="*60)

print("\n--- Ablation Studies (50 rounds) ---")
print("\nCalibration Rounds:")
for k, v in calib_results.items():
    marker = "← best" if k == best_calib else ""
    print(f"  {k} rounds: {v:.2f}% {marker}")

print("\nSparsity Ratios:")
for k, v in sparsity_results.items():
    marker = "← best" if k == best_sparsity else ""
    print(f"  {k:.0%}: {v:.2f}% {marker}")

print("\n--- Final Experiments (100 rounds) ---")
print(f"\n  IID + Least Sensitive: {hist_iid_ls['test_acc'][-1]:.2f}%")
print("\n  Non-IID (Nc=1) Mask Rule Comparison:")

# Sort by accuracy
sorted_masks = sorted(mask_results.items(), key=lambda x: x[1], reverse=True)
for i, (rule, acc) in enumerate(sorted_masks):
    rank = "★" if i == 0 else " "
    print(f"  {rank} {rule.replace('_', ' ').title()}: {acc:.2f}%")

# Key comparison
ls_acc = mask_results.get('least_sensitive', 0)
rnd_acc = mask_results.get('random', 0)
print(f"\n  Least Sensitive vs Random: +{ls_acc - rnd_acc:.2f} pp")

In [None]:
# Save complete summary
summary = {
    'optimal_params': {
        'calibration_rounds': OPTIMAL_CALIB,
        'sparsity_ratio': OPTIMAL_SPARSITY
    },
    'ablation_calibration': calib_results,
    'ablation_sparsity': {str(k): v for k, v in sparsity_results.items()},
    'final_experiments': {
        'iid_ls': hist_iid_ls['test_acc'][-1],
        **{f'noniid_{k}': v for k, v in mask_results.items()}
    }
}

save_metrics_json(os.path.join(OUTPUT_DIR, 'complete_summary.json'), summary)
print(f"\nResults saved to {OUTPUT_DIR}/")

In [None]:
# Visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Calibration rounds
ax1 = axes[0]
ax1.bar(range(len(calib_results)), list(calib_results.values()), color='steelblue')
ax1.set_xticks(range(len(calib_results)))
ax1.set_xticklabels(list(calib_results.keys()))
ax1.set_xlabel('Calibration Rounds')
ax1.set_ylabel('Test Accuracy (%)')
ax1.set_title('Effect of Calibration Rounds')
ax1.grid(axis='y', alpha=0.3)

# Plot 2: Sparsity ratio
ax2 = axes[1]
ax2.plot(list(sparsity_results.keys()), list(sparsity_results.values()), 'o-', color='darkorange', linewidth=2, markersize=8)
ax2.set_xlabel('Sparsity Ratio')
ax2.set_ylabel('Test Accuracy (%)')
ax2.set_title('Effect of Sparsity Ratio')
ax2.grid(True, alpha=0.3)

# Plot 3: Mask rule comparison
ax3 = axes[2]
rules = list(mask_results.keys())
accs = list(mask_results.values())
colors = ['#2ecc71' if r == 'least_sensitive' else '#e74c3c' if r == 'random' else '#3498db' for r in rules]
bars = ax3.bar(range(len(rules)), accs, color=colors)
ax3.set_xticks(range(len(rules)))
ax3.set_xticklabels([r.replace('_', '\n') for r in rules], fontsize=9)
ax3.set_ylabel('Test Accuracy (%)')
ax3.set_title('Mask Rule Comparison (Non-IID)')
ax3.grid(axis='y', alpha=0.3)

# Add value labels
for bar, acc in zip(bars, accs):
    ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, f'{acc:.1f}', ha='center', fontsize=9)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'all_results.pdf'), bbox_inches='tight')
plt.savefig(os.path.join(OUTPUT_DIR, 'all_results.png'), dpi=150, bbox_inches='tight')
plt.show()