# Federated Averaging (FedAvg) baseline on CIFAR-100
- Dirichlet non-IID partitioning
- Partial client participation
- Optional heterogeneity (per-client batch size / epochs / lr)
## Imports and simulation defaults

In [None]:
import copy
import math
import random
import numpy as np
from collections import defaultdict
from typing import Dict, List, Tuple, Optional

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models

# Reproducibility
rng_seed = 42
random.seed(rng_seed)
np.random.seed(rng_seed)
torch.manual_seed(rng_seed)  
# (we can also seed CUDA later if present: torch.cuda.manual_seed_all(rng_seed))

# Device helper (prefers CUDA, then MPS, else CPU)
def get_device(prefer: Optional[str] = None) -> torch.device:
    if prefer == "cuda" and torch.cuda.is_available():
        return torch.device("cuda")
    if prefer == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

: 

### Data: CIFAR-100 loaders (train/test)

In [None]:
def load_cifar100(data_root: str = "./data"):
    # Mild augmentation on train; standard normalization
    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)

    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train = datasets.CIFAR100(root=data_root, train=True, download=True, transform=train_tf)
    test = datasets.CIFAR100(root=data_root, train=False, download=True, transform=test_tf)
    return train, test  # mirrors your original intent. 


# %%
# Targets accessor (handles .targets/.labels)
def _get_targets(dataset) -> np.ndarray:
    targets = getattr(dataset, "targets", None)
    if targets is None:
        targets = getattr(dataset, "labels", None)
    if targets is None:
        raise AttributeError("Dataset has no 'targets' or 'labels'.")
    return np.array(targets)

: 

## Dirichlet non-IID split (returns dict: client_id -> list of indices)

In [None]:
def dirichlet_noniid_indices(dataset, num_clients: int, alpha: float, min_per_client: int = 10) -> Dict[int, List[int]]:
    y = _get_targets(dataset)
    num_classes = int(y.max()) + 1
    idx_by_class = {c: np.where(y == c)[0] for c in range(num_classes)}
    for c in idx_by_class:
        np.random.shuffle(idx_by_class[c])

    # Each class's indices are split among clients with Dirichlet(alpha)
    client_indices = [[] for _ in range(num_clients)]
    for c in range(num_classes):
        idx_c = idx_by_class[c]
        if len(idx_c) == 0:
            continue
        # Dirichlet draw for this class across clients
        p = np.random.dirichlet([alpha] * num_clients)
        # Proportions -> integer split (rounding by cumulative sums)
        cuts = (np.cumsum(p) * len(idx_c)).astype(int)[:-1]
        split = np.split(idx_c, cuts)
        for i, shard in enumerate(split):
            client_indices[i].extend(shard.tolist())

    # Ensure minimum per client (fallback to random fill if some are tiny)
    # This usually isn't necessary for reasonable alpha, but keeps loaders happy
    pool = list(range(len(dataset)))
    for i in range(num_clients):
        if len(client_indices[i]) < min_per_client:
            need = min_per_client - len(client_indices[i])
            extra = np.random.choice(pool, size=need, replace=False).tolist()
            client_indices[i].extend(extra)

    # Shuffle each client's order (nicer batching)
    for i in range(num_clients):
        random.shuffle(client_indices[i])
    return {i: client_indices[i] for i in range(num_clients)}

: 

## Model: ResNet18 head for CIFAR-100

In [None]:
def build_model(num_classes: int = 100) -> nn.Module:
    model = models.resnet18(weights=None)  # no pretrained to avoid download in restricted envs
    # CIFAR images are 3x32x32; torchvision ResNet expects 224x224,
    # but it's fine—ResNet is fully conv except FC. It still works on 32x32.
    # Replace final FC layer to match number of classes
    in_feats = model.fc.in_features
    model.fc = nn.Linear(in_feats, num_classes)
    return model


# %%
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            loss_sum += loss.item() * x.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    avg_loss = loss_sum / max(1, total)
    acc = correct / max(1, total)
    return avg_loss, acc


# %%
def local_train(
    global_model: nn.Module,
    subset: Subset,
    device: torch.device,
    epochs: int = 1,
    batch_size: int = 64,
    lr: float = 0.01,
) -> Dict[str, torch.Tensor]:
    model = copy.deepcopy(global_model).to(device)
    loader = DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=False)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    model.train()
    for _ in range(max(1, epochs)):
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

    # Return CPU weights to simulate uplink
    return {k: v.detach().cpu() for k, v in model.state_dict().items()}

: 

## Weighted model averaging (FedAvg) with checks

In [None]:
def average_weights(weight_list, sizes):
    if not weight_list:
        raise ValueError("No client weights provided.")
    if len(weight_list) != len(sizes):
        raise ValueError("weights and sizes mismatch")
    total = float(sum(sizes))
    avg = {k: torch.zeros_like(v) for k, v in weight_list[0].items()}

    for wi, si in zip(weight_list, sizes):
        w = si / total
        for k in avg.keys():
            if avg[k].dtype.is_floating_point:      
                avg[k] += wi[k].float() * w
            else:
                avg[k] = wi[k].clone()              

    return avg # mirrors safety we intend. :contentReference[oaicite:2]{index=2}

# %%
def federated_training(
    train_dataset,
    test_dataset,
    partitions: Dict[int, List[int]],
    rounds: int = 10,
    local_epochs: int = 1,
    device: str = "cpu",
    q: float = 1.0,  # participation rate per round
    num_classes: int = 100,
    batch_size: int = 64,
    lr: float = 0.01,
    hetero_profiles: Optional[Dict[int, Dict[str, float]]] = None,  # per-client overrides
):
    """
    Train the global model using FedAvg over 'rounds' communication rounds.
    """
    # Validate partitions
    num_clients = len(partitions)
    if num_clients == 0:
        raise ValueError("No clients available for federated training.")
    device = get_device(device)

    # Build global model
    global_model = build_model(num_classes=num_classes).to(device)

    # Test loader (global test set)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)

    history = {"round": [], "test_loss": [], "test_acc": [], "selected": []}

    for r in range(1, rounds + 1):
        # Client sampling
        m = max(1, int(math.ceil(q * num_clients)))
        selected = sorted(random.sample(range(num_clients), m))

        # Local updates
        local_weights = []
        local_sizes = []
        for cid in selected:
            idxs = partitions[cid]
            subset = Subset(train_dataset, idxs)

            # Per-client heterogeneity overrides if provided
            ep = local_epochs
            bs = batch_size
            lr_i = lr
            if hetero_profiles and cid in hetero_profiles:
                ep = int(hetero_profiles[cid].get("epochs", ep))
                bs = int(hetero_profiles[cid].get("batch_size", bs))
                lr_i = float(hetero_profiles[cid].get("lr", lr_i))

            wi = local_train(global_model, subset, device, epochs=ep, batch_size=bs, lr=lr_i)
            local_weights.append(wi)
            local_sizes.append(len(subset))

        # FedAvg
        new_state = average_weights(local_weights, local_sizes)
        global_model.load_state_dict(new_state)

        # Evaluate global model
        test_loss, test_acc = evaluate(global_model, test_loader, device)
        history["round"].append(r)
        history["test_loss"].append(test_loss)
        history["test_acc"].append(test_acc)
        history["selected"].append(selected)

        print(f"[Round {r:03d}] Test loss: {test_loss:.4f} | Test acc: {test_acc*100:.2f}% | Clients: {selected}")

    return global_model, history  # same signature goal you had. :contentReference[oaicite:3]{index=3}

: 

## Final 

In [8]:
if __name__ == "__main__":
    # --- Config (kept close to your defaults) ---
    number_of_clients = 10
    number_of_rounds = 20
    local_epochs = 1            # start small for smoke tests; raise later
    participation = 0.75        # fraction of clients per round
    dirichlet_alpha = 0.5       # non-IID strength (lower => more skew)
    batch_size = 64
    lr = 0.01
    device_pref = None          # "cuda" | "mps" | "cpu" | None (auto)

    # (These reflect the same values/types you were using.) :contentReference[oaicite:4]{index=4}

    # --- Data ---
    train_dataset, test_dataset = load_cifar100()

    # --- Partitioning ---
    partitions = dirichlet_noniid_indices(
        train_dataset, num_clients=number_of_clients, alpha=dirichlet_alpha
    )

    # Optional: example heterogeneity profile
    # hetero_profiles = {
    #     0: {"epochs": 2, "batch_size": 32, "lr": 0.02},
    #     3: {"epochs": 1, "batch_size": 128, "lr": 0.005},
    # }
    hetero_profiles = None

    # --- Federated training ---
    model, hist = federated_training(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        partitions=partitions,
        rounds=number_of_rounds,
        local_epochs=local_epochs,
        device=device_pref or "auto",
        q=participation,
        num_classes=100,
        batch_size=batch_size,
        lr=lr,
        hetero_profiles=hetero_profiles,
    )

[Round 001] Test loss: 4.7050 | Test acc: 1.00% | Clients: [0, 2, 3, 4, 5, 6, 7, 9]
[Round 002] Test loss: 4.5497 | Test acc: 1.52% | Clients: [0, 2, 3, 4, 5, 6, 7, 9]
[Round 003] Test loss: 4.1629 | Test acc: 5.92% | Clients: [0, 1, 3, 4, 5, 6, 8, 9]
[Round 004] Test loss: 3.8902 | Test acc: 9.79% | Clients: [1, 2, 4, 5, 6, 7, 8, 9]
[Round 005] Test loss: 3.6892 | Test acc: 13.83% | Clients: [1, 2, 3, 4, 5, 7, 8, 9]
[Round 006] Test loss: 3.5107 | Test acc: 17.14% | Clients: [0, 1, 2, 3, 4, 5, 6, 7]
[Round 007] Test loss: 3.4494 | Test acc: 18.49% | Clients: [1, 2, 3, 4, 5, 6, 7, 8]
[Round 008] Test loss: 3.3218 | Test acc: 20.47% | Clients: [0, 1, 3, 4, 5, 7, 8, 9]
[Round 009] Test loss: 3.2474 | Test acc: 21.49% | Clients: [0, 1, 2, 4, 5, 6, 8, 9]
[Round 010] Test loss: 3.1447 | Test acc: 23.30% | Clients: [0, 1, 3, 4, 6, 7, 8, 9]
[Round 011] Test loss: 3.1076 | Test acc: 24.12% | Clients: [0, 1, 2, 4, 5, 6, 7, 9]
[Round 012] Test loss: 3.0439 | Test acc: 25.74% | Clients: [0, 3, 4,

## Metrics to export 

metrics_to_export = {
    'round': round_num,
    'participating_clients': client_ids,
    
    # Per-client exports
    'client_metrics': {
        client_id: {
            'gradient_norm': float,
            'per_layer_norms': dict,
            'batch_size': int,
            'local_epochs': int,
            'learning_rate': float,
            'num_samples': int,
            'class_distribution': dict,
            'local_loss': float,
        }
    },
    
    # Global state
    'global_model_state': state_dict,  # or checkpoint path
    'global_accuracy': float,
    'global_loss': float,
    
    # For reconstruction
    'raw_gradients': {client_id: gradient_dict},  # KEY for breaching
    'model_updates': {client_id: update_dict},
}

In [None]:
import torch
import copy
from torch.utils.data import Subset

# reconstruct final state and global metrics
global_loss, global_acc = evaluate(model, DataLoader(test_dataset, batch_size=256, shuffle=False), get_device())

# container for metrics
metrics_to_export = {
    "round": number_of_rounds,
    "participating_clients": hist["selected"][-1],
    "client_metrics": {},
    "global_model_state": copy.deepcopy(model.state_dict()),  # could also save to disk
    "global_accuracy": float(global_acc),
    "global_loss": float(global_loss),
    "raw_gradients": {},
    "model_updates": {},
}

# --- compute per-client metrics on the final round participants ---
device = get_device()
global_state = copy.deepcopy(model.state_dict())

criterion = torch.nn.CrossEntropyLoss()

for cid in hist["selected"][-1]:
    idxs = partitions[cid]
    subset = Subset(train_dataset, idxs)
    loader = DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=2)

    # clone model to compute client gradients from final global weights
    client_model = build_model(num_classes=100).to(device)
    client_model.load_state_dict(global_state)
    client_model.train()

    optimizer = torch.optim.SGD(client_model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

    # one batch to approximate gradient norm
    x, y = next(iter(loader))
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    logits = client_model(x)
    loss = criterion(logits, y)
    loss.backward()

    grad_dict = {n: p.grad.detach().cpu().clone() for n, p in client_model.named_parameters() if p.grad is not None}
    grad_norm = torch.sqrt(sum(torch.norm(g)**2 for g in grad_dict.values())).item()
    per_layer_norms = {n: torch.norm(g).item() for n, g in grad_dict.items()}

    # simulate a local update (1 epoch) to get model delta
    local_weights = local_train(model, subset, device, epochs=1, batch_size=batch_size, lr=lr)
    update_dict = {
    k: local_weights[k].cpu() - global_state[k].detach().cpu()
    for k in global_state.keys()
    }
    # class distribution for this client
    y_subset = [train_dataset.targets[i] for i in idxs]
    class_counts = dict(zip(*np.unique(y_subset, return_counts=True)))

    metrics_to_export["client_metrics"][cid] = {
        "gradient_norm": float(grad_norm),
        "per_layer_norms": per_layer_norms,
        "batch_size": batch_size,
        "local_epochs": local_epochs,
        "learning_rate": lr,
        "num_samples": len(subset),
        "class_distribution": class_counts,
        "local_loss": float(loss.item()),
    }
    metrics_to_export["raw_gradients"][cid] = grad_dict
    metrics_to_export["model_updates"][cid] = update_dict

print("Metrics extracted for", len(metrics_to_export["client_metrics"]), "clients.")

In [None]:
print("Global accuracy:", metrics_to_export["global_accuracy"])
print("Global loss:", metrics_to_export["global_loss"])
print("Round:", metrics_to_export["round"])
print("Clients:", metrics_to_export["participating_clients"])

Global accuracy: 0.3223
Global loss: 2.683524996185303
Round: 20
Clients: [0, 1, 2, 3, 4, 5, 7, 9]


: 

In [None]:
for cid, info in metrics_to_export["client_metrics"].items():
    print(f"\nClient {cid}:")
    print(f"  samples = {info['num_samples']}")
    print(f"  local_loss = {info['local_loss']:.4f}")
    print(f"  grad_norm = {info['gradient_norm']:.2f}")
    print(f"  top-3 per-layer norms = {list(info['per_layer_norms'].items())[:3]}")


Client 0:
  samples = 4860
  local_loss = 3.2197
  grad_norm = 5.99
  top-3 per-layer norms = [('conv1.weight', 1.4087834358215332), ('bn1.weight', 0.12101278454065323), ('bn1.bias', 0.06460622698068619)]

Client 1:
  samples = 5014
  local_loss = 2.4702
  grad_norm = 4.61
  top-3 per-layer norms = [('conv1.weight', 0.9140346050262451), ('bn1.weight', 0.10412313789129257), ('bn1.bias', 0.05682973563671112)]

Client 2:
  samples = 5582
  local_loss = 2.8508
  grad_norm = 5.30
  top-3 per-layer norms = [('conv1.weight', 1.2020173072814941), ('bn1.weight', 0.11466034501791), ('bn1.bias', 0.07469813525676727)]

Client 3:
  samples = 4299
  local_loss = 2.4702
  grad_norm = 5.40
  top-3 per-layer norms = [('conv1.weight', 1.3166335821151733), ('bn1.weight', 0.10946671664714813), ('bn1.bias', 0.0636722594499588)]

Client 4:
  samples = 4640
  local_loss = 2.5860
  grad_norm = 5.30
  top-3 per-layer norms = [('conv1.weight', 1.2379850149154663), ('bn1.weight', 0.11968004703521729), ('bn1.bia

: 

In [None]:
import pprint
pprint.pprint(metrics_to_export["client_metrics"])

{0: {'batch_size': 64,
     'class_distribution': {np.int64(0): np.int64(5),
                            np.int64(1): np.int64(160),
                            np.int64(3): np.int64(50),
                            np.int64(4): np.int64(21),
                            np.int64(5): np.int64(78),
                            np.int64(6): np.int64(90),
                            np.int64(7): np.int64(267),
                            np.int64(8): np.int64(29),
                            np.int64(9): np.int64(26),
                            np.int64(10): np.int64(96),
                            np.int64(11): np.int64(74),
                            np.int64(12): np.int64(24),
                            np.int64(13): np.int64(56),
                            np.int64(14): np.int64(50),
                            np.int64(15): np.int64(3),
                            np.int64(16): np.int64(27),
                            np.int64(17): np.int64(57),
                            np.int

: 

Saving the full object locally (so we can load it later for analysis or privacy-attack experiments)

In [None]:
import pickle
with open("fed_metrics_round20.pkl", "wb") as f:
    pickle.dump(metrics_to_export, f)

: 

Then we can reload later with:

"

with open("fed_metrics_round20.pkl", "rb") as f:
    m = pickle.load(f)

"