# Milestone M5 — Non-IID + Sweep on Nc and J

**Goal**: Study the effect of non-IID data heterogeneity and local steps on FedAvg performance.

## Experiment Configuration
- **K** = 100 clients
- **C** = 0.1 (10 clients per round)
- **Nc** ∈ {1, 5, 10, 50} (classes per client)
- **J** ∈ {4, 8, 16} (local steps)
- Fixed number of rounds for fair comparison

In [None]:
import os
import json
import copy
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from itertools import product
from tqdm import tqdm

# Import utilities
from src.utils import set_seed, get_device, ensure_dir, save_checkpoint, save_metrics_json
from src.data import load_cifar100, create_dataloader, partition_non_iid, get_transforms
from src.model import build_model
from src.train import evaluate
from src.fedavg import run_fedavg

## 1. Configuration

In [None]:
# Base configuration
base_config = {
    'exp_name': 'fedavg_noniid_sweep',
    'seed': 42,
    'data_dir': './data',
    'output_dir': './outputs',
    
    # Model
    'model_name': 'dino_vits16',
    'num_classes': 100,
    'freeze_policy': 'head_only',
    'dropout': 0.0,
    
    # FL settings
    'num_clients': 100,          # K
    'clients_per_round': 0.1,    # C
    'num_rounds': 50,            # Fixed rounds for comparison
    
    # Training
    'batch_size': 32,
    'lr': 0.01,
    'weight_decay': 1e-4,
    'num_workers': 0,
}

# Sweep parameters
NC_VALUES = [1, 5, 10, 50]   # Classes per client
J_VALUES = [4, 8, 16]        # Local steps

# Set seed and device
set_seed(base_config['seed'])
device = get_device()
print(f"Device: {device}")
print(f"Sweep: Nc ∈ {NC_VALUES}, J ∈ {J_VALUES}")
print(f"Total experiments: {len(NC_VALUES) * len(J_VALUES)}")

## 2. Setup Output Directories

In [None]:
# Create output directories
sweep_dir = os.path.join(base_config['output_dir'], 'logs', base_config['exp_name'])
figures_dir = os.path.join(base_config['output_dir'], 'figures')
checkpoint_dir = os.path.join(base_config['output_dir'], 'checkpoints')

ensure_dir(sweep_dir)
ensure_dir(figures_dir)
ensure_dir(checkpoint_dir)

print(f"Sweep results: {sweep_dir}")
print(f"Figures: {figures_dir}")

## 3. Load Dataset

In [None]:
print("Loading CIFAR-100...")
train_dataset, test_dataset = load_cifar100(data_dir=base_config['data_dir'], image_size=224)

# Split train into train/val
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = torch.utils.data.random_split(
    train_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(base_config['seed'])
)

# Create val and test loaders (these are shared across experiments)
val_loader = create_dataloader(val_subset, batch_size=base_config['batch_size'], shuffle=False, num_workers=base_config['num_workers'])
test_loader = create_dataloader(test_dataset, batch_size=base_config['batch_size'], shuffle=False, num_workers=base_config['num_workers'])

print(f"Train: {len(train_subset)}, Val: {len(val_subset)}, Test: {len(test_dataset)}")

## 4. Run Sweep Experiments

In [None]:
def run_experiment(nc: int, j: int, train_dataset, val_loader, test_loader, base_config, device):
    """
    Run a single FedAvg experiment with given Nc and J.
    """
    print(f"\n{'='*60}")
    print(f"Running experiment: Nc={nc}, J={j}")
    print(f"{'='*60}")
    
    # Reset seed for reproducibility
    set_seed(base_config['seed'])
    
    # Partition data (non-IID)
    print(f"Partitioning data: {base_config['num_clients']} clients, {nc} classes each...")
    client_datasets = partition_non_iid(
        train_dataset,
        num_clients=base_config['num_clients'],
        num_classes_per_client=nc,
        seed=base_config['seed']
    )
    
    # Create client loaders
    client_loaders = [
        create_dataloader(ds, batch_size=base_config['batch_size'], shuffle=True, num_workers=0)
        for ds in client_datasets
    ]
    
    # Build fresh model
    model = build_model(base_config)
    model.to(device)
    
    # Run FedAvg
    config = {
        **base_config,
        'local_steps': j,
        'nc': nc,
    }
    
    history = run_fedavg(
        global_model=model,
        client_loaders=client_loaders,
        val_loader=val_loader,
        test_loader=test_loader,
        config=config,
        device=device
    )
    
    # Add experiment info to history
    history['nc'] = nc
    history['local_steps'] = j
    history['config'] = config
    
    return history

In [None]:
# Run all experiments
all_results = {}

for nc in NC_VALUES:
    for j in J_VALUES:
        exp_key = f"nc{nc}_j{j}"
        
        # Run experiment
        history = run_experiment(
            nc=nc, j=j,
            train_dataset=train_dataset,  # Use original train dataset for partitioning
            val_loader=val_loader,
            test_loader=test_loader,
            base_config=base_config,
            device=device
        )
        
        all_results[exp_key] = history
        
        # Save individual experiment results
        exp_path = os.path.join(sweep_dir, f'{exp_key}_metrics.json')
        
        # Convert to serializable format
        save_data = {
            k: v if not isinstance(v, dict) else {str(kk): vv for kk, vv in v.items()}
            for k, v in history.items()
        }
        save_metrics_json(exp_path, save_data)
        print(f"Saved results to {exp_path}")

print(f"\n{'='*60}")
print("All experiments completed!")
print(f"{'='*60}")

## 5. Results Table (Nc × J)

In [None]:
# Create results table
print("\n" + "="*70)
print("RESULTS TABLE: Final Test Accuracy (%)")
print("="*70)

# Header
header = "Nc\\J   |" + " | ".join([f"  J={j:2d}  " for j in J_VALUES]) + " |"
print(header)
print("-" * len(header))

# Results matrix
results_matrix = []
for nc in NC_VALUES:
    row = []
    row_str = f"Nc={nc:2d}  |"
    for j in J_VALUES:
        exp_key = f"nc{nc}_j{j}"
        if exp_key in all_results:
            final_test_acc = all_results[exp_key]['test_acc'][-1]
            row.append(final_test_acc)
            row_str += f" {final_test_acc:6.2f}% |"
        else:
            row.append(None)
            row_str += "    N/A |"
    results_matrix.append(row)
    print(row_str)

print("="*70)

# Best Val Accuracy Table
print("\n" + "="*70)
print("RESULTS TABLE: Best Validation Accuracy (%)")
print("="*70)
print(header)
print("-" * len(header))

for nc in NC_VALUES:
    row_str = f"Nc={nc:2d}  |"
    for j in J_VALUES:
        exp_key = f"nc{nc}_j{j}"
        if exp_key in all_results:
            best_val_acc = all_results[exp_key].get('best_val_acc', max(all_results[exp_key]['val_acc']))
            row_str += f" {best_val_acc:6.2f}% |"
        else:
            row_str += "    N/A |"
    print(row_str)

print("="*70)

## 6. Visualization

In [None]:
# Plot 1: Test accuracy curves for all experiments
fig, axes = plt.subplots(1, len(NC_VALUES), figsize=(16, 4), sharey=True)

colors = plt.cm.viridis(np.linspace(0, 0.8, len(J_VALUES)))

for ax_idx, nc in enumerate(NC_VALUES):
    ax = axes[ax_idx]
    for j_idx, j in enumerate(J_VALUES):
        exp_key = f"nc{nc}_j{j}"
        if exp_key in all_results:
            rounds = all_results[exp_key]['round']
            test_acc = all_results[exp_key]['test_acc']
            ax.plot(rounds, test_acc, label=f'J={j}', color=colors[j_idx], linewidth=2)
    
    ax.set_title(f'Nc = {nc}', fontsize=12, fontweight='bold')
    ax.set_xlabel('Round', fontsize=10)
    if ax_idx == 0:
        ax.set_ylabel('Test Accuracy (%)', fontsize=10)
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)

plt.suptitle('FedAvg Test Accuracy: Effect of Nc and J', fontsize=14, fontweight='bold')
plt.tight_layout()

curve_path = os.path.join(figures_dir, 'fedavg_noniid_curves.png')
plt.savefig(curve_path, dpi=150, bbox_inches='tight')
print(f"Saved curves to: {curve_path}")
plt.show()

In [None]:
# Plot 2: Heatmap of final test accuracy
import numpy as np

# Build matrix
acc_matrix = np.zeros((len(NC_VALUES), len(J_VALUES)))
for i, nc in enumerate(NC_VALUES):
    for k, j in enumerate(J_VALUES):
        exp_key = f"nc{nc}_j{k}"
        if exp_key in all_results:
            acc_matrix[i, k] = all_results[exp_key]['test_acc'][-1]

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(acc_matrix, cmap='YlGnBu', aspect='auto')

# Labels
ax.set_xticks(range(len(J_VALUES)))
ax.set_xticklabels([f'J={j}' for j in J_VALUES])
ax.set_yticks(range(len(NC_VALUES)))
ax.set_yticklabels([f'Nc={nc}' for nc in NC_VALUES])
ax.set_xlabel('Local Steps (J)', fontsize=12)
ax.set_ylabel('Classes per Client (Nc)', fontsize=12)
ax.set_title('Final Test Accuracy (%)', fontsize=14, fontweight='bold')

# Add values to cells
for i in range(len(NC_VALUES)):
    for k in range(len(J_VALUES)):
        text = ax.text(k, i, f'{acc_matrix[i, k]:.1f}',
                       ha='center', va='center', color='black', fontsize=11)

plt.colorbar(im, ax=ax, label='Accuracy (%)')
plt.tight_layout()

heatmap_path = os.path.join(figures_dir, 'fedavg_noniid_heatmap.png')
plt.savefig(heatmap_path, dpi=150, bbox_inches='tight')
print(f"Saved heatmap to: {heatmap_path}")
plt.show()

In [None]:
# Plot 3: Effect of Nc (fixed J values)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Left: Lines for each J value showing Nc effect
ax = axes[0]
for j_idx, j in enumerate(J_VALUES):
    accs = []
    for nc in NC_VALUES:
        exp_key = f"nc{nc}_j{j}"
        if exp_key in all_results:
            accs.append(all_results[exp_key]['test_acc'][-1])
        else:
            accs.append(np.nan)
    ax.plot(NC_VALUES, accs, marker='o', label=f'J={j}', linewidth=2, markersize=8)

ax.set_xlabel('Classes per Client (Nc)', fontsize=12)
ax.set_ylabel('Final Test Accuracy (%)', fontsize=12)
ax.set_title('Effect of Non-IID Severity', fontsize=12, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xscale('log')
ax.set_xticks(NC_VALUES)
ax.set_xticklabels(NC_VALUES)

# Right: Lines for each Nc value showing J effect
ax = axes[1]
for nc_idx, nc in enumerate(NC_VALUES):
    accs = []
    for j in J_VALUES:
        exp_key = f"nc{nc}_j{j}"
        if exp_key in all_results:
            accs.append(all_results[exp_key]['test_acc'][-1])
        else:
            accs.append(np.nan)
    ax.plot(J_VALUES, accs, marker='s', label=f'Nc={nc}', linewidth=2, markersize=8)

ax.set_xlabel('Local Steps (J)', fontsize=12)
ax.set_ylabel('Final Test Accuracy (%)', fontsize=12)
ax.set_title('Effect of Local Steps', fontsize=12, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()

effect_path = os.path.join(figures_dir, 'fedavg_noniid_effects.png')
plt.savefig(effect_path, dpi=150, bbox_inches='tight')
print(f"Saved effects plot to: {effect_path}")
plt.show()

## 7. Save Summary Results

In [None]:
# Save all results summary
summary = {
    'nc_values': NC_VALUES,
    'j_values': J_VALUES,
    'experiments': {}
}

for exp_key, history in all_results.items():
    summary['experiments'][exp_key] = {
        'nc': history['nc'],
        'local_steps': history['local_steps'],
        'final_test_acc': history['test_acc'][-1],
        'best_val_acc': history.get('best_val_acc', max(history['val_acc'])),
        'final_val_acc': history['val_acc'][-1],
    }

summary_path = os.path.join(sweep_dir, 'sweep_summary.json')
save_metrics_json(summary_path, summary)
print(f"Saved summary to: {summary_path}")

print("\n" + "="*60)
print("SWEEP COMPLETE!")
print("="*60)
print(f"Results saved to: {sweep_dir}")
print(f"Figures saved to: {figures_dir}")