# Milestone M4 â€” FedAvg IID (Baseline FL)

**Goal**: Implement FedAvg with K=100, C=0.1, J=4 on IID data distribution.

## Configuration
- **K** = 100 clients
- **C** = 0.1 (10 clients per round)
- **J** = 4 local steps

In [None]:
import os
import json
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import utilities
from src.utils import set_seed, get_device, ensure_dir, save_checkpoint, save_metrics_json
from src.data import load_cifar100, create_dataloader, partition_iid
from src.model import build_model
from src.train import evaluate
from src.fedavg import run_fedavg

## 1. Configuration

In [None]:
# Configuration for FedAvg IID
config = {
    'exp_name': 'fedavg_iid',
    'seed': 42,
    'data_dir': './data',
    'output_dir': './outputs',
    
    # Model
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'head_only',
    'dropout': 0.0,
    
    # FL settings
    'num_clients': 100,          # K
    'clients_per_round': 0.1,    # C
    'local_steps': 4,            # J
    'num_rounds': 100,
    
    # Training
    'batch_size': 32,
    'lr': 0.01,
    'weight_decay': 1e-4,
    'num_workers': 0,
}

# Set seed and device
set_seed(config['seed'])
device = get_device()

print(f"Device: {device}")
print(f"K={config['num_clients']}, C={config['clients_per_round']}, J={config['local_steps']}")

## 2. Setup Directories

In [None]:
# Create output directories
log_dir = os.path.join(config['output_dir'], 'logs', config['exp_name'])
figures_dir = os.path.join(config['output_dir'], 'figures')
checkpoint_dir = os.path.join(config['output_dir'], 'checkpoints')

ensure_dir(log_dir)
ensure_dir(figures_dir)
ensure_dir(checkpoint_dir)

print(f"Logs: {log_dir}")
print(f"Figures: {figures_dir}")
print(f"Checkpoints: {checkpoint_dir}")

## 3. Load and Partition Data (IID)

In [None]:
print("Loading CIFAR-100...")
train_dataset, test_dataset = load_cifar100(data_dir=config['data_dir'], image_size=224)

# Split train into train/val
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = torch.utils.data.random_split(
    train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(config['seed'])
)

print(f"Train: {len(train_subset)}, Val: {len(val_subset)}, Test: {len(test_dataset)}")

# Partition into K clients (IID)
print(f"\nPartitioning data into {config['num_clients']} clients (IID)...")
client_datasets = partition_iid(train_dataset, num_clients=config['num_clients'], seed=config['seed'])

print(f"Samples per client: ~{len(train_dataset) // config['num_clients']}")

In [None]:
# Create data loaders
client_loaders = [
    create_dataloader(ds, batch_size=config['batch_size'], shuffle=True, num_workers=0)
    for ds in client_datasets
]

val_loader = create_dataloader(val_subset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])
test_loader = create_dataloader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'])

print(f"Created {len(client_loaders)} client loaders")
print(f"Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")

## 4. Build Model

In [None]:
# Build model
model = build_model(config)
model.to(device)

from src.utils import count_parameters
total_params = count_parameters(model, trainable_only=False)
trainable_params = count_parameters(model, trainable_only=True)

print(f"Model: {config['model_name']}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 5. Run FedAvg

In [None]:
print(f"\n{'='*60}")
print(f"Starting FedAvg Training")
print(f"K={config['num_clients']}, C={config['clients_per_round']}, J={config['local_steps']}")
print(f"Rounds: {config['num_rounds']}")
print(f"{'='*60}\n")

history = run_fedavg(
    global_model=model,
    client_loaders=client_loaders,
    val_loader=val_loader,
    test_loader=test_loader,
    config=config,
    device=device
)

print(f"\n{'='*60}")
print(f"Training complete!")
print(f"Best Val Accuracy: {history['best_val_acc']:.2f}%")
print(f"Final Test Accuracy: {history['test_acc'][-1]:.2f}%")
print(f"{'='*60}")

## 6. Save Results

In [None]:
# Save metrics
metrics_path = os.path.join(log_dir, 'metrics.json')
save_metrics_json(metrics_path, history)
print(f"Saved metrics to: {metrics_path}")

# Save checkpoint
checkpoint_path = os.path.join(checkpoint_dir, 'fedavg_iid_best.pt')
save_checkpoint({
    'model_state_dict': model.state_dict(),
    'best_val_acc': history['best_val_acc'],
    'config': config,
}, checkpoint_path)
print(f"Saved checkpoint to: {checkpoint_path}")

## 7. Visualization

In [None]:
# Plot training curves
rounds = history['round']

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Loss
ax = axes[0]
ax.plot(rounds, history['train_loss'], 'b-', label='Train', linewidth=2)
ax.plot(rounds, history['val_loss'], 'r-', label='Val', linewidth=2)
ax.set_xlabel('Round', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('FedAvg IID: Loss', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# Accuracy
ax = axes[1]
ax.plot(rounds, history['train_acc'], 'b-', label='Train', linewidth=2)
ax.plot(rounds, history['val_acc'], 'r-', label='Val', linewidth=2)
ax.plot(rounds, history['test_acc'], 'g--', label='Test', linewidth=2)
ax.set_xlabel('Round', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('FedAvg IID: Accuracy', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()

figure_path = os.path.join(figures_dir, 'fedavg_iid_curves.png')
plt.savefig(figure_path, dpi=150, bbox_inches='tight')
print(f"Saved figure to: {figure_path}")
plt.show()