# Use a federated learning strategy

Welcome to the next part of the federated learning tutorial. In previous parts of this tutorial, we introduced federated learning with PyTorch and Flower ([part 1](https://flower.ai/docs/framework/tutorial-get-started-with-flower-pytorch.html)).

In this notebook, we'll begin to customize the federated learning system we built in the introductory notebook again, using the Flower framework, Flower Datasets, and PyTorch.

> [Star Flower on GitHub](https://github.com/adap/flower) ⭐️ and join the Flower community on Flower Discuss and the Flower Slack to connect, ask questions, and get help:
> - [Join Flower Discuss](https://discuss.flower.ai/) We'd love to hear from you in the `Introduction` topic! If anything is unclear, post in `Flower Help - Beginners`.
> - [Join Flower Slack](https://flower.ai/join-slack) We'd love to hear from you in the `#introductions` channel! If anything is unclear, head over to the `#questions` channel.

Let's move beyond FedAvg with Flower strategies! 🌼

## Preparation

Before we begin with the actual code, let's make sure that we have everything we need.

### Installing dependencies

First, we install the necessary packages:

In [1]:
!pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.6/65.6 MB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.9/3.9 MB[0m [31m73.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m33.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m80.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m71.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.0/236.0 kB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.3/47.3 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Now that we have all dependencies installed, we can import everything we need for this tutorial:

In [10]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context

DEVICE = torch.device("cuda")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

Training on cuda
Flower 1.14.0 / PyTorch 2.5.1+cu121


It is possible to switch to a runtime that has GPU acceleration enabled (on Google Colab: `Runtime > Change runtime type > Hardware acclerator: GPU > Save`). Note, however, that Google Colab is not always able to offer GPU acceleration. If you see an error related to GPU availability in one of the following sections, consider switching back to CPU-based execution by setting `DEVICE = torch.device("cpu")`. If the runtime has GPU acceleration enabled, you should see the output `Training on cuda`, otherwise it'll say `Training on cpu`.

### Data loading

Let's now load the CIFAR-10 training and test set, partition them into ten smaller datasets (each split into training and validation set), and wrap everything in their own `DataLoader`. We introduce a new parameter `num_partitions` which allows us to call `load_datasets` with different numbers of partitions.

In [3]:
NUM_PARTITIONS = 10
BATCH_SIZE = 32


def load_datasets(partition_id: int, num_partitions: int):
    fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to CIFAR10(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
    )
    valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

### Model training/evaluation

Let's continue with the usual model definition (including `set_parameters` and `get_parameters`), training and test functions:

In [4]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

### Flower client

To implement the Flower client, we (again) create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`. Here, we also pass the `partition_id` to the client and use it log additional details. We then create an instance of `ClientApp` and pass it the `client_fn`.

In [5]:
class FlowerClient(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.partition_id}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.partition_id}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)

    # Read the node_config to fetch data partition associated to this node
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]

    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

## Strategy customization

So far, everything should look familiar if you've worked through the introductory notebook. With that, we're ready to introduce a number of new features.

### Server-side parameter **initialization**

Flower, by default, initializes the global model by asking one random client for the initial parameters. In many cases, we want more control over parameter initialization though. Flower therefore allows you to directly pass the initial parameters to the Strategy. We create an instance of `Net()` and get the paramaters as follows:

In [6]:
# Create an instance of the model and get the parameters
params = get_parameters(Net())

Next, we create a `server_fn` that returns the components needed for the server. Within `server_fn`, we create a Strategy that uses the initial parameters.

In [7]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=0.3,
        fraction_evaluate=0.3,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(
            params
        ),  # Pass initial model parameters
    )

    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)

Passing `initial_parameters` to the `FedAvg` strategy prevents Flower from asking one of the clients for the initial parameters. In `server_fn`, we pass this new `strategy` and a `ServerConfig` for defining the number of federated learning rounds (`num_rounds`).

Similar to the `ClientApp`, we now create the `ServerApp` using the `server_fn`:

In [8]:
# Create ServerApp
server = ServerApp(server_fn=server_fn)

In [11]:
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [12]:
# Specify backend configuration based on GPU availability
backend_config = {"client_resources": {"num_cpus": 2, "num_gpus": 1}}  # Allocate 1 GPU if CUDA is available

# Run simulation with appropriate resource allocation
try:
    run_simulation(
        server_app=server,
        client_app=client,
        num_supernodes=NUM_PARTITIONS,
        backend_config=backend_config,
    )
except KeyError as e:
    print(f"KeyError encountered: {e}. Please check your backend configuration.")
except RuntimeError as e:
    print(f"RuntimeError encountered: {e}. The simulation engine might have crashed.")


[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)
[36m(pid=4502)[0m 2025-01-21 08:05:04.680963: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=4502)[0m 2025-01-21 08:05:04.718876: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=4502)[0m 2025-01-21 08:05:04.730359: E external/local_xla/

[36m(ClientAppActor pid=4502)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=4502)[0m Epoch 1: train loss 0.06487863510847092, accuracy 0.221
[36m(ClientAppActor pid=4502)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=4502)[0m Epoch 1: train loss 0.06423641741275787, accuracy 0.2325
[36m(ClientAppActor pid=4502)[0m [Client 4] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=4502)[0m Epoch 1: train loss 0.06505753099918365, accuracy 0.2245
[36m(ClientAppActor pid=4502)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=4502)[0m [Client 6] evaluate, config: {}
[36m(ClientAppActor pid=4502)[0m [Client 7] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=4502)[0m [Client 2] fit, config: {}
[36m(ClientAppActor pid=4502)[0m Epoch 1: train loss 0.05622417479753494, accuracy 0.329
[36m(ClientAppActor pid=4502)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=4502)[0m Epoch 1: train loss 0.05519891157746315, accuracy 0.3375
[36m(ClientAppActor pid=4502)[0m [Client 7] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=4502)[0m Epoch 1: train loss 0.05661184713244438, accuracy 0.32525
[36m(ClientAppActor pid=4502)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=4502)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=4502)[0m [Client 8] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=4502)[0m [Client 4] fit, config: {}
[36m(ClientAppActor pid=4502)[0m Epoch 1: train loss 0.0533718504011631, accuracy 0.3635
[36m(ClientAppActor pid=4502)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=4502)[0m Epoch 1: train loss 0.05301301181316376, accuracy 0.37
[36m(ClientAppActor pid=4502)[0m [Client 8] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=4502)[0m Epoch 1: train loss 0.052811428904533386, accuracy 0.366
[36m(ClientAppActor pid=4502)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=4502)[0m [Client 6] evaluate, config: {}
[36m(ClientAppActor pid=4502)[0m [Client 7] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 126.49s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.06285407102108002
[92mINFO [0m:      		round 2: 0.05398792048295339
[92mINFO [0m:      		round 3: 0.05214744730790457
[92mINFO [0m:      


Last but not least, we specify the resources for each client and run the simulation.

 If we look closely, we can see that the logs do not show any calls to the `FlowerClient.get_parameters` method.

### Starting with a customized strategy

We've seen the function `run_simulation` before. It accepts a number of arguments, amongst them the `server_app` which wraps around the strategy and number of training rounds, `client_app` which wraps around the `client_fn` used to create `FlowerClient` instances, and the number of clients to simulate which equals `num_supernodes`.

The strategy encapsulates the federated learning approach/algorithm, for example, `FedAvg` or `FedAdagrad`. Let's try to use a different strategy this time:

In [13]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAdagrad strategy
    strategy = FedAdagrad(
        fraction_fit=0.3,
        fraction_evaluate=0.3,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(params),
    )
    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)
[36m(pid=5667)[0m 2025-01-21 08:08:04.857082: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=5667)[0m 2025-01-21 08:08:04.876494: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=5667)[0m 2025-01-21 08:08:04.882379: E external/local_xla/

[36m(ClientAppActor pid=5667)[0m [Client 3] fit, config: {}
[36m(ClientAppActor pid=5667)[0m Epoch 1: train loss 0.06357608735561371, accuracy 0.23525
[36m(ClientAppActor pid=5667)[0m [Client 5] fit, config: {}
[36m(ClientAppActor pid=5667)[0m Epoch 1: train loss 0.06433352828025818, accuracy 0.225
[36m(ClientAppActor pid=5667)[0m [Client 8] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=5667)[0m Epoch 1: train loss 0.06448725610971451, accuracy 0.22375
[36m(ClientAppActor pid=5667)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=5667)[0m [Client 6] evaluate, config: {}
[36m(ClientAppActor pid=5667)[0m [Client 7] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=5667)[0m [Client 6] fit, config: {}
[36m(ClientAppActor pid=5667)[0m Epoch 1: train loss 0.8559237122535706, accuracy 0.30575
[36m(ClientAppActor pid=5667)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=5667)[0m Epoch 1: train loss 0.8274834156036377, accuracy 0.2915
[36m(ClientAppActor pid=5667)[0m [Client 9] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=5667)[0m Epoch 1: train loss 0.8296621441841125, accuracy 0.30875
[36m(ClientAppActor pid=5667)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=5667)[0m [Client 6] evaluate, config: {}
[36m(ClientAppActor pid=5667)[0m [Client 9] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=5667)[0m [Client 4] fit, config: {}
[36m(ClientAppActor pid=5667)[0m Epoch 1: train loss 0.09926921129226685, accuracy 0.14575
[36m(ClientAppActor pid=5667)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=5667)[0m Epoch 1: train loss 0.0998692661523819, accuracy 0.1405
[36m(ClientAppActor pid=5667)[0m [Client 8] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=5667)[0m Epoch 1: train loss 0.09942536056041718, accuracy 0.13775
[36m(ClientAppActor pid=5667)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=5667)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=5667)[0m [Client 9] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 119.36s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 5.222560839335124
[92mINFO [0m:      		round 2: 0.475978541692098
[92mINFO [0m:      		round 3: 0.13511009939511617
[92mINFO [0m:      


## Server-side parameter **evaluation**

Flower can evaluate the aggregated model on the server-side or on the client-side. Client-side and server-side evaluation are similar in some ways, but different in others.

**Centralized Evaluation** (or *server-side evaluation*) is conceptually simple: it works the same way that evaluation in centralized machine learning does. If there is a server-side dataset that can be used for evaluation purposes, then that's great. We can evaluate the newly aggregated model after each round of training without having to send the model to clients. We're also fortunate in the sense that our entire evaluation dataset is available at all times.

**Federated Evaluation** (or *client-side evaluation*) is more complex, but also more powerful: it doesn't require a centralized dataset and allows us to evaluate models over a larger set of data, which often yields more realistic evaluation results. In fact, many scenarios require us to use **Federated Evaluation** if we want to get representative evaluation results at all. But this power comes at a cost: once we start to evaluate on the client side, we should be aware that our evaluation dataset can change over consecutive rounds of learning if those clients are not always available. Moreover, the dataset held by each client can also change over consecutive rounds. This can lead to evaluation results that are not stable, so even if we would not change the model, we'd see our evaluation results fluctuate over consecutive rounds.

We've seen how federated evaluation works on the client side (i.e., by implementing the `evaluate` method in `FlowerClient`). Now let's see how we can evaluate aggregated model parameters on the server-side:

In [15]:
# The `evaluate` function will be called by Flower after every round
def evaluate(
    server_round: int,
    parameters: NDArrays,
    config: Dict[str, Scalar],
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
    net = Net().to(DEVICE)
    _, _, testloader = load_datasets(0, NUM_PARTITIONS)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}

We create a `FedAvg` strategy and pass `evaluate_fn` to it. Then, we create a `ServerApp` that uses this strategy.

In [16]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create the FedAvg strategy
    strategy = FedAvg(
        fraction_fit=0.3,
        fraction_evaluate=0.3,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(params),
        evaluate_fn=evaluate,  # Pass the evaluation function
    )
    # Configure the server for 3 rounds of training
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

Finally, we run the simulation.

In [17]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
[36m(pid=6722)[0m 2025-01-21 08:10:41.783042: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=6722)[0m 2025-01-21 08:10:41.804024: E external/local_xla/xla/stre

Server-side evaluation loss 0.0721279774427414 / accuracy 0.1


[36m(ClientAppActor pid=6722)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=6722)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=6722)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=6722)[0m   from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write


[36m(ClientAppActor pid=6722)[0m [Client 4] fit, config: {}
[36m(ClientAppActor pid=6722)[0m Epoch 1: train loss 0.06411401182413101, accuracy 0.2485
[36m(ClientAppActor pid=6722)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=6722)[0m Epoch 1: train loss 0.06486368179321289, accuracy 0.21375
[36m(ClientAppActor pid=6722)[0m [Client 9] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


[36m(ClientAppActor pid=6722)[0m Epoch 1: train loss 0.06465820223093033, accuracy 0.23275


[92mINFO [0m:      fit progress: (1, 0.06082897651195526, {'accuracy': 0.2771}, 36.675614015000065)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


Server-side evaluation loss 0.06082897651195526 / accuracy 0.2771
[36m(ClientAppActor pid=6722)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=6722)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=6722)[0m [Client 5] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=6722)[0m [Client 0] fit, config: {}
[36m(ClientAppActor pid=6722)[0m Epoch 1: train loss 0.05637635290622711, accuracy 0.324
[36m(ClientAppActor pid=6722)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=6722)[0m Epoch 1: train loss 0.05631326138973236, accuracy 0.32
[36m(ClientAppActor pid=6722)[0m [Client 8] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


[36m(ClientAppActor pid=6722)[0m Epoch 1: train loss 0.05606072023510933, accuracy 0.32475


[92mINFO [0m:      fit progress: (2, 0.05305157095193863, {'accuracy': 0.3612}, 79.23173681499998)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


Server-side evaluation loss 0.05305157095193863 / accuracy 0.3612
[36m(ClientAppActor pid=6722)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=6722)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=6722)[0m [Client 6] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=6722)[0m [Client 7] fit, config: {}
[36m(ClientAppActor pid=6722)[0m Epoch 1: train loss 0.052959784865379333, accuracy 0.36175
[36m(ClientAppActor pid=6722)[0m [Client 8] fit, config: {}
[36m(ClientAppActor pid=6722)[0m Epoch 1: train loss 0.05251153185963631, accuracy 0.36775
[36m(ClientAppActor pid=6722)[0m [Client 9] fit, config: {}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


[36m(ClientAppActor pid=6722)[0m Epoch 1: train loss 0.052268996834754944, accuracy 0.38525


[92mINFO [0m:      fit progress: (3, 0.050012921643257144, {'accuracy': 0.4091}, 121.37425202000009)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


Server-side evaluation loss 0.050012921643257144 / accuracy 0.4091
[36m(ClientAppActor pid=6722)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=6722)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=6722)[0m [Client 7] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 137.21s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.062190063397089636
[92mINFO [0m:      		round 2: 0.054067747712135315
[92mINFO [0m:      		round 3: 0.0519946782986323
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 0.0721279774427414
[92mINFO [0m:      		round 1: 0.06082897651195526
[92mINFO [0m:      		round 2: 0.05305157095193863
[92mINFO [0m:      		round 3: 0.050012921643257144
[92mINFO [0m:      	History (metrics, centralized):
[92mINFO [0m:      	{'accuracy': [(0, 0.1), (1, 0.2771), (2, 0.3612), (3, 0.4091)]}
[92mINFO [0m:      


## Sending/receiving arbitrary values to/from clients

In some situations, we want to configure client-side execution (training, evaluation) from the server-side. One example for that is the server asking the clients to train for a certain number of local epochs. Flower provides a way to send configuration values from the server to the clients using a dictionary. Let's look at an example where the clients receive values from the server through the `config` parameter in `fit` (`config` is also available in `evaluate`). The `fit` method receives the configuration dictionary through the `config` parameter and can then read values from this dictionary. In this example, it reads `server_round` and `local_epochs` and uses those values to improve the logging and configure the number of local training epochs:

In [18]:
class FlowerClient(NumPyClient):
    def __init__(self, pid, net, trainloader, valloader):
        self.pid = pid  # partition ID of a client
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.pid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        # Read values from config
        server_round = config["server_round"]
        local_epochs = config["local_epochs"]

        # Use values provided by the config
        print(f"[Client {self.pid}, round {server_round}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=local_epochs)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.pid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

So how can we  send this config dictionary from server to clients? The built-in Flower Strategies provide way to do this, and it works similarly to the way server-side evaluation works. We provide a function to the strategy, and the strategy calls this function for every round of federated learning:

In [19]:
def fit_config(server_round: int):
    """Return training configuration dict for each round.

    Perform two rounds of training with one local epoch, increase to two local
    epochs afterwards.
    """
    config = {
        "server_round": server_round,  # The current round of federated learning
        "local_epochs": 1 if server_round < 2 else 2,
    }
    return config

Next, we'll pass this function to the FedAvg strategy before starting the simulation:

In [20]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=0.3,
        fraction_evaluate=0.3,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(params),
        evaluate_fn=evaluate,
        on_fit_config_fn=fit_config,  # Pass the fit_config function
    )
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[36m(pid=7970)[0m 2025-01-21 08:14:00.462211: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=7970)[0m 2025-01-21 08:14:00.485232: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=7970)[0m 2025-01-21 08:14:00.491977: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[92mINFO [0m:      ini

Server-side evaluation loss 0.0721279774427414 / accuracy 0.1


[36m(ClientAppActor pid=7970)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=7970)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=7970)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=7970)[0m   from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write


[36m(ClientAppActor pid=7970)[0m [Client 2, round 1] fit, config: {'server_round': 1, 'local_epochs': 1}
[36m(ClientAppActor pid=7970)[0m Epoch 1: train loss 0.06503771990537643, accuracy 0.22175
[36m(ClientAppActor pid=7970)[0m [Client 3, round 1] fit, config: {'server_round': 1, 'local_epochs': 1}
[36m(ClientAppActor pid=7970)[0m Epoch 1: train loss 0.06413747370243073, accuracy 0.231
[36m(ClientAppActor pid=7970)[0m [Client 7, round 1] fit, config: {'server_round': 1, 'local_epochs': 1}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


[36m(ClientAppActor pid=7970)[0m Epoch 1: train loss 0.06427401304244995, accuracy 0.22425


[92mINFO [0m:      fit progress: (1, 0.0593475800037384, {'accuracy': 0.3005}, 35.01486138600012)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


Server-side evaluation loss 0.0593475800037384 / accuracy 0.3005
[36m(ClientAppActor pid=7970)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=7970)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=7970)[0m [Client 3] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=7970)[0m [Client 0, round 2] fit, config: {'server_round': 2, 'local_epochs': 2}
[36m(ClientAppActor pid=7970)[0m Epoch 1: train loss 0.0560138076543808, accuracy 0.335
[36m(ClientAppActor pid=7970)[0m Epoch 2: train loss 0.05282445624470711, accuracy 0.37125
[36m(ClientAppActor pid=7970)[0m [Client 6, round 2] fit, config: {'server_round': 2, 'local_epochs': 2}
[36m(ClientAppActor pid=7970)[0m Epoch 1: train loss 0.055216964334249496, accuracy 0.336
[36m(ClientAppActor pid=7970)[0m Epoch 2: train loss 0.051443662494421005, accuracy 0.38875
[36m(ClientAppActor pid=7970)[0m [Client 9, round 2] fit, config: {'server_round': 2, 'local_epochs': 2}
[36m(ClientAppActor pid=7970)[0m Epoch 1: train loss 0.05557616427540779, accuracy 0.3485


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


[36m(ClientAppActor pid=7970)[0m Epoch 2: train loss 0.05158259719610214, accuracy 0.401


[92mINFO [0m:      fit progress: (2, 0.0511463111281395, {'accuracy': 0.4008}, 83.30867298400017)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


Server-side evaluation loss 0.0511463111281395 / accuracy 0.4008
[36m(ClientAppActor pid=7970)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=7970)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=7970)[0m [Client 5] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=7970)[0m [Client 0, round 3] fit, config: {'server_round': 3, 'local_epochs': 2}
[36m(ClientAppActor pid=7970)[0m Epoch 1: train loss 0.051159217953681946, accuracy 0.39825
[36m(ClientAppActor pid=7970)[0m Epoch 2: train loss 0.04912010207772255, accuracy 0.4185
[36m(ClientAppActor pid=7970)[0m [Client 8, round 3] fit, config: {'server_round': 3, 'local_epochs': 2}
[36m(ClientAppActor pid=7970)[0m Epoch 1: train loss 0.05124373361468315, accuracy 0.38875
[36m(ClientAppActor pid=7970)[0m Epoch 2: train loss 0.04875548928976059, accuracy 0.419
[36m(ClientAppActor pid=7970)[0m [Client 9, round 3] fit, config: {'server_round': 3, 'local_epochs': 2}
[36m(ClientAppActor pid=7970)[0m Epoch 1: train loss 0.05055130645632744, accuracy 0.408


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


[36m(ClientAppActor pid=7970)[0m Epoch 2: train loss 0.04799516126513481, accuracy 0.44425


[92mINFO [0m:      fit progress: (3, 0.04826355242729187, {'accuracy': 0.4248}, 130.613215202)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


Server-side evaluation loss 0.04826355242729187 / accuracy 0.4248
[36m(ClientAppActor pid=7970)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=7970)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=7970)[0m [Client 9] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 146.44s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.060336949507395425
[92mINFO [0m:      		round 2: 0.05257604813575745
[92mINFO [0m:      		round 3: 0.04941874321301778
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 0.0721279774427414
[92mINFO [0m:      		round 1: 0.0593475800037384
[92mINFO [0m:      		round 2: 0.0511463111281395
[92mINFO [0m:      		round 3: 0.04826355242729187
[92mINFO [0m:      	History (metrics, centralized):
[92mINFO [0m:      	{'accuracy': [(0, 0.1), (1, 0.3005), (2, 0.4008), (3, 0.4248)]}
[92mINFO [0m:      


As we can see, the client logs now include the current round of federated learning (which they read from the `config` dictionary). We can also configure local training to run for one epoch during the first and second round of federated learning, and then for two epochs during the third round.

Clients can also return arbitrary values to the server. To do so, they return a dictionary from `fit` and/or `evaluate`. We have seen and used this concept throughout this notebook without mentioning it explicitly: our `FlowerClient` returns a dictionary containing a custom key/value pair as the third return value in `evaluate`.

## Scaling federated learning

As a last step in this notebook, let's see how we can use Flower to experiment with a large number of clients.

In [21]:
NUM_PARTITIONS = 1000

  and should_run_async(code)


Note that we can reuse the `ClientApp` for different `num-partitions` since the Context is defined by the `num_supernodes` argument in `run_simulation()`.

We now have 1000 partitions, each holding 45 training and 5 validation examples. Given that the number of training examples on each client is quite small, we should probably train the model a bit longer, so we configure the clients to perform 3 local training epochs. We should also adjust the fraction of clients selected for training during each round (we don't want all 1000 clients participating in every round), so we adjust `fraction_fit` to `0.025`, which means that only 2.5% of available clients (so 25 clients) will be selected for training each round:


In [22]:
def fit_config(server_round: int):
    config = {
        "server_round": server_round,
        "local_epochs": 3,
    }
    return config


def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=0.025,  # Train on 25 clients (each round)
        fraction_evaluate=0.05,  # Evaluate on 50 clients (each round)
        min_fit_clients=20,
        min_evaluate_clients=40,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(params),
        on_fit_config_fn=fit_config,
    )
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 21 clients (out of 1000)
[36m(pid=9048)[0m 2025-01-21 08:16:40.637993: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=9048)[0m 2025-01-21 08:16:40.665012: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=9048)[0m 2025-01-21 08:16:40.671447: E external/local_x

[36m(ClientAppActor pid=9048)[0m [Client 157, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11441856622695923, accuracy 0.1
[36m(ClientAppActor pid=9048)[0m Epoch 2: train loss 0.11407946795225143, accuracy 0.125
[36m(ClientAppActor pid=9048)[0m Epoch 3: train loss 0.11248277872800827, accuracy 0.3
[36m(ClientAppActor pid=9048)[0m [Client 216, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11592202633619308, accuracy 0.075
[36m(ClientAppActor pid=9048)[0m Epoch 2: train loss 0.11406165361404419, accuracy 0.075
[36m(ClientAppActor pid=9048)[0m Epoch 3: train loss 0.11426541954278946, accuracy 0.125
[36m(ClientAppActor pid=9048)[0m [Client 236, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11393915861845016, accuracy 0.175
[36m(ClientAppActor pid=9048)[0m Epoch

[92mINFO [0m:      aggregate_fit: received 21 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 50 clients (out of 1000)


[36m(ClientAppActor pid=9048)[0m [Client 207, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11543109267950058, accuracy 0.125
[36m(ClientAppActor pid=9048)[0m Epoch 2: train loss 0.11405813694000244, accuracy 0.125
[36m(ClientAppActor pid=9048)[0m Epoch 3: train loss 0.11320360004901886, accuracy 0.125
[36m(ClientAppActor pid=9048)[0m [Client 11] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 24] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 71] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 276] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 278] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 998] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 852] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 914] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 452] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 50 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 25 clients (out of 1000)


[36m(ClientAppActor pid=9048)[0m [Client 488] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 59, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11534950882196426, accuracy 0.1
[36m(ClientAppActor pid=9048)[0m Epoch 2: train loss 0.11319249123334885, accuracy 0.1
[36m(ClientAppActor pid=9048)[0m Epoch 3: train loss 0.11337079852819443, accuracy 0.15
[36m(ClientAppActor pid=9048)[0m [Client 94, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11570998281240463, accuracy 0.05
[36m(ClientAppActor pid=9048)[0m Epoch 2: train loss 0.1143534928560257, accuracy 0.2
[36m(ClientAppActor pid=9048)[0m Epoch 3: train loss 0.11291766166687012, accuracy 0.375
[36m(ClientAppActor pid=9048)[0m [Client 147, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.113978125154

[92mINFO [0m:      aggregate_fit: received 25 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 50 clients (out of 1000)


[36m(ClientAppActor pid=9048)[0m [Client 870, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11510872840881348, accuracy 0.15
[36m(ClientAppActor pid=9048)[0m Epoch 2: train loss 0.11241376399993896, accuracy 0.15
[36m(ClientAppActor pid=9048)[0m Epoch 3: train loss 0.1119036003947258, accuracy 0.15
[36m(ClientAppActor pid=9048)[0m [Client 65] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 216] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 591] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 712] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 809] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 897] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 975] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 286] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 939] evaluate, config: {}
[

[92mINFO [0m:      aggregate_evaluate: received 50 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 25 clients (out of 1000)


[36m(ClientAppActor pid=9048)[0m [Client 10] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 326, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11473257839679718, accuracy 0.075
[36m(ClientAppActor pid=9048)[0m Epoch 2: train loss 0.11253738403320312, accuracy 0.175
[36m(ClientAppActor pid=9048)[0m Epoch 3: train loss 0.10651582479476929, accuracy 0.275
[36m(ClientAppActor pid=9048)[0m [Client 358, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11251883953809738, accuracy 0.175
[36m(ClientAppActor pid=9048)[0m Epoch 2: train loss 0.11026220768690109, accuracy 0.3
[36m(ClientAppActor pid=9048)[0m Epoch 3: train loss 0.11065860092639923, accuracy 0.3
[36m(ClientAppActor pid=9048)[0m [Client 537, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.114365

[92mINFO [0m:      aggregate_fit: received 25 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 50 clients (out of 1000)


[36m(ClientAppActor pid=9048)[0m [Client 633, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=9048)[0m Epoch 1: train loss 0.11572638899087906, accuracy 0.1
[36m(ClientAppActor pid=9048)[0m Epoch 2: train loss 0.1115238219499588, accuracy 0.175
[36m(ClientAppActor pid=9048)[0m Epoch 3: train loss 0.11029378324747086, accuracy 0.275
[36m(ClientAppActor pid=9048)[0m [Client 92] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 167] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 341] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 616] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 866] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 985] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 266] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 426] evaluate, config: {}
[36m(ClientAppActor pid=9048)[0m [Client 38] evaluate, config: {}
[

[92mINFO [0m:      aggregate_evaluate: received 50 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 1134.62s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.22963387870788574
[92mINFO [0m:      		round 2: 0.22858363437652585
[92mINFO [0m:      		round 3: 0.22559631681442266
[92mINFO [0m:      


[36m(ClientAppActor pid=9048)[0m [Client 543] evaluate, config: {}


## Recap

In this notebook, we've seen how we can gradually enhance our system by customizing the strategy, initializing parameters on the server side, choosing a different strategy, and evaluating models on the server-side. That's quite a bit of flexibility with so little code, right?

In the later sections, we've seen how we can communicate arbitrary values between server and clients to fully customize client-side execution. With that capability, we built a large-scale Federated Learning simulation using the Flower Virtual Client Engine and ran an experiment involving 1000 clients in the same workload - all in a Jupyter Notebook!

## Next steps

Before you continue, make sure to join the Flower community on Flower Discuss ([Join Flower Discuss](https://discuss.flower.ai)) and on Slack ([Join Slack](https://flower.ai/join-slack/)).

There's a dedicated `#questions` channel if you need help, but we'd also love to hear who you are in `#introductions`!

The [Flower Federated Learning Tutorial - Part 3](https://flower.ai/docs/framework/tutorial-build-a-strategy-from-scratch-pytorch.html) shows how to build a fully custom `Strategy` from scratch.

In [23]:
# Specify the resources each of your clients need
# If set to none, by default, each client will be allocated 2x CPU and 0x GPUs
backend_config = {"client_resources": {"num_cpus": 2, "num_gpus": 0}}  # Default 2 CPUs, no GPU
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 2, "num_gpus": 1}}  # Use 1 GPU if CUDA is available

# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 25 clients (out of 1000)
[36m(pid=15778)[0m 2025-01-21 08:35:56.082452: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=15778)[0m 2025-01-21 08:35:56.102290: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=15778)[0m 2025-01-21 08:35:56.112579: E external/loca

[36m(ClientAppActor pid=15778)[0m [Client 285, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.115158312022686, accuracy 0.025
[36m(ClientAppActor pid=15778)[0m Epoch 2: train loss 0.11396624892950058, accuracy 0.175
[36m(ClientAppActor pid=15778)[0m Epoch 3: train loss 0.1135617271065712, accuracy 0.15
[36m(ClientAppActor pid=15778)[0m [Client 498, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.115766242146492, accuracy 0.025
[36m(ClientAppActor pid=15778)[0m Epoch 2: train loss 0.11441149562597275, accuracy 0.05
[36m(ClientAppActor pid=15778)[0m Epoch 3: train loss 0.11389360576868057, accuracy 0.2
[36m(ClientAppActor pid=15778)[0m [Client 500, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.11511470377445221, accuracy 0.1
[36m(ClientAppActor pid=15778)[0m E

[92mINFO [0m:      aggregate_fit: received 25 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 50 clients (out of 1000)


[36m(ClientAppActor pid=15778)[0m [Client 901, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.1155947670340538, accuracy 0.1
[36m(ClientAppActor pid=15778)[0m Epoch 2: train loss 0.11471480131149292, accuracy 0.1
[36m(ClientAppActor pid=15778)[0m Epoch 3: train loss 0.11364848911762238, accuracy 0.1
[36m(ClientAppActor pid=15778)[0m [Client 400] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 513] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 577] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 655] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 682] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 981] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 376] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 372] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 416] evaluate, co

[92mINFO [0m:      aggregate_evaluate: received 50 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 25 clients (out of 1000)


[36m(ClientAppActor pid=15778)[0m [Client 150] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 78, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.11513068526983261, accuracy 0.1
[36m(ClientAppActor pid=15778)[0m Epoch 2: train loss 0.11291127651929855, accuracy 0.125
[36m(ClientAppActor pid=15778)[0m Epoch 3: train loss 0.11247352510690689, accuracy 0.2
[36m(ClientAppActor pid=15778)[0m [Client 171, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.11457901448011398, accuracy 0.2
[36m(ClientAppActor pid=15778)[0m Epoch 2: train loss 0.1128481850028038, accuracy 0.2
[36m(ClientAppActor pid=15778)[0m Epoch 3: train loss 0.11109552532434464, accuracy 0.2
[36m(ClientAppActor pid=15778)[0m [Client 575, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.11

[92mINFO [0m:      aggregate_fit: received 25 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 50 clients (out of 1000)


[36m(ClientAppActor pid=15778)[0m [Client 958, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.11479492485523224, accuracy 0.025
[36m(ClientAppActor pid=15778)[0m Epoch 2: train loss 0.11293592303991318, accuracy 0.25
[36m(ClientAppActor pid=15778)[0m Epoch 3: train loss 0.11087167263031006, accuracy 0.225
[36m(ClientAppActor pid=15778)[0m [Client 104] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 374] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 460] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 712] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 945] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 86] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 947] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 298] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 77] evaluate

[92mINFO [0m:      aggregate_evaluate: received 50 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 25 clients (out of 1000)


[36m(ClientAppActor pid=15778)[0m [Client 610] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 84, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.11452563107013702, accuracy 0.075
[36m(ClientAppActor pid=15778)[0m Epoch 2: train loss 0.11293976753950119, accuracy 0.275
[36m(ClientAppActor pid=15778)[0m Epoch 3: train loss 0.1087932363152504, accuracy 0.25
[36m(ClientAppActor pid=15778)[0m [Client 172, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.11506728082895279, accuracy 0.15
[36m(ClientAppActor pid=15778)[0m Epoch 2: train loss 0.11138360947370529, accuracy 0.275
[36m(ClientAppActor pid=15778)[0m Epoch 3: train loss 0.10979592800140381, accuracy 0.275
[36m(ClientAppActor pid=15778)[0m [Client 476, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train l

[92mINFO [0m:      aggregate_fit: received 25 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 50 clients (out of 1000)


[36m(ClientAppActor pid=15778)[0m [Client 37, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=15778)[0m Epoch 1: train loss 0.11549980938434601, accuracy 0.05
[36m(ClientAppActor pid=15778)[0m Epoch 2: train loss 0.11224949359893799, accuracy 0.2
[36m(ClientAppActor pid=15778)[0m Epoch 3: train loss 0.11173105239868164, accuracy 0.2
[36m(ClientAppActor pid=15778)[0m [Client 47] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 133] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 327] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 544] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 788] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 967] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 614] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 513] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 498] evaluate, co

[92mINFO [0m:      aggregate_fit: received 0 results and 3 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=15778)[0m [Client 599] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 355] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 29] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 919] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 374] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 179] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 40] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 669] evaluate, config: {}
[36m(ClientAppActor pid=15778)[0m [Client 433] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 50 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 1041.42s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.22989180183410649
[92mINFO [0m:      		round 2: 0.22868901252746576
[92mINFO [0m:      		round 3: 0.22595208024978639
[92mINFO [0m:      


[36m(ClientAppActor pid=15778)[0m [Client 304] evaluate, config: {}
