In [1]:
# Importing libs and quick parameters definition

import copy
import random
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset

number_of_clients = 5
number_of_rounds = 3

In [2]:
def load_cifar100():
    transform = transforms.Compose([
        transforms.ToTensor(), #convert each image to a tensor
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))  # CIFAR-100 mean/std deviation
    ])
    train = datasets.CIFAR100(root="./data", train=True, download=True, transform=transform)
    test = datasets.CIFAR100(root="./data", train=False, download=True, transform=transform)
    return train, test

train, test = load_cifar100()

In [3]:
# Next we try to simulate multiple clients in FL, each with a portion of the dataset
def partition_dataset(dataset, num_clients, iid=True): # Defaults to independent & identically distributed
    n = len(dataset)
    indices = list(range(n))
    if iid: # then every client gets a random, fair sample of the data
        random.shuffle(indices)
        split = n // num_clients
        parts = [indices[i*split:(i+1)*split] for i in range(num_clients)]
        # last client gets remainder
        parts[-1].extend(indices[num_clients*split:])
    else:
        # non-iid: group by class (simple approach) - could add later ?
        raise NotImplementedError("Non-iid partition not implemented yet")
    return parts # Returns a list of lists, where each inner list contains dataset indices belonging to one client.

parts = partition_dataset(train, number_of_clients) # 4 for 4 clients
print([len(p) for p in parts])  

[10000, 10000, 10000, 10000, 10000]


In [4]:
def local_train(model, dataset, device, epochs=1, batch_size=32, lr=0.01):
    model = copy.deepcopy(model) # each client gets its own copy of the global model
    model.to(device)
    model.train() # Train it on its local partition of the dataset
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()
    for _ in range(epochs):
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            opt.zero_grad()
            out = model(x)
            loss = loss_fn(out, y)
            loss.backward()
            opt.step()
    return model.state_dict() 
    # Then after training we return only the trained weights : state_dict

# FedAvg algorithm
def average_weights(weight_list):
    avg = copy.deepcopy(weight_list[0])
    for k in avg.keys():
        # for each parameter, it computes the element-wise mean across all clients
        for i in range(1, len(weight_list)):
            avg[k] += weight_list[i][k]
        avg[k] = torch.div(avg[k], len(weight_list))
    return avg # Returns the averaged model weights.

def federated_training(num_clients=4, rounds=10, local_epochs=1, device="mps"): # mps for apple Metal series gpu, switch to cuda or cpu if otherwise
    global_model = models.resnet18(num_classes=100) # to initialize a global Initializes a global ResNet18 model
    global_state = global_model.state_dict()

    for r in range(rounds): # Each round = 1 global communication cycle 
        selected = list(range(num_clients))  # simple: all clients
        client_states = []
        for c in selected:
            client_dataset = Subset(train, parts[c])
            local_state = local_train(global_model, client_dataset, device,
                                      epochs=local_epochs, batch_size=64, lr=0.01)
            # All clients download the global model
            # Each client trains locally on its partition for n local_epochs and produces new weights
            # server collects all local weights
            client_states.append(local_state)
        # FedAvg
        averaged = average_weights(client_states)
            # Server averages them
        global_model.load_state_dict(averaged)
            # The average model becomes the new global one for the next round
        print(f"Round {r+1}/{rounds} finished.")

    # evaluate global model against global test set
    global_model.to(device)
    global_model.eval()
    # test accuracy results:
    correct = 0
    total = 0
    loader = DataLoader(test, batch_size=128)
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            out = global_model(x)
            pred = out.argmax(dim=1)
            correct += (pred==y).sum().item()
            total += y.size(0)
    print("Global test acc:", correct/total)
    return global_model

federated_training(num_clients=number_of_clients, rounds=number_of_rounds)

Round 1/3 finished.
Round 2/3 finished.
Round 3/3 finished.
Global test acc: 0.1712


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  