# Federated Averaging (FedAvg) baseline on CIFAR-100
- Dirichlet non-IID partitioning
- Partial client participation
- Optional heterogeneity (per-client batch size / epochs / lr)
- Safe, self-contained, single-file version
## Imports and simulation defaults

In [1]:
import copy
import math
import random
import numpy as np
from collections import defaultdict
from typing import Dict, List, Tuple, Optional

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models

# Reproducibility
rng_seed = 42
random.seed(rng_seed)
np.random.seed(rng_seed)
torch.manual_seed(rng_seed)  
# (we can also seed CUDA later if present: torch.cuda.manual_seed_all(rng_seed))

# Device helper (prefers CUDA, then MPS, else CPU)
def get_device(prefer: Optional[str] = None) -> torch.device:
    if prefer == "cuda" and torch.cuda.is_available():
        return torch.device("cuda")
    if prefer == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

### Data: CIFAR-100 loaders (train/test)

In [2]:
def load_cifar100(data_root: str = "./data"):
    # Mild augmentation on train; standard normalization
    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)

    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train = datasets.CIFAR100(root=data_root, train=True, download=True, transform=train_tf)
    test = datasets.CIFAR100(root=data_root, train=False, download=True, transform=test_tf)
    return train, test  # mirrors your original intent. 


# %%
# Targets accessor (handles .targets/.labels)
def _get_targets(dataset) -> np.ndarray:
    targets = getattr(dataset, "targets", None)
    if targets is None:
        targets = getattr(dataset, "labels", None)
    if targets is None:
        raise AttributeError("Dataset has no 'targets' or 'labels'.")
    return np.array(targets)

## Dirichlet non-IID split (returns dict: client_id -> list of indices)

In [3]:
def dirichlet_noniid_indices(dataset, num_clients: int, alpha: float, min_per_client: int = 10) -> Dict[int, List[int]]:
    y = _get_targets(dataset)
    num_classes = int(y.max()) + 1
    idx_by_class = {c: np.where(y == c)[0] for c in range(num_classes)}
    for c in idx_by_class:
        np.random.shuffle(idx_by_class[c])

    # Each class's indices are split among clients with Dirichlet(alpha)
    client_indices = [[] for _ in range(num_clients)]
    for c in range(num_classes):
        idx_c = idx_by_class[c]
        if len(idx_c) == 0:
            continue
        # Dirichlet draw for this class across clients
        p = np.random.dirichlet([alpha] * num_clients)
        # Proportions -> integer split (rounding by cumulative sums)
        cuts = (np.cumsum(p) * len(idx_c)).astype(int)[:-1]
        split = np.split(idx_c, cuts)
        for i, shard in enumerate(split):
            client_indices[i].extend(shard.tolist())

    # Ensure minimum per client (fallback to random fill if some are tiny)
    # This usually isn't necessary for reasonable alpha, but keeps loaders happy
    pool = list(range(len(dataset)))
    for i in range(num_clients):
        if len(client_indices[i]) < min_per_client:
            need = min_per_client - len(client_indices[i])
            extra = np.random.choice(pool, size=need, replace=False).tolist()
            client_indices[i].extend(extra)

    # Shuffle each client's order (nicer batching)
    for i in range(num_clients):
        random.shuffle(client_indices[i])
    return {i: client_indices[i] for i in range(num_clients)}

## Model: ResNet18 head for CIFAR-100

In [4]:
def build_model(num_classes: int = 100) -> nn.Module:
    model = models.resnet18(weights=None)  # no pretrained to avoid download in restricted envs
    # CIFAR images are 3x32x32; torchvision ResNet expects 224x224,
    # but it's fine—ResNet is fully conv except FC. It still works on 32x32.
    # Replace final FC layer to match number of classes
    in_feats = model.fc.in_features
    model.fc = nn.Linear(in_feats, num_classes)
    return model


# %%
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            loss_sum += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    avg_loss = loss_sum / max(1, total)
    acc = correct / max(1, total)
    return avg_loss, acc


# %%
def local_train(
    global_model: nn.Module,
    subset: Subset,
    device: torch.device,
    epochs: int = 1,
    batch_size: int = 64,
    lr: float = 0.01,
) -> Dict[str, torch.Tensor]:
    model = copy.deepcopy(global_model).to(device)
    loader = DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=False)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    model.train()
    for _ in range(max(1, epochs)):
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

    # Return CPU weights to simulate uplink
    return {k: v.detach().cpu() for k, v in model.state_dict().items()}

## Weighted model averaging (FedAvg) with checks

In [5]:
def average_weights(weight_list, sizes):
    if not weight_list:
        raise ValueError("No client weights provided.")
    if len(weight_list) != len(sizes):
        raise ValueError("weights and sizes mismatch")
    total = float(sum(sizes))
    avg = {k: torch.zeros_like(v) for k, v in weight_list[0].items()}

    for wi, si in zip(weight_list, sizes):
        w = si / total
        for k in avg.keys():
            if avg[k].dtype.is_floating_point:      
                avg[k] += wi[k].float() * w
            else:
                avg[k] = wi[k].clone()              

    return avg # mirrors safety we intend. :contentReference[oaicite:2]{index=2}

# %%
def federated_training(
    train_dataset,
    test_dataset,
    partitions: Dict[int, List[int]],
    rounds: int = 10,
    local_epochs: int = 1,
    device: str = "cpu",
    q: float = 1.0,  # participation rate per round
    num_classes: int = 100,
    batch_size: int = 64,
    lr: float = 0.01,
    hetero_profiles: Optional[Dict[int, Dict[str, float]]] = None,  # per-client overrides
):
    """
    Train the global model using FedAvg over 'rounds' communication rounds.
    """
    # Validate partitions
    num_clients = len(partitions)
    if num_clients == 0:
        raise ValueError("No clients available for federated training.")
    device = get_device(device)

    # Build global model
    global_model = build_model(num_classes=num_classes).to(device)

    # Test loader (global test set)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)

    history = {"round": [], "test_loss": [], "test_acc": [], "selected": []}

    for r in range(1, rounds + 1):
        # Client sampling
        m = max(1, int(math.ceil(q * num_clients)))
        selected = sorted(random.sample(range(num_clients), m))

        # Local updates
        local_weights = []
        local_sizes = []
        for cid in selected:
            idxs = partitions[cid]
            subset = Subset(train_dataset, idxs)

            # Per-client heterogeneity overrides if provided
            ep = local_epochs
            bs = batch_size
            lr_i = lr
            if hetero_profiles and cid in hetero_profiles:
                ep = int(hetero_profiles[cid].get("epochs", ep))
                bs = int(hetero_profiles[cid].get("batch_size", bs))
                lr_i = float(hetero_profiles[cid].get("lr", lr_i))

            wi = local_train(global_model, subset, device, epochs=ep, batch_size=bs, lr=lr_i)
            local_weights.append(wi)
            local_sizes.append(len(subset))

        # FedAvg
        new_state = average_weights(local_weights, local_sizes)
        global_model.load_state_dict(new_state)

        # Evaluate global model
        test_loss, test_acc = evaluate(global_model, test_loader, device)
        history["round"].append(r)
        history["test_loss"].append(test_loss)
        history["test_acc"].append(test_acc)
        history["selected"].append(selected)

        print(f"[Round {r:03d}] Test loss: {test_loss:.4f} | Test acc: {test_acc*100:.2f}% | Clients: {selected}")

    return global_model, history  # same signature goal you had. :contentReference[oaicite:3]{index=3}

## Final 

In [None]:
if __name__ == "__main__":
    # --- Config (kept close to your defaults) ---
    number_of_clients = 10
    number_of_rounds = 20
    local_epochs = 1            # start small for smoke tests; raise later
    participation = 0.75        # fraction of clients per round
    dirichlet_alpha = 0.5       # non-IID strength (lower => more skew)
    batch_size = 64
    lr = 0.01
    device_pref = None          # "cuda" | "mps" | "cpu" | None (auto)

    # (These reflect the same values/types you were using.) :contentReference[oaicite:4]{index=4}

    # --- Data ---
    train_dataset, test_dataset = load_cifar100()

    # --- Partitioning ---
    partitions = dirichlet_noniid_indices(
        train_dataset, num_clients=number_of_clients, alpha=dirichlet_alpha
    )

    # Optional: example heterogeneity profile
    # hetero_profiles = {
    #     0: {"epochs": 2, "batch_size": 32, "lr": 0.02},
    #     3: {"epochs": 1, "batch_size": 128, "lr": 0.005},
    # }
    hetero_profiles = None

    # --- Federated training ---
    model, hist = federated_training(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        partitions=partitions,
        rounds=number_of_rounds,
        local_epochs=local_epochs,
        device=device_pref or "auto",
        q=participation,
        num_classes=100,
        batch_size=batch_size,
        lr=lr,
        hetero_profiles=hetero_profiles,
    )