# Lesson 3: Tuning

Welcome to Lesson 3!

To access the `requirements.txt` and `utils3.py` file for this course, go to `File` and click `Open`.

#### 1. Load imports

In [1]:
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

from utils3 import *

#### 2. Prepare the datasets

* Prepare data using Flower Datasets.

Use `flwr-datasets` that provides with a Federated Dataset abstraction.

In [2]:
def load_data(partition_id):
    fds = FederatedDataset(dataset="mnist", partitioners={"train": 5})
    partition = fds.load_partition(partition_id)

    traintest = partition.train_test_split(test_size=0.2, seed=42)
    traintest = traintest.with_transform(normalize)
    trainset, testset = traintest["train"], traintest["test"]

    trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
    testloader = DataLoader(testset, batch_size=64)
    return trainloader, testloader

#### 3. Clients configuration

* Define fit_config.

Flower can send configuration values to clients.

In [3]:
def fit_config(server_round: int):
    config_dict = {
        "local_epochs": 2 if server_round < 3 else 5,
    }
    return config_dict

* The FedAvg strategy in the Server Function.

In [4]:
net = SimpleModel()
params = ndarrays_to_parameters(get_weights(net))

def server_fn(context: Context):
    strategy = FedAvg(
        min_fit_clients=5,
        fraction_evaluate=0.0,
        initial_parameters=params,
        on_fit_config_fn=fit_config,  # <- NEW
    )
    config=ServerConfig(num_rounds=3)
    return ServerAppComponents(
        strategy=strategy,
        config=config,
    )

* Define an instance of ServerApp.

In [5]:
server = ServerApp(server_fn=server_fn)

* Define FlowerClient.

The client side receives the configuration dictionary in the `fit` method.

In [6]:
class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, testloader):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader

    def fit(self, parameters, config):
        set_weights(self.net, parameters)

        epochs = config["local_epochs"]
        log(INFO, f"client trains for {epochs} epochs")
        train_model(self.net, self.trainloader, epochs)

        return get_weights(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_weights(self.net, parameters)
        loss, accuracy = evaluate_model(self.net, self.testloader)
        return loss, len(self.testloader), {"accuracy": accuracy}

* Create the Client Function and the Client App.

In [7]:
def client_fn(context: Context) -> Client:
    net = SimpleModel()
    partition_id = int(context.node_config["partition-id"])
    trainloader, testloader = load_data(partition_id=partition_id)
    return FlowerClient(net, trainloader, testloader).to_client()


client = ClientApp(client_fn)

* Run Client and Server apps.

In [8]:
run_simulation(server_app=server,
               client_app=client,
               num_supernodes=5,
               backend_config=backend_setup
               )

[92mINFO [0m: Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m: 
[92mINFO [0m: [INIT]
[92mINFO [0m: Using initial global parameters provided by strategy
[92mINFO [0m: Evaluating initial global parameters
[92mINFO [0m: 
[92mINFO [0m: [ROUND 1]
[92mINFO [0m: configure_fit: strategy sampled 5 clients (out of 5)
Downloading builder script: 100%|██████████| 3.98k/3.98k [00:00<00:00, 24.2MB/s]
Downloading readme: 100%|██████████| 6.83k/6.83k [00:00<00:00, 37.0MB/s]
Downloading data:   0%|          | 0.00/9.91M [00:00<?, ?B/s]
Downloading data: 100%|██████████| 9.91M/9.91M [00:00<00:00, 70.6MB/s]
Downloading data: 100%|██████████| 28.9k/28.9k [00:00<00:00, 10.2MB/s]
Downloading data:   0%|          | 0.00/1.65M [00:00<?, ?B/s]
Downloading data: 100%|██████████| 1.65M/1.65M [00:00<00:00, 58.7MB/s]
Downloading data: 100%|██████████| 4.54k/4.54k [00:00<00:00, 22.1MB/s]
Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]
Generatin

[2m[36m(ClientAppActor pid=520)[0m [92mINFO [0m: client trains for 2 epochs[32m [repeated 3x across cluster][0m
[92mINFO [0m: aggregate_fit: received 5 results and 0 failures
[92mINFO [0m: configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m: 
[92mINFO [0m: [ROUND 3]
[92mINFO [0m: configure_fit: strategy sampled 5 clients (out of 5)
[2m[36m(ClientAppActor pid=522)[0m [92mINFO [0m: client trains for 5 epochs[32m [repeated 2x across cluster][0m
[2m[36m(ClientAppActor pid=522)[0m [92mINFO [0m: client trains for 5 epochs[32m [repeated 3x across cluster][0m
[92mINFO [0m: aggregate_fit: received 5 results and 0 failures
[92mINFO [0m: configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m: 
[92mINFO [0m: [SUMMARY]
[92mINFO [0m: Run finished 3 round(s) in 67.44s
[92mINFO [0m: 
[2m[36m(ClientAppActor pid=521)[0m [92mINFO [0m: client trains for 5 epochs
