# AML Project 2: Sparse Ablation Study + Extension

This notebook runs all sparse FL ablation studies:

**Ablation 1: Calibration Rounds** (fixed 50% sparsity)
- Tests: 1, 3, 5, 10 calibration rounds

**Ablation 2: Sparsity Ratio** (fixed 3 calibration rounds)  
- Tests: 20%, 50%, 90% sparsity

**Ablation 3 (Extension): Mask Rule Comparison**
- 5 rules tested at 80% sparsity, 3 calibration rounds:
  1. **Least Sensitive** - Keep weights with LOW Fisher scores (paper's approach)
  2. **Most Sensitive** - Keep weights with HIGH Fisher scores
  3. **Lowest Magnitude** - Keep weights with LOW absolute values
  4. **Highest Magnitude** - Keep weights with HIGH absolute values (traditional pruning)
  5. **Random** - Random selection baseline

**Protocol:** Backbone Fine-tuning (`finetune_all` + `freeze_head=True`)  
**Rounds:** 50 per experiment

In [None]:
# Clone Repository & Install Dependencies
!git clone https://github.com/emanueleR3/AML-Project-2.git
%cd AML-Project-2
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:
import sys
import os
import torch
import torch.nn as nn
import numpy as np

from src.utils import set_seed, get_device, ensure_dir, save_metrics_json
from src.data import load_cifar100, create_dataloader, partition_iid
from src.model import build_model
from src.train import evaluate
from src.sparse_fedavg import run_fedavg_sparse_round
from src.masking import compute_fisher_diagonal, create_mask, get_mask_sparsity

sys.path.append('.')

OUTPUT_DIR = 'output/sparse_ablation'
ensure_dir(OUTPUT_DIR)
device = get_device()
set_seed(42)

In [None]:
# NOTE: We do NOT apply the mask to the model weights initially (Pruning).
# We only mask the gradients during training (Sparse Updates).
# This ensures the 'frozen' backbone weights retain their pretrained values.


In [None]:
# Load Data
train_full, test_data = load_cifar100(data_dir='./data', download=True)
train_size = int(0.9 * len(train_full))
val_size = len(train_full) - train_size
train_data, val_data = torch.utils.data.random_split(
    train_full, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

val_loader = create_dataloader(val_data, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_data, batch_size=64, shuffle=False)

print(f"Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

In [None]:
# Configuration
NUM_ROUNDS = 50
SPARSITY_LEVELS = [0.2, 0.5, 0.9]
CALIB_ROUNDS = [1, 3, 5, 10]

# Training hyperparams (same as dense baseline)
LR = 1e-4
WEIGHT_DECAY = 1e-4
LOCAL_STEPS = 4
NUM_CLIENTS = 100
CLIENTS_PER_ROUND = 0.1
EVAL_FREQ = 5

model_config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'finetune_all',
    'freeze_head': True,
    'device': device
}

BASELINE_PATH = 'output/main/pretrained_head.pt'

In [None]:
# Partition Data (IID for ablation)
client_datasets = partition_iid(train_data, NUM_CLIENTS)
client_loaders = [create_dataloader(ds, 32, True, 0) for ds in client_datasets]

# Load pretrained model state
if os.path.exists(BASELINE_PATH):
    pretrained_state = torch.load(BASELINE_PATH, map_location=device)['model_state_dict']
    print(f"✓ Loaded pretrained model from {BASELINE_PATH}")
else:
    pretrained_state = None
    print("⚠ No pretrained model found, using random initialization")


def run_sparse_experiment(exp_name, mask, num_rounds=NUM_ROUNDS):
    """Run sparse FedAvg with SparseSGDM optimizer."""
    print(f"\nRunning: {exp_name}")
    
    # Build fresh model
    model = build_model(model_config)
    model.to(device)
    if pretrained_state is not None:
        model.load_state_dict(pretrained_state)
    
    # NOTE: No initial pruning applied here.
    print(f"  Mask active ratio: {get_mask_sparsity(mask):.2%}")
    
    m = max(1, int(NUM_CLIENTS * CLIENTS_PER_ROUND))
    history = {'round': [], 'train_acc': [], 'val_acc': [], 'test_acc': []}
    
    for r in range(1, num_rounds + 1):
        selected = np.random.choice(NUM_CLIENTS, m, replace=False)
        
        # Use SparseSGDM via run_fedavg_sparse_round
        loss, acc = run_fedavg_sparse_round(
            model, client_loaders, selected.tolist(),
            lr=LR, weight_decay=WEIGHT_DECAY, device=device,
            local_steps=LOCAL_STEPS, mask=mask
        )
        
        # Always save train accuracy
        history['round'].append(r)
        history['train_acc'].append(acc)
        
        # Only evaluate val/test periodically (expensive)
        if r % EVAL_FREQ == 0 or r == num_rounds:
            val_loss, val_acc = evaluate(model, val_loader, nn.CrossEntropyLoss(), device, show_progress=False)
            test_loss, test_acc = evaluate(model, test_loader, nn.CrossEntropyLoss(), device, show_progress=False)
            print(f"  Round {r}/{num_rounds} | Train: {acc:.1f}% | Val: {val_acc:.1f}% | Test: {test_acc:.1f}%")
            history['val_acc'].append(val_acc)
            history['test_acc'].append(test_acc)
        else:
            history['val_acc'].append(None)
            history['test_acc'].append(None)
    
    save_metrics_json(os.path.join(OUTPUT_DIR, f'{exp_name}.json'), history)
    print(f"  → Final Test Acc: {history['test_acc'][-1]:.2f}%")
    return history


# ============================================
# 1. Calibration Rounds Ablation (fixed 50% sparsity)
# ============================================
print("\n" + "="*60)
print("ABLATION 1: Calibration Rounds (sparsity=50%)")
print("="*60)

calib_results = {}

for num_calib in CALIB_ROUNDS:
    print(f"\n--- Calibration Rounds = {num_calib} ---")
    
    # Fresh model for calibration
    model = build_model(model_config)
    model.to(device)
    if pretrained_state is not None:
        model.load_state_dict(pretrained_state)
    
    # Multi-round Fisher calibration
    cumulative_fisher = None
    m = max(1, int(NUM_CLIENTS * CLIENTS_PER_ROUND))
    
    for cal_round in range(1, num_calib + 1):
        selected = np.random.choice(NUM_CLIENTS, m, replace=False)
        round_fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
        
        for idx in selected:
            client_fisher = compute_fisher_diagonal(model, client_loaders[idx], device, num_batches=50)
            for n in round_fisher:
                if n in client_fisher:
                    round_fisher[n] += client_fisher[n]
        
        for n in round_fisher:
            round_fisher[n] /= len(selected)
        
        if cumulative_fisher is None:
            cumulative_fisher = {n: f.clone() for n, f in round_fisher.items()}
        else:
            for n in cumulative_fisher:
                cumulative_fisher[n] += round_fisher[n]
    
    for n in cumulative_fisher:
        cumulative_fisher[n] /= num_calib
    
    # Create mask with LEAST SENSITIVE rule (keep low Fisher weights)
    mask = create_mask(cumulative_fisher, model, sparsity_ratio=0.5, rule='least_sensitive')
    
    # Train
    history = run_sparse_experiment(f'ablation_calib{num_calib}', mask)
    calib_results[num_calib] = history['test_acc'][-1]

print("\n--- Calibration Rounds Summary ---")
for k, v in calib_results.items():
    print(f"  {k} rounds: {v:.2f}%")
best_calib = max(calib_results, key=calib_results.get)
print(f"  → Best: {best_calib} calibration rounds")


# ============================================
# 2. Sparsity Ratio Ablation (fixed 3 calibration rounds)
# ============================================
print("\n" + "="*60)
print("ABLATION 2: Sparsity Ratio (calibration=3)")
print("="*60)

# Compute Fisher with 3 calibration rounds
model = build_model(model_config)
model.to(device)
if pretrained_state is not None:
    model.load_state_dict(pretrained_state)

cumulative_fisher = None
m = max(1, int(NUM_CLIENTS * CLIENTS_PER_ROUND))

for cal_round in range(1, 4):
    selected = np.random.choice(NUM_CLIENTS, m, replace=False)
    round_fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters() if p.requires_grad}
    
    for idx in selected:
        client_fisher = compute_fisher_diagonal(model, client_loaders[idx], device, num_batches=50)
        for n in round_fisher:
            if n in client_fisher:
                round_fisher[n] += client_fisher[n]
    
    for n in round_fisher:
        round_fisher[n] /= len(selected)
    
    if cumulative_fisher is None:
        cumulative_fisher = {n: f.clone() for n, f in round_fisher.items()}
    else:
        for n in cumulative_fisher:
            cumulative_fisher[n] += round_fisher[n]

for n in cumulative_fisher:
    cumulative_fisher[n] /= 3

sparsity_results = {}

for sparsity in SPARSITY_LEVELS:
    print(f"\n--- Sparsity = {sparsity:.0%} ---")
    
    # Create mask with LEAST SENSITIVE rule
    mask = create_mask(cumulative_fisher, model, sparsity_ratio=sparsity, rule='least_sensitive')
    
    history = run_sparse_experiment(f'ablation_sparsity{int(sparsity*100)}', mask)
    sparsity_results[sparsity] = history['test_acc'][-1]

print("\n--- Sparsity Ratio Summary ---")
for k, v in sparsity_results.items():
    print(f"  {k:.0%}: {v:.2f}%")
best_sparsity = max(sparsity_results, key=sparsity_results.get)
print(f"  → Best: {best_sparsity:.0%} sparsity")


# ============================================
# 3. EXTENSION: Mask Rule Comparison
# ============================================
print("\n" + "="*60)
print("ABLATION 3: Mask Rule Comparison (Extension)")
print("  Using best settings: calibration=3, sparsity=80%")
print("="*60)

# All 5 mask rules to compare
MASK_RULES = [
    'least_sensitive',    # Paper's recommended approach
    'most_sensitive',     # Opposite: keep high-sensitivity weights
    'lowest_magnitude',   # Traditional pruning: remove small weights
    'highest_magnitude',  # Opposite: remove large weights
    'random'              # Baseline: random selection
]

# Use the Fisher scores already computed (3 calibration rounds)
rule_results = {}

for rule in MASK_RULES:
    print(f"\n--- Mask Rule: {rule.replace('_', ' ').title()} ---")
    
    # Create mask with this rule
    mask = create_mask(cumulative_fisher, model, sparsity_ratio=0.8, rule=rule)
    
    # Train
    history = run_sparse_experiment(f'ablation_rule_{rule}', mask)
    rule_results[rule] = history['test_acc'][-1]

print("\n" + "="*60)
print("MASK RULE COMPARISON SUMMARY")
print("="*60)

# Sort by accuracy
sorted_rules = sorted(rule_results.items(), key=lambda x: x[1], reverse=True)
for i, (rule, acc) in enumerate(sorted_rules):
    marker = "★ BEST" if i == 0 else ""
    print(f"  {rule.replace('_', ' ').title():20s}: {acc:.2f}% {marker}")

# Key comparison
ls_acc = rule_results.get('least_sensitive', 0)
rnd_acc = rule_results.get('random', 0)
print(f"\n  Least Sensitive vs Random: {'+' if ls_acc > rnd_acc else ''}{ls_acc - rnd_acc:.2f} pp")