In [1]:
!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.7/66.7 MB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m30.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

Training on cpu
Flower 1.18.0 / PyTorch 2.6.0+cu124


In [3]:
NUM_PARTITIONS = 5
BATCH_SIZE = 32


def load_datasets(partition_id: int, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
    )
    valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

In [4]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [5]:
class FlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)

    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

In [6]:
# Create an instance of the model and get the parameters
params = get_parameters(Net())

In [7]:
# The `evaluate` function will be called by Flower after every round
def evaluate(
    server_round: int,
    parameters: NDArrays,
    config: Dict[str, Scalar],
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
    net = Net().to(DEVICE)
    _, _, testloader = load_datasets(0, NUM_PARTITIONS)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}

## Server-side parameter **evaluation**

Flower can evaluate the aggregated model on the server-side or on the client-side. Client-side and server-side evaluation are similar in some ways, but different in others.

**Centralized Evaluation** (or *server-side evaluation*) is conceptually simple: it works the same way that evaluation in centralized machine learning does. If there is a server-side dataset that can be used for evaluation purposes, then that's great. We can evaluate the newly aggregated model after each round of training without having to send the model to clients. We're also fortunate in the sense that our entire evaluation dataset is available at all times.

**Federated Evaluation** (or *client-side evaluation*) is more complex, but also more powerful: it doesn't require a centralized dataset and allows us to evaluate models over a larger set of data, which often yields more realistic evaluation results. In fact, many scenarios require us to use **Federated Evaluation** if we want to get representative evaluation results at all. But this power comes at a cost: once we start to evaluate on the client side, we should be aware that our evaluation dataset can change over consecutive rounds of learning if those clients are not always available. Moreover, the dataset held by each client can also change over consecutive rounds. This can lead to evaluation results that are not stable, so even if we would not change the model, we'd see our evaluation results fluctuate over consecutive rounds.

We've seen how federated evaluation works on the client side (i.e., by implementing the `evaluate` method in `FlowerClient`). Now let's see how we can evaluate aggregated model parameters on the server-side:

In [11]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create the FedAvg strategy
    strategy = FedAvg(
        fraction_fit=1,
        fraction_evaluate=1,
        min_fit_clients=5,
        min_evaluate_clients=5,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(params),
        evaluate_fn=evaluate,  # Pass the evaluation function
    )
    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

In [12]:
backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_gpus": 1}}
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[36m(pid=5182)[0m 2025-05-03 16:24:22.963286: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=5182)[0m E0000 00:00:1746289463.198262    5182 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=5182)[0m E0000 00:00:1746289463.227014    5182 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[92mINFO [0m:      initial parameters (loss, other metrics): 0.07216868712902069, {'accuracy':

Server-side evaluation loss 0.07216868712902069 / accuracy 0.0615


[36m(ClientAppActor pid=5182)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=5182)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=5182)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=5182)[0m   from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write


[36m(ClientAppActor pid=5182)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.061832185834646225, accuracy 0.26525
[36m(ClientAppActor pid=5182)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.061396535485982895, accuracy 0.27975
[36m(ClientAppActor pid=5182)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.06178712844848633, accuracy 0.26125
[36m(ClientAppActor pid=5182)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.06135924533009529, accuracy 0.275125
[36m(ClientAppActor pid=5182)[0m [Client 4] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures


[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.062053706496953964, accuracy 0.261875


[92mINFO [0m:      fit progress: (1, 0.05965306313037872, {'accuracy': 0.3536}, 66.95670645299992)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation loss 0.05965306313037872 / accuracy 0.3536
[36m(ClientAppActor pid=5182)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 4] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=5182)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.05342862010002136, accuracy 0.36825
[36m(ClientAppActor pid=5182)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.053195852786302567, accuracy 0.370875
[36m(ClientAppActor pid=5182)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.05401067063212395, accuracy 0.359125
[36m(ClientAppActor pid=5182)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.05330224707722664, accuracy 0.375375
[36m(ClientAppActor pid=5182)[0m [Client 4] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures


[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.053560275584459305, accuracy 0.374


[92mINFO [0m:      fit progress: (2, 0.04942217069864273, {'accuracy': 0.4264}, 147.5841882499999)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation loss 0.04942217069864273 / accuracy 0.4264
[36m(ClientAppActor pid=5182)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 4] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=5182)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.049831755459308624, accuracy 0.41375
[36m(ClientAppActor pid=5182)[0m [Client 1] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.04969844967126846, accuracy 0.419
[36m(ClientAppActor pid=5182)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.05007021501660347, accuracy 0.40975
[36m(ClientAppActor pid=5182)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.049229834228754044, accuracy 0.419375
[36m(ClientAppActor pid=5182)[0m [Client 4] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures


[36m(ClientAppActor pid=5182)[0m Epoch 1: train loss 0.04988604411482811, accuracy 0.421125


[92mINFO [0m:      fit progress: (3, 0.04650962634086609, {'accuracy': 0.4575}, 222.536352159)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation loss 0.04650962634086609 / accuracy 0.4575
[36m(ClientAppActor pid=5182)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=5182)[0m [Client 4] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 239.19s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.06022637792825698
[92mINFO [0m:      		round 2: 0.050080990052223204
[92mINFO [0m:      		round 3: 0.047179259955883034
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 0.07216868712902069
[92mINFO [0m:      		round 1: 0.05965306313037872
[92mINFO [0m:      		round 2: 0.04942217069864273
[92mINFO [0m:      		round 3: 0.04650962634086609
[92mINFO [0m:      	History (metrics, centralized):
[92mINFO [0m:      	{'accuracy': [(0, 0.0615), (1, 0.3536), (2, 0.4264), (3, 0.4575)]}
[92mINFO [0m:      
