# Sparse Non-IID Sweep (Masking Rule Comparison)

This notebook runs **Sparse FedAvg** on the **3 most interesting Non-IID configurations**:
1. **(Nc=5, J=4)** - High heterogeneity, Dense FL works
2. **(Nc=10, J=8)** - Moderate heterogeneity, moderate steps
3. **(Nc=5, J=8)** - Degraded performance case (70.60% Dense)

**Mask Rules to compare** (select one per run):
- `least_sensitive` (default)
- `most_sensitive`
- `lowest_magnitude`
- `highest_magnitude`
- `random`

In [None]:
# Clone Repository & Install Dependencies
!git clone https://github.com/emanueleR3/AML-Project-2.git
%cd AML-Project-2
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:
import sys
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.utils import set_seed, get_device, ensure_dir, save_metrics_json
from src.data import load_cifar100, create_dataloader, partition_non_iid, partition_iid
from src.model import build_model
from src.train import evaluate
from src.sparse_fedavg import run_fedavg_sparse_round
from src.masking import compute_fisher_diagonal, create_mask, get_mask_sparsity

sys.path.append('.')
device = get_device()
set_seed(42)

In [None]:
# ==========================================
# CONFIGURATION - Change MASK_RULE to parallelize
# ==========================================
MASK_RULE = 'most_sensitive'  # Options: 'least_sensitive', 'most_sensitive', 'lowest_magnitude', 'highest_magnitude', 'random'
# ==========================================

# Fixed hyperparameters (from ablation study)
SPARSITY = 0.5      
CALIB_ROUNDS = 3    

# FedAvg Settings
NUM_CLIENTS = 100
CLIENTS_PER_ROUND = 0.1
LR = 1e-4
WEIGHT_DECAY = 1e-4
EVAL_FREQ = 10

# Scaling Logic 
BASE_J = 4
BASE_ROUNDS = 200  
TOTAL_STEPS = BASE_J * BASE_ROUNDS 

def get_scaled_rounds(j):
    return TOTAL_STEPS // j

CONFIGURATIONS = [
    {'nc': 5,  'j': 4},   # High heterogeneity, Dense works (81.25%)
    {'nc': 10, 'j': 8},   # Moderate heterogeneity (81.19%)
    {'nc': 5,  'j': 8},   # Degraded case (70.60%)
]

model_config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'finetune_all',
    'freeze_head': True,
    'device': device
}

BASELINE_PATH = 'output/main/pretrained_head.pt'
OUTPUT_DIR = f'output/sparse_noniid_{MASK_RULE}'
ensure_dir(OUTPUT_DIR)

print(f"\n{'='*60}")
print(f"MASK RULE: {MASK_RULE}")
print(f"SPARSITY: {SPARSITY*100:.0f}%")
print(f"CALIB_ROUNDS: {CALIB_ROUNDS}")
print(f"OUTPUT_DIR: {OUTPUT_DIR}")
print(f"{'='*60}")

In [None]:
# Load Data
train_full, test_data = load_cifar100(data_dir='./data', download=True)
train_size = int(0.9 * len(train_full))
val_size = len(train_full) - train_size
train_data, val_data = torch.utils.data.random_split(
    train_full, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

val_loader = create_dataloader(val_data, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_data, batch_size=64, shuffle=False)

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

In [None]:
# Load pretrained model
if os.path.exists(BASELINE_PATH):
    pretrained_state = torch.load(BASELINE_PATH, map_location=device)['model_state_dict']
    print(f"✓ Loaded pretrained model from {BASELINE_PATH}")
else:
    pretrained_state = None
    print("⚠ No pretrained model found, using random initialization")

# Calibrate mask with multi-round Fisher on IID data
print(f"\nCalibrating mask ({CALIB_ROUNDS} rounds, rule: {MASK_RULE})...")

iid_datasets = partition_iid(train_data, NUM_CLIENTS)
iid_loaders = [create_dataloader(ds, 32, True, 0) for ds in iid_datasets]

# Build model for calibration
model = build_model(model_config)
model.to(device)
if pretrained_state is not None:
    model.load_state_dict(pretrained_state)

# Multi-round Fisher calibration
cumulative_fisher = None
m = max(1, int(NUM_CLIENTS * CLIENTS_PER_ROUND))

for cal_round in range(1, CALIB_ROUNDS + 1):
    print(f"  Calibration round {cal_round}/{CALIB_ROUNDS}")
    selected = np.random.choice(NUM_CLIENTS, m, replace=False)
    round_fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
    
    for idx in selected:
        client_fisher = compute_fisher_diagonal(model, iid_loaders[idx], device, num_batches=50)
        for n in round_fisher:
            if n in client_fisher:
                round_fisher[n] += client_fisher[n]
    
    for n in round_fisher:
        round_fisher[n] /= len(selected)
    
    if cumulative_fisher is None:
        cumulative_fisher = {n: f.clone() for n, f in round_fisher.items()}
    else:
        for n in cumulative_fisher:
            cumulative_fisher[n] += round_fisher[n]

for n in cumulative_fisher:
    cumulative_fisher[n] /= CALIB_ROUNDS

# Create mask with selected rule
mask = create_mask(cumulative_fisher, model, sparsity_ratio=SPARSITY, rule=MASK_RULE)
print(f"✓ Mask created: {get_mask_sparsity(mask):.2%} active weights (rule: {MASK_RULE})")

In [None]:
# Run the 3 configurations
criterion = nn.CrossEntropyLoss()

for config in CONFIGURATIONS:
    nc = config['nc']
    j = config['j']
    scaled_rounds = get_scaled_rounds(j)
    exp_name = f'sparse_nc{nc}_j{j}_{MASK_RULE}'
    
    print(f"\n" + "*"*60)
    print(f"SPARSE: Nc={nc}, J={j}, Rounds={scaled_rounds}, Rule={MASK_RULE}")
    print(f"{'*'*60}")
    
    # 1. Partition (Non-IID)
    client_datasets = partition_non_iid(train_data, NUM_CLIENTS, nc, 42)
    client_loaders = [create_dataloader(ds, 32, True, 0) for ds in client_datasets]
    
    # 2. Build fresh model
    model = build_model(model_config)
    model.to(device)
    if pretrained_state is not None:
        model.load_state_dict(pretrained_state)
    

    m = max(1, int(NUM_CLIENTS * CLIENTS_PER_ROUND))
    history = {'round': [], 'train_acc': [], 'val_acc': [], 'test_acc': [], 'val_loss': [], 'test_loss': []}
    best_val_acc = 0.0
    
    for r in range(1, scaled_rounds + 1):
        selected = np.random.choice(NUM_CLIENTS, m, replace=False)
        

        loss, acc = run_fedavg_sparse_round(
            model, client_loaders, selected.tolist(),
            lr=LR, weight_decay=WEIGHT_DECAY, device=device,
            local_steps=j, mask=mask
        )
        
        history['round'].append(r)
        history['train_acc'].append(acc)
        
        # Evaluate periodically
        if r % EVAL_FREQ == 0 or r == scaled_rounds:
            val_loss, val_acc = evaluate(model, val_loader, criterion, device, show_progress=False)
            test_loss, test_acc = evaluate(model, test_loader, criterion, device, show_progress=False)
            print(f"  Round {r}/{scaled_rounds} | Train: {acc:.1f}% | Val: {val_acc:.1f}% | Test: {test_acc:.1f}%")
            history['val_acc'].append(val_acc)
            history['val_loss'].append(val_loss)
            history['test_acc'].append(test_acc)
            history['test_loss'].append(test_loss)
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
        else:
            history['val_acc'].append(None)
            history['val_loss'].append(None)
            history['test_acc'].append(None)
            history['test_loss'].append(None)
    

    history['best_val_acc'] = best_val_acc
    
    # Save
    save_metrics_json(os.path.join(OUTPUT_DIR, f'{exp_name}.json'), history)
    print(f"✓ Saved: {exp_name}.json | Final Test: {history['test_acc'][-1]:.2f}% | Best Val: {best_val_acc:.2f}%")

print(f"\n{'='*60}")
print(f"All experiments complete! Results in: {OUTPUT_DIR}")
print(f"{'='*60}")