In [9]:
# Importing libs and quick parameters definition

import copy
import random
import torch
import numpy as np
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset

number_of_clients = 1
number_of_rounds = 1

In [10]:
def load_cifar100():
    transform = transforms.Compose([
        transforms.ToTensor(), #convert each image to a tensor
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))  # CIFAR-100 mean/std deviation
    ])
    train = datasets.CIFAR100(root="./data", train=True, download=True, transform=transform)
    test = datasets.CIFAR100(root="./data", train=False, download=True, transform=transform)
    return train, test

train, test = load_cifar100()

In [None]:
# Next we try to simulate multiple clients in FL, each with a portion of the dataset
# mode: "iid"，  "dirichlet"
# alpha: Dirichlet concentration parameter, smaller values ​​result in more uneven distribution (0.5 or 0.1)
def partition_dataset(dataset, num_clients, mode = "iid", alpha = 0.5):
    n = len(dataset)
    all_idx = np.arange(n)

    if mode == "iid":  
        idx = all_idx.tolist()
        random.shuffle(idx) # shuffle the order to ensure randomness
        split = n // num_clients
        parts = [idx[i*split:(i+1)*split] for i in range(num_clients)]
        # last client gets remainder
        parts[-1].extend(idx[num_clients*split:])
        return parts
    else:
        # non-iid: each client randomly receives a different number of samples
        assert mode == "dirichlet", "mode must be 'iid' or 'dirichlet'"  

        # get labels (common .targets / .labels in torchvision)
        targets = getattr(dataset, "targets", None)
        if targets is None:
            targets = getattr(dataset, "labels", None)  
        if torch.is_tensor(targets):                  
            targets = targets.numpy()
        targets = np.asarray(targets)
        num_classes = int(targets.max() + 1)          

        # calculate Dirichlet-based client shares for each class
        parts = [[] for _ in range(num_clients)]       
        rng = np.random.default_rng()

        for c in range(num_classes):
            cls_idx = np.where(targets == c)[0]
            rng.shuffle(cls_idx)

            # one Dirichlet allocation per class
            # smaller alpha -> more non-IID (more biased towards certain clients)
            p = rng.dirichlet(alpha * np.ones(num_clients))   

            # allocate integer quotas proportionally
            raw = p * len(cls_idx)
            sizes = raw.astype(int)

            # because rounding will lose samples, fill in the remainder according to the maximum residual
            residue = len(cls_idx) - sizes.sum()
            if residue > 0:
                # fill in the decimal part from large to small
                fracs = raw - sizes                              
                order = np.argsort(-fracs)
                for j in order[:residue]:
                    sizes[j] += 1

            # split the index by sizes  (do this INSIDE the class loop)
            start = 0
            for j, need in enumerate(sizes):                     
                if need > 0:
                    parts[j].extend(cls_idx[start:start+need].tolist())
                    start += need

        # optional: shuffle each client's indices
        for j in range(num_clients):
            random.shuffle(parts[j])

        return parts  # returns a list of lists, where each inner list contains dataset indices belonging to one client.

# α = 0.5 (medium non-IID)
parts = partition_dataset(train, number_of_clients, mode="dirichlet", alpha=0.5)
# α = 0.1 (strong non-IID)
parts = partition_dataset(train, number_of_clients, mode="dirichlet", alpha=0.1)

In [12]:
def local_train(model, dataset, device, epochs=1, batch_size=32, lr=0.01):
    local_model = copy.deepcopy(model)  # each client gets its own copy of the global model
    local_model.to(device)
    local_model.train() # Train it on its local partition of the dataset

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    opt = optim.SGD(local_model.parameters(), lr=lr, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()

    for _ in range(epochs):
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            out = local_model(x)
            loss = loss_fn(out, y)
            loss.backward()
            opt.step()

    return local_model.state_dict()
   # then after training we return only the trained weights : state_dict

# FedAvg algorithm (weighted by client sample counts)
def average_weights(weight_list, sizes):# weighted average based on the number of samples from each client
    assert len(weight_list) == len(sizes) and len(weight_list) >0
    total = float(sum(sizes))
    avg = copy.deepcopy(weight_list[0])
    # first multiply the weight of the 0th client by its proportion
    scale0 = sizes[0] / total
    for k in avg.keys():
        avg[k] = avg[k] * scale0
    # weighted accumulation of other clients
    for i in range(1, len(weight_list)):
        wi = weight_list[i]
        si = sizes[i] / total
        for k in avg.keys():
            avg[k] += wi[k] * si
    return avg # Returns the averaged model weights.


def federated_training(num_clients=4,rounds=10,local_epochs=1,
    device="cpu",#mps for apple Metal series gpu, switch to cuda or cpu if otherwise
    q=1.0, # participation rate (0, 1]
    num_classes=100,    
    batch_size=64,
    lr=0.01,
    hetero_profiles=None,   # per-client overrides (basic heterogeneity)
):
    global_model = models.resnet18(num_classes=num_classes)

    for r in range(rounds):
        # sample m = q * num_clients participants this round
        m = max(1, int(q * num_clients))
        selected = random.sample(range(num_clients), k=m)

        client_states, client_sizes = [], []

        for c in selected:
            # default local hyperparams
            ep = local_epochs
            bs = batch_size
            lrn = lr
            drop_p = 0.0
            # per-client overrides if provided
            if hetero_profiles is not None and c in hetero_profiles:
                prof = hetero_profiles[c]
                ep   = prof.get("epochs", ep)
                bs   = prof.get("batch_size", bs)
                lrn  = prof.get("lr", lrn)
                drop_p = prof.get("dropout_prob", 0.0)

            # simulate availability (dropout)
            if drop_p > 0 and random.random() < drop_p:
                continue  # this client skips this round

            client_dataset = Subset(train, parts[c])
            local_state = local_train(
                global_model, client_dataset, device,
                epochs=ep, batch_size=bs, lr=lrn
            )
            client_states.append(local_state)
            client_sizes.append(len(parts[c]))

        if len(client_states) == 0:
            print(f"Round {r+1}: no clients participated — skip aggregation")
            continue

        averaged = average_weights(client_states, client_sizes)
        global_model.load_state_dict(averaged)
        print(f"Round {r+1}/{rounds} finished. participants={len(client_states)}/{num_clients}")

    # evaluate global model against global test set
    global_model.to(device)
    global_model.eval()
    correct, total = 0, 0
    test_loader = DataLoader(test, batch_size=128, shuffle=False)
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            out = global_model(x)
            pred = out.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    print("Global test acc:", correct / total)
    return global_model
    
federated_training(num_clients=number_of_clients, rounds=number_of_rounds)

# e.g., make three clients slower / smaller batch / lower lr, and one flaky client
hetero = {
    0: {"epochs": 1, "batch_size": 32, "lr": 0.005, "dropout_prob": 0.10},
    1: {"epochs": 1, "batch_size": 32, "lr": 0.005},
    2: {"epochs": 1, "batch_size": 16, "lr": 0.005},
    3: {"dropout_prob": 0.30},  # may skip rounds randomly
}

global_model = federated_training(
    num_clients=number_of_clients,
    rounds=number_of_rounds,
    device="cpu",
    q=0.5,
    num_classes=100,
    hetero_profiles=hetero
)
print(global_model)

Round 1/1 finished. participants=1/1
Global test acc: 0.2106
Round 1/1 finished. participants=1/1
Global test acc: 0.2134
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      