# AML Project 2: Protocol Step 4 - Scaled Non-IID Sweep

This notebook runs the Non-IID experiments with scaled rounds to ensure constant computation, using the **Backbone Fine-tuning Protocol**.
- **Pre-requisite:** `pretrained_head.pt` (Protocol Step 1)
- **Backbone:** Unfrozen (`finetune_all`)
- **Classifier:** Frozen (`freeze_head=True`)
- **Base Rounds:** 300 (for J=4)
- **Scaling Logic:** Rounds = (4 * 300) / J

In [None]:
# Clone Repository & Install Dependencies
!git clone https://github.com/emanueleR3/AML-Project-2.git
%cd AML-Project-2
!pip install -r requirements.txt
!pip install torch torchvision numpy matplotlib tqdm

In [None]:
# Imports & Setup
import sys
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.utils import set_seed, get_device, ensure_dir, save_metrics_json
from src.data import load_cifar100, create_dataloader, partition_non_iid
from src.model import build_model
from src.train import evaluate
from src.fedavg import run_fedavg

sys.path.append('.')

# Setup output dirs
OUTPUT_DIR = 'output/scaled'
ensure_dir(OUTPUT_DIR)
device = get_device()
print(f"Device: {device}")

set_seed(42)

In [None]:
# Load Data
print("Loading CIFAR-100...")
train_trainval, test_dataset = load_cifar100(data_dir='./data', image_size=224, download=True)

# Split Train/Val (Consistent with other notebooks)
train_size = int(0.9 * len(train_trainval))
val_size = len(train_trainval) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    train_trainval, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

val_loader = create_dataloader(val_dataset, batch_size=64, shuffle=False)
test_loader = create_dataloader(test_dataset, batch_size=64, shuffle=False)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

In [None]:
# Base Configuration for Non-IID Sweep
base_config = {
    'num_clients': 100,
    'clients_per_round': 0.1,
    'local_steps': 4,
    'num_rounds': 300,
    'batch_size': 64,
    'lr': 1e-4,  # Lower LR for backbone
    'weight_decay': 1e-4,
    'seed': 42,
    'eval_freq': 10
}

model_config = {
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'finetune_all', # Unfreeze backbone
    'freeze_head': True,             # Freeze classifier
    'dropout': 0.1,
    'device': device
}

# Path to pre-trained head
HEAD_PATH = 'output/main/pretrained_head.pt'
if not os.path.exists(HEAD_PATH):
    raise FileNotFoundError(f"Pre-trained head not found at {HEAD_PATH}. Run pretrain_head.ipynb first.")

In [None]:
# Scaled Rounds Logic
BASE_J = 4
BASE_ROUNDS = 300
TOTAL_STEPS = BASE_J * BASE_ROUNDS  # = 1200

def get_scaled_rounds(j):
    return TOTAL_STEPS // j

# ==========================================
# SWEEP CONTROL - CONFIGURE EXECUTION HERE
# ==========================================
RUN_ALL_NC = False  # Set to True to run ALL scenarios
TARGET_NC = 50       # Use this logic to split work among team members. Options: 1, 5, 10, 50
# ==========================================

# Sweep Parameters
NC_VALUES_ALL = [1, 5, 10, 50]
J_VALUES = [4, 8, 16]

if RUN_ALL_NC:
    nc_scenarios = NC_VALUES_ALL
    print("Running FULL sweep for all Nc values.")
else:
    nc_scenarios = [TARGET_NC]
    print(f"Running PARTIAL sweep for Nc={TARGET_NC} only.")


print("\nScaled Rounds Configuration:")
for j in J_VALUES:
    print(f"  J={j:2d} → Rounds={get_scaled_rounds(j):3d} | Total Steps={j * get_scaled_rounds(j)}")

In [None]:
for nc in nc_scenarios:
    for j in J_VALUES:
        scaled_rounds = get_scaled_rounds(j)
        exp_name = f'noniid_nc{nc}_j{j}'
        print(f"\n" + "*"*40)
        print(f"Starting: Nc={nc}, J={j}, Rounds={scaled_rounds}")
        print(f"{'*'*40}")
        
        # 1. Partition Data (Reproducible seed)
        client_datasets = partition_non_iid(train_dataset, 100, nc, 42)
        client_loaders = [create_dataloader(ds, 64, True, 0) for ds in client_datasets]
        
        # 2. Configure Experiment
        sweep_config = base_config.copy()
        sweep_config['local_steps'] = j
        sweep_config['num_rounds'] = scaled_rounds
        
        # 3. Build Fresh Model with Pre-trained Head
        model = build_model(model_config)
        model.to(device)
        
        # Load pre-trained head
        ckpt = torch.load(HEAD_PATH, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        
        # 4. Run Training
        history = run_fedavg(model, client_loaders, val_loader, test_loader, sweep_config, device)
        
        # 5. Save Results
        save_metrics_json(os.path.join(OUTPUT_DIR, f'{exp_name}.json'), history)
        print(f"✓ Final Test Acc: {history['test_acc'][-1]:.2f}%")