In [1]:
# Imports and simulation defaults

import copy
import random
import torch
import numpy as np
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset

# 'cpu', 'mps', or 'cuda'
device = "mps"

number_of_clients = 12
number_of_rounds = 5
local_epochs = 1
baseline_participation = 0.75
hetero_participation = 0.6
dirichlet_alphas = [0.5, 0.1]
rng_seed = 42
baseline_batch_size = 64
baseline_lr = 0.01

random.seed(rng_seed)
np.random.seed(rng_seed)
torch.manual_seed(rng_seed)


<torch._C.Generator at 0x1230a14f0>

In [2]:
def load_cifar100(data_root="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                         (0.2675, 0.2565, 0.2761))
    ])
    train = datasets.CIFAR100(root=data_root, train=True, download=True, transform=transform)
    test = datasets.CIFAR100(root=data_root, train=False, download=True, transform=transform)
    return train, test

train_dataset, test_dataset = load_cifar100()


In healthcare federated learning each site treats a distinct patient cohort, so their data distributions diverge.
Assuming IID shards would imply identical caseloads everywhere and can hide clinically critical edge cases.
We therefore rely exclusively on Dirichlet-driven non-IID splits throughout this simulator.


In [3]:
def _get_targets(dataset):
    targets = getattr(dataset, "targets", None)
    if targets is None:
        targets = getattr(dataset, "labels", None)
    if targets is None:
        raise ValueError("Dataset does not expose targets or labels for partitioning")
    if torch.is_tensor(targets):
        targets = targets.numpy()
    return np.asarray(targets)

def partition_dirichlet(dataset, num_clients, alpha=0.5, seed=None):
    if alpha <= 0:
        raise ValueError("Dirichlet alpha must be positive.")
    targets = _get_targets(dataset)
    classes = np.unique(targets)
    rng = np.random.default_rng(seed)
    partitions = [[] for _ in range(num_clients)]

    for cls in classes:
        cls_indices = np.where(targets == cls)[0]
        if cls_indices.size == 0:
            continue
        rng.shuffle(cls_indices)
        allocation = rng.dirichlet(np.full(num_clients, alpha))
        expected = allocation * cls_indices.size
        counts = expected.astype(int)
        residue = cls_indices.size - counts.sum()
        if residue > 0:
            order = np.argsort(-(expected - counts))
            for client_id in order[:residue]:
                counts[client_id] += 1
        start = 0
        for client_id, take in enumerate(counts):
            if take <= 0:
                continue
            end = start + take
            partitions[client_id].extend(cls_indices[start:end].tolist())
            start = end

    for client_id in range(num_clients):
        rng.shuffle(partitions[client_id])

    return partitions

dirichlet_partitions = {
    alpha: partition_dirichlet(
        train_dataset,
        number_of_clients,
        alpha=alpha,
        seed=rng_seed + int(alpha * 1000)
    )
    for alpha in dirichlet_alphas
}


In [4]:
BN_TYPES = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)


def local_train(model, dataset, device, epochs=1, batch_size=32, lr=0.01):
    local_model = copy.deepcopy(model)
    local_model.to(device)

    shard_size = len(dataset)
    if shard_size == 0:
        raise ValueError("Client dataset shard is empty.")

    local_model.train()
    bn_layers = [m for m in local_model.modules() if isinstance(m, BN_TYPES)]

    loader = DataLoader(
        dataset,
        batch_size=max(1, min(batch_size, shard_size)),
        shuffle=True,
    )
    opt = optim.SGD(local_model.parameters(), lr=lr, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()

            toggled = []
            if x.size(0) < 2:
                for layer in bn_layers:
                    if layer.training:
                        toggled.append(layer)
                        layer.eval()

            out = local_model(x)
            loss = loss_fn(out, y)
            loss.backward()
            opt.step()

            for layer in toggled:
                layer.train()

    return {k: v.detach().cpu() for k, v in local_model.state_dict().items()}

def average_weights(weight_list, sizes):
    if not weight_list:
        raise ValueError("No client weights provided for averaging.")
    if len(weight_list) != len(sizes):
        raise ValueError("Number of weights and sizes must match.")
    total = float(sum(sizes))
    if total <= 0:
        raise ValueError("Total client sample size must be positive.")

    avg = copy.deepcopy(weight_list[0])
    scale0 = sizes[0] / total
    for k in avg.keys():
        avg[k] = avg[k] * scale0
    for idx in range(1, len(weight_list)):
        wi = weight_list[idx]
        si = sizes[idx] / total
        for k in avg.keys():
            avg[k] += wi[k] * si
    return avg

def federated_training(
    train_dataset,
    test_dataset,
    partitions,
    rounds=10,
    local_epochs=1,
    device="cpu",
    q=1.0,
    num_classes=100,
    batch_size=64,
    lr=0.01,
    hetero_profiles=None,
):
    num_clients = len(partitions)
    if num_clients == 0:
        raise ValueError("No clients available for federated training.")
    if not (0 < q <= 1):
        raise ValueError("Client participation rate q must be in (0, 1].")

    global_model = models.resnet18(num_classes=num_classes)

    for r in range(rounds):
        m = max(1, min(num_clients, int(q * num_clients)))
        selected = random.sample(range(num_clients), k=m)
        client_states = []
        client_sizes = []

        for client_id in selected:
            shard = partitions[client_id]
            if not shard:
                continue

            epochs = local_epochs
            batch = batch_size
            client_lr = lr
            dropout_prob = 0.0
            if hetero_profiles:
                profile = hetero_profiles.get(client_id, {})
                epochs = profile.get("epochs", epochs)
                batch = profile.get("batch_size", batch)
                client_lr = profile.get("lr", client_lr)
                dropout_prob = profile.get("dropout_prob", dropout_prob)

            if dropout_prob > 0 and random.random() < dropout_prob:
                continue

            client_dataset = Subset(train_dataset, shard)
            local_state = local_train(
                global_model,
                client_dataset,
                device,
                epochs=epochs,
                batch_size=batch,
                lr=client_lr,
            )
            client_states.append(local_state)
            client_sizes.append(len(shard))

        if not client_states:
            print(f"Round {r + 1}: no clients contributed updates; skipping aggregation.")
            continue

        averaged_state = average_weights(client_states, client_sizes)
        global_model.load_state_dict(averaged_state)
        print(f"Round {r + 1}/{rounds} finished. participants={len(client_states)}/{num_clients}")

    global_model.to(device)
    global_model.eval()
    correct = 0
    total = 0
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            preds = global_model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    accuracy = correct / total if total else 0.0
    print("Global test acc:", accuracy)
    return global_model

model_alpha_05 = federated_training(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    partitions=dirichlet_partitions[0.5],
    rounds=number_of_rounds,
    local_epochs=local_epochs,
    device=device,
    q=baseline_participation,
    num_classes=100,
    batch_size=baseline_batch_size,
    lr=baseline_lr,
)

hetero_profiles = {
    0: {"epochs": 2, "batch_size": 48, "lr": 0.0075, "dropout_prob": 0.10},
    1: {"epochs": 1, "batch_size": 32, "lr": 0.005},
    2: {"epochs": 1, "batch_size": 24, "lr": 0.006},
    3: {"dropout_prob": 0.25},
    4: {"epochs": 2, "batch_size": 64, "lr": 0.012},
    5: {"dropout_prob": 0.15, "lr": 0.008},
}

hetero_model = federated_training(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    partitions=dirichlet_partitions[0.1],
    rounds=number_of_rounds,
    local_epochs=local_epochs + 1,
    device=device,
    q=hetero_participation,
    num_classes=100,
    batch_size=baseline_batch_size,
    lr=baseline_lr,
    hetero_profiles=hetero_profiles,
)
print(hetero_model)


Round 1/5 finished. participants=9/12
Round 2/5 finished. participants=9/12
Round 3/5 finished. participants=9/12
Round 4/5 finished. participants=9/12
Round 5/5 finished. participants=9/12
Global test acc: 0.1614


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])