# Flower tutorial
## Building a strategy

A custom strategy enables granular control over client node configuration, result aggregation, and more. To define a custom strategy, you only have to overwrite the abstract methods of the (abstract) base class `Strategy`. To make custom strategies even more powerful, you can pass custom functions to the constructor of your new class `(__init__)` and then call these functions whenever needed.

[tutorial link](https://flower.dev/docs/tutorial/Flower-3-Building-a-Strategy-PyTorch.html)

In [1]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

import flwr as fl

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(
    f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}"
)

Training on cpu using PyTorch 2.0.0 and Flower 1.5.0.dev20230427


# Data loading

In [2]:
NUM_CLIENTS = 10


def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader


trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

Files already downloaded and verified
Files already downloaded and verified


# Model training/evaluation
Let’s continue with the usual model definition (including `set_parameters` and `get_parameters`), training and test functions:

In [3]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

# Flower client
To implement the Flower client, we (again) create a subclass of `flwr.client.NumPyClient` and implement the three methods `get_parameters`, `fit`, and `evaluate`. Here, we also pass the `cid` to the client and use it log additional details:

In [4]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

Let’s test what we have so far before we continue:

In [5]:
# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": 1}

fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
    client_resources=client_resources,
)

INFO flwr 2023-05-23 13:21:26,024 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
2023-05-23 13:21:27,757	INFO worker.py:1625 -- Started a local Ray instance.
INFO flwr 2023-05-23 13:21:28,339 | app.py:180 | Flower VCE: Ray initialized with resources: {'node:127.0.0.1': 1.0, 'CPU': 8.0, 'object_store_memory': 2147483648.0, 'memory': 8582786253.0}
INFO flwr 2023-05-23 13:21:28,339 | server.py:86 | Initializing global parameters
INFO flwr 2023-05-23 13:21:28,339 | server.py:273 | Requesting initial parameters from one random client
INFO flwr 2023-05-23 13:21:29,501 | server.py:277 | Received initial parameters from one random client
INFO flwr 2023-05-23 13:21:29,502 | server.py:88 | Evaluating initial parameters
INFO flwr 2023-05-23 13:21:29,502 | server.py:101 | FL starting
DEBUG flwr 2023-05-23 13:21:29,502 | server.py:218 | fit_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_get_parameters pid=7878)[0m [Client 0] get_parameters
[2m[36m(launch_and_fit pid=7878)[0m [Client 1] fit, config: {}
[2m[36m(launch_and_fit pid=7878)[0m Epoch 1: train loss 0.06487391144037247, accuracy 0.21911111111111112


DEBUG flwr 2023-05-23 13:21:32,476 | server.py:232 | fit_round 1 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:32,479 | server.py:168 | evaluate_round 1: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-05-23 13:21:33,594 | server.py:182 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:33,595 | server.py:218 | fit_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_evaluate pid=7878)[0m [Client 1] evaluate, config: {}


DEBUG flwr 2023-05-23 13:21:35,914 | server.py:232 | fit_round 2 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:35,917 | server.py:168 | evaluate_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=7880)[0m [Client 0] fit, config: {}[32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)[0m


DEBUG flwr 2023-05-23 13:21:37,050 | server.py:182 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:37,051 | server.py:218 | fit_round 3: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=7880)[0m Epoch 1: train loss 0.056883227080106735, accuracy 0.3417777777777778[32m [repeated 3x across cluster][0m


DEBUG flwr 2023-05-23 13:21:39,404 | server.py:232 | fit_round 3 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:39,406 | server.py:168 | evaluate_round 3: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_evaluate pid=7880)[0m [Client 0] evaluate, config: {}[32m [repeated 3x across cluster][0m


DEBUG flwr 2023-05-23 13:21:40,504 | server.py:182 | evaluate_round 3 received 2 results and 0 failures
INFO flwr 2023-05-23 13:21:40,504 | server.py:147 | FL finished in 11.002069499983918
INFO flwr 2023-05-23 13:21:40,505 | app.py:218 | app_fit: losses_distributed [(1, 0.06242328441143036), (2, 0.05558965194225311), (3, 0.05388106870651245)]
INFO flwr 2023-05-23 13:21:40,505 | app.py:219 | app_fit: metrics_distributed_fit {}
INFO flwr 2023-05-23 13:21:40,505 | app.py:220 | app_fit: metrics_distributed {}
INFO flwr 2023-05-23 13:21:40,506 | app.py:221 | app_fit: losses_centralized []
INFO flwr 2023-05-23 13:21:40,506 | app.py:222 | app_fit: metrics_centralized {}


History (loss, distributed):
	round 1: 0.06242328441143036
	round 2: 0.05558965194225311
	round 3: 0.05388106870651245

# Build a Strategy from scratch
Let’s overwrite the `configure_fit` method such that it passes a higher learning rate (potentially also other hyperparameters) to the optimizer of a fraction of the clients. We will keep the sampling of the clients as it is in `FedAvg` and then change the configuration dictionary (one of the `FitIns` attributes).

In [6]:
from typing import Callable, Union

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg


class FedCustom(fl.server.strategy.Strategy):
    def __init__(
        self,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
    ) -> None:
        super().__init__()
        self.fraction_fit = fraction_fit
        self.fraction_evaluate = fraction_evaluate
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients

    def __repr__(self) -> str:
        return "FedCustom"

    def initialize_parameters(
        self, client_manager: ClientManager
    ) -> Optional[Parameters]:
        """Initialize global model parameters."""
        net = Net()
        ndarrays = get_parameters(net)
        return fl.common.ndarrays_to_parameters(ndarrays)

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training."""

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Create custom configs
        n_clients = len(clients)
        half_clients = n_clients // 2
        standard_config = {"lr": 0.001}
        higher_lr_config = {"lr": 0.003}
        fit_configurations = []
        for idx, client in enumerate(clients):
            if idx < half_clients:
                fit_configurations.append((client, FitIns(parameters, standard_config)))
            else:
                fit_configurations.append(
                    (client, FitIns(parameters, higher_lr_config))
                )
        return fit_configurations

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""

        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]
        parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))
        metrics_aggregated = {}
        return parameters_aggregated, metrics_aggregated

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        if self.fraction_evaluate == 0.0:
            return []
        config = {}
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients]

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation losses using weighted average."""

        if not results:
            return None, {}

        loss_aggregated = weighted_loss_avg(
            [
                (evaluate_res.num_examples, evaluate_res.loss)
                for _, evaluate_res in results
            ]
        )
        metrics_aggregated = {}
        return loss_aggregated, metrics_aggregated

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate global model parameters using an evaluation function."""

        # Let's assume we won't perform the global model evaluation on the server side.
        return None

    def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Return sample size and required number of clients."""
        num_clients = int(num_available_clients * self.fraction_fit)
        return max(num_clients, self.min_fit_clients), self.min_available_clients

    def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
        """Use a fraction of available clients for evaluation."""
        num_clients = int(num_available_clients * self.fraction_evaluate)
        return max(num_clients, self.min_evaluate_clients), self.min_available_clients

The only thing left is to use the newly created custom Strategy FedCustom when starting the experiment:

In [7]:
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=2,
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=FedCustom(),  # <-- pass the new strategy here
    client_resources=client_resources,
)

INFO flwr 2023-05-23 13:21:40,526 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)


[2m[36m(launch_and_fit pid=7880)[0m [Client 1] fit, config: {}[32m [repeated 2x across cluster][0m
[2m[36m(launch_and_fit pid=7880)[0m Epoch 1: train loss 0.052589356899261475, accuracy 0.3864444444444444[32m [repeated 2x across cluster][0m
[2m[36m(launch_and_evaluate pid=7880)[0m [Client 1] evaluate, config: {}[32m [repeated 2x across cluster][0m


2023-05-23 13:21:44,955	INFO worker.py:1625 -- Started a local Ray instance.
INFO flwr 2023-05-23 13:21:45,586 | app.py:180 | Flower VCE: Ray initialized with resources: {'memory': 8625946624.0, 'node:127.0.0.1': 1.0, 'CPU': 8.0, 'object_store_memory': 2147483648.0}
INFO flwr 2023-05-23 13:21:45,586 | server.py:86 | Initializing global parameters
INFO flwr 2023-05-23 13:21:45,588 | server.py:269 | Using initial parameters provided by strategy
INFO flwr 2023-05-23 13:21:45,588 | server.py:88 | Evaluating initial parameters
INFO flwr 2023-05-23 13:21:45,589 | server.py:101 | FL starting
DEBUG flwr 2023-05-23 13:21:45,589 | server.py:218 | fit_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=7902)[0m [Client 0] fit, config: {'lr': 0.001}


DEBUG flwr 2023-05-23 13:21:48,614 | server.py:232 | fit_round 1 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:48,617 | server.py:168 | evaluate_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=7902)[0m Epoch 1: train loss 0.06622076779603958, accuracy 0.22577777777777777


DEBUG flwr 2023-05-23 13:21:49,729 | server.py:182 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:49,729 | server.py:218 | fit_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_evaluate pid=7902)[0m [Client 1] evaluate, config: {}


DEBUG flwr 2023-05-23 13:21:52,068 | server.py:232 | fit_round 2 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:52,071 | server.py:168 | evaluate_round 2: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-05-23 13:21:53,175 | server.py:182 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:53,175 | server.py:218 | fit_round 3: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_fit pid=7900)[0m [Client 0] fit, config: {'lr': 0.001}[32m [repeated 3x across cluster][0m
[2m[36m(launch_and_fit pid=7900)[0m Epoch 1: train loss 0.058432046324014664, accuracy 0.31377777777777777[32m [repeated 3x across cluster][0m


DEBUG flwr 2023-05-23 13:21:55,497 | server.py:232 | fit_round 3 received 2 results and 0 failures
DEBUG flwr 2023-05-23 13:21:55,500 | server.py:168 | evaluate_round 3: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_evaluate pid=7900)[0m [Client 1] evaluate, config: {}[32m [repeated 3x across cluster][0m


DEBUG flwr 2023-05-23 13:21:56,601 | server.py:182 | evaluate_round 3 received 2 results and 0 failures
INFO flwr 2023-05-23 13:21:56,601 | server.py:147 | FL finished in 11.012399083003402
INFO flwr 2023-05-23 13:21:56,602 | app.py:218 | app_fit: losses_distributed [(1, 0.06430582904815674), (2, 0.05756994104385375), (3, 0.054374657511711125)]
INFO flwr 2023-05-23 13:21:56,602 | app.py:219 | app_fit: metrics_distributed_fit {}
INFO flwr 2023-05-23 13:21:56,602 | app.py:220 | app_fit: metrics_distributed {}
INFO flwr 2023-05-23 13:21:56,602 | app.py:221 | app_fit: losses_centralized []
INFO flwr 2023-05-23 13:21:56,602 | app.py:222 | app_fit: metrics_centralized {}


History (loss, distributed):
	round 1: 0.06430582904815674
	round 2: 0.05756994104385375
	round 3: 0.054374657511711125