# Flower + Pytorch + FedPredict

This tutorial implements a simple example of "Flower + Pytorch + FedPredict".
For comparison, it it possible to execute the strategies "FedAvg+FP" (i.e., using FedPredict) and "FedAvg" (i.e., original solution).

To run in a Jupyer or [![Google Colab environment](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/claudiocapanema/fedpredict/examples/FedPredict-in-20-minutes/tutorial.ipynb), this example uses the method `run_simulation` from:
```Python
from flwr.simulation import run_simulation
```

However, using the `flwr run` CLI command to launch experiments it more professional. A similar tutorial using this command can be found in [here]().

### Experiment config

In [21]:
BATCH_SIZE = 32
ALPHA = 0.1 # [0.1, 1.0]
STRATEGY = "FedAvg+FP" # FedAvg+FP or FedAvg
LOCAL_EPOCHS = 1
LEARNING_RATE = 0.1
NUM_ROUNDS = 10
NUM_PARTITIONS = 10 # number of clients

### Model definition and utils

In [22]:
"""pytorch_fedpredict_example: A Flower / PyTorch app."""

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, ToTensor
import logging
logging.basicConfig(level=logging.INFO)  # Configure logging
logger = logging.getLogger(__name__)  # Create logger for the module


class Net(torch.nn.Module):
    """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


def get_weights(net):
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_weights(net, parameters):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


fds = None  # Cache FederatedDataset


def load_data(partition_id: int, num_partitions: int, alpha: float, batch_size: int):
    """Load partition CIFAR10 data."""
    # Only initialize `FederatedDataset` once
    global fds
    if fds is None:
        partitioner = DirichletPartitioner(num_partitions=num_partitions, partition_by="label",
                                     alpha=alpha, min_partition_size=10)
        fds = FederatedDataset(
            dataset="uoft-cs/cifar10",
            partitioners={"train": partitioner},
        )
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
    pytorch_transforms = Compose(
        [ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    def apply_transforms(batch):
        """Apply transforms to the partition from FederatedDataset."""
        batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=batch_size, shuffle=True
    )
    testloader = DataLoader(partition_train_test["test"], batch_size=batch_size)
    return trainloader, testloader


def train(net, trainloader, valloader, epochs, learning_rate, device):
    """Train the model on the training set."""
    net.to(device)  # move model to GPU if available
    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
    net.train()
    for _ in range(epochs):
        for batch in trainloader:
            images = batch["img"]
            labels = batch["label"]
            optimizer.zero_grad()
            criterion(net(images.to(device)), labels.to(device)).backward()
            optimizer.step()

    val_loss, val_acc = test(net, valloader, device)

    results = {
        "val_loss": val_loss,
        "val_accuracy": val_acc,
    }
    return results


def test(net, testloader, device):
    """Validate the model on the test set."""
    net.to(device)  # move model to GPU if available
    criterion = torch.nn.CrossEntropyLoss()
    correct, loss = 0, 0.0
    with torch.no_grad():
        for batch in testloader:
            images = batch["img"].to(device)
            labels = batch["label"].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
    accuracy = correct / len(testloader.dataset)
    loss = loss / len(testloader)
    return loss, accuracy

### FedAvg client

In [23]:
"""pytorch_fedpredict_example: A Flower / PyTorch / FedPredict app."""

import copy
import sys
import torch
from flwr.client import NumPyClient

import logging
logging.basicConfig(level=logging.INFO)  # Configure logging
logger = logging.getLogger(__name__)  # Create logger for the module

# Define FedAvg Client
class FedAvgClient(NumPyClient):
    def __init__(self, trainloader, valloader, local_epochs, learning_rate, num_server_rounds, client_id, client_state):
        try:
            self.local_model = Net()
            self.trainloader = trainloader
            self.valloader = valloader
            self.local_epochs = local_epochs
            self.lr = learning_rate
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        except Exception as e:
            logger.error("__init__ error")
            logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))

    def fit(self, parameters, config):
        """Train the model with data of this client."""
        try:
            """Train the model with data of this client."""
            set_weights(self.local_model, parameters)
            results = train(
                self.local_model,
                self.trainloader,
                self.valloader,
                self.local_epochs,
                self.lr,
                self.device,
            )
            return get_weights(self.local_model), len(self.trainloader.dataset), results
        except Exception as e:
            logger.error("fit error")
            logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))

    def evaluate(self, parameters, config):
        """Evaluate the model on the data this client has."""
        try:
            set_weights(self.local_model, parameters)
            loss, accuracy = test(self.local_model, self.valloader, self.device)
            return loss, len(self.valloader.dataset), {"accuracy": accuracy}
        except Exception as e:
            logger.error("evaluate error")
            logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))

### FedAvg+FP client

This class implements the required modifications to include the plugin.

In [24]:
"""pytorch_fedpredict_example: A Flower / PyTorch / FedPredict app."""

import copy
from fedpredict import fedpredict_client_torch
from flwr.common import ConfigRecord

# Define FedAvg+FP Client
class FedAvgClientFP(FedAvgClient):
    def __init__(self, trainloader, valloader, local_epochs, learning_rate, num_server_rounds, client_id, client_state):
        try:
            super().__init__(trainloader, valloader, local_epochs, learning_rate, num_server_rounds, client_id, client_state)
            self.global_model = copy.deepcopy(self.local_model)
            self.lt = 0 # last round the client trained
            self.num_server_rounds = num_server_rounds
            self.client_state = client_state
        except Exception as e:
            logger.error("__init__ error")
            logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))

    def fit(self, parameters, config):
        """Train the model with data of this client."""
        try:
            t = config["server_round"]
            self.lt = t
            results = super().fit(parameters, config)
            self._save_layer_weights_to_state()
            return results
        except Exception as e:
            logger.error("fit error")
            logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))

    def evaluate(self, parameters, config):
        """Evaluate the model on the data this client has."""
        try:

            set_weights(self.global_model, parameters)
            t = config["server_round"]
            self._load_layer_weights_from_state()
            # Calculate the number of consecutive rounds the client has not been selected for training (nt)."
            nt = t - self.lt
            # Get the "combined_model" from FedPredict
            combined_model = fedpredict_client_torch(local_model=self.local_model, global_model=self.global_model,
                                                     t=t, T=self.num_server_rounds, nt=nt, device=self.device)
            # Test the "combined_model"
            loss, accuracy = test(combined_model, self.valloader, self.device)
            return loss, len(self.valloader.dataset), {"accuracy": accuracy}
        except Exception as e:
            logger.error("evaluate error")
            logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))

    def _save_layer_weights_to_state(self):
        """Save last layer weights to state."""
        try:
            arr_record = ArrayRecord(torch_state_dict=self.local_model.state_dict())

            # Add to RecordDict (replace if already exists)
            self.client_state["model"] = arr_record
            self.client_state["lt"] = ConfigRecord(config_dict={"lt": self.lt})
        except Exception as e:
            logger.error("_save_layer_weights_to_state error")
            logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))

    def _load_layer_weights_from_state(self):
        """Load last layer weights to state."""
        if "model" not in self.client_state.array_records:
            return

        state_dict = self.client_state["model"].to_torch_state_dict()
        self.lt = self.client_state["lt"]["lt"]

        # apply previously saved classification head by this client
        self.local_model.load_state_dict(state_dict, strict=True)

### Client App

Here we define the client app:

In [25]:
"""pytorch_fedpredict_example: A Flower / PyTorch / FedPredict app."""

import sys
from flwr.client import ClientApp
from flwr.common import ArrayRecord, Context

import logging
logging.basicConfig(level=logging.INFO)  # Configure logging
logger = logging.getLogger(__name__)  # Create logger for the module

def client_fn(context: Context):
    """Construct a Client that will be run in a ClientApp."""
    try:
        # Read the node_config to fetch data partition associated to this node
        partition_id = context.node_config["partition-id"]
        trainloader, valloader = load_data(partition_id, num_partitions, ALPHA, BATCH_SIZE)
        client_state = context.state

        # Return Client instance
        if STRATEGY == "FedAvg+FP":
            return FedAvgClientFP(trainloader, valloader, LOCAL_EPOCHS, LEARNING_RATE, NUM_ROUNDS, partition_id,
                                client_state).to_client()
        elif STRATEGY == "FedAvg":
            return FedAvgClient(trainloader, valloader, LOCAL_EPOCHS, LEARNING_RATE, NUM_ROUNDS, partition_id,
                                  client_state).to_client()
        else:
            raise ValueError("Unknown strategy")
    except Exception as e:
        logger.error("client_fn error")
        logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))

# Flower ClientApp
try:
    client_app = ClientApp(client_fn)
except Exception as e:
    logger.error("app error")
    logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))


### Server App

Here we define the server app:

In [26]:
"""pytorch_fedpredict_example: A Flower / PyTorch / FedPredict app.
    This basic version of FedPredict requires a small modification
    on the server side:the server must communicate the current
     training round number to the selected client during each
     training cycle. """

from typing import List, Tuple

from flwr.common import Context, Metrics, Scalar, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
import logging
logging.basicConfig(level=logging.INFO)  # Configure logging
logger = logging.getLogger(__name__)  # Create logger for the module


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

def on_fit_config_fn(server_round: int) -> dict[str, Scalar]:
    config = {"server_round": server_round}
    return config

def on_evaluate_config_fn(server_round: int) -> dict[str, Scalar]:
    config = {"server_round": server_round}
    return config

def server_fn(context: Context):
    """Construct components that set the ServerApp behaviour."""
    try:

        # Initialize model parameters
        ndarrays = get_weights(Net())
        parameters = ndarrays_to_parameters(ndarrays)

        # Define the strategy
        strategy = FedAvg(
            fraction_fit=0.3,
            fraction_evaluate=1,
            min_available_clients=2,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            evaluate_metrics_aggregation_fn=weighted_average,
            initial_parameters=parameters,
        )
        config = ServerConfig(num_rounds=NUM_ROUNDS)

        return ServerAppComponents(strategy=strategy, config=config)
    except Exception as e:
        logger.error("server_fn error")
        logger.error("""Error on line {} {} {}""".format(sys.exc_info()[-1].tb_lineno, type(e).__name__, e))

# Create ServerApp
server_app = ServerApp(server_fn=server_fn)


### Launching the Simulation

With both `ClientApp` and `ServerApp` ready, we can launch the simulation. Pass both apps to the `run_simulation()` function and specify the number of `supernodes` (this is a more general term used in Flower to refer to individual "nodes" or "clients"). We earlier partitioned the dataset into 100 partitions, one for each supernode. So we indicate that `num_supernodes`=100.

In [27]:
from flwr.simulation import run_simulation

run_simulation(
    server_app=server_app, client_app=client_app, num_supernodes=NUM_PARTITIONS
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=10, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)
[36m(ClientAppActor pid=26420)[0m   obj.co_lnotab,  # for < python 3.10 [not counted in args]
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 10 clients (out of 10)
[36m(ClientAppActor pid=26419)[0m   obj.co_lnotab,  # for < python 3.10 [not counted in args][32m [repeated 3x across cluster][0m
[36m(ClientAppActor pid=26419)[0m   obj.co_lnotab,  # for < python 3.10 [not counted in args][32m [repeated 8x across cluster][0m
[92mINFO [0m