# Carbon Aware Federated Learning using Flower and Carbon Aware SDK
This notebook follows along the official "Introduction to Federated Learning" from the official Flower [documentation](https://flower.dev/docs/tutorial/Flower-1-Intro-to-FL-PyTorch.html) using PyTorch and the CIFAR10 dataset

The extra steps to make your federated learning carben aware using the Carbon Aware SDK are highlighted along the way.

In [1]:
import random
from collections import OrderedDict
from typing import List

import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

In [2]:
DEVICE = torch.device("cpu")
NUM_CLIENTS = 20
BATCH_SIZE = 32

In [3]:
CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader

trainloaders, valloaders, testloader = load_datasets(num_clients=NUM_CLIENTS)

Files already downloaded and verified
Files already downloaded and verified


## Implementation of the neural net
using the CIFAR10 dataset following along the [Flower documentation](https://flower.dev/docs/tutorial/Flower-1-Intro-to-FL-PyTorch.html)

In [4]:
def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader


class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")

def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

def weighted_average(metrics):
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

## Implementation of the Flower Client
The simple implementation of the FlowerClient follows once again the [Flower documentation](https://flower.dev/docs/tutorial/Flower-1-Intro-to-FL-PyTorch.html), however, we make it carbon aware using the lowcarb framework.
Here you have two options:
-   inherit lowcarb's Lowcarb_Client and make sure your client has a 'location' attribute,
-   if your own future client implementation needs to implement the get_properties() method, decorate it with @lowcarb and go on with your business.
```
from lowcarb import Lowcarb_Client, lowcarb

class My_Custom_Client(Lowcarb_Client, fl.client.NumPyClient)

    @lowcarb
    def get_properties(config):
        my_business = {'value1': 123, 'value2': 456, 'value3': 789}

        return my_business
```

Note:
-   the @lowcarb decorated attaches a 'location' attribute to your get_properties() return. lowcarb's Client_Manager we will meet further down needs this information for a carbon aware scheduling.
    ```
    my_custom_client.get_properties({}) -> {'value1': 123, 'value2': 456, 'value3': 789, 'location': my_custom_client.location}
    ```
-   If the locatio of the client changes througout its lifetime, update the location in your get_properties() implementation

In [5]:
from lowcarb import Lowcarb_Client

class FlowerClient(Lowcarb_Client, fl.client.NumPyClient):
    def __init__(self, net, trainloader, valloader, location, *args, **kwargs):
        super(FlowerClient, self).__init__(*args, **kwargs)
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

        self.location = location

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        local_epochs = 1

        set_parameters(self.net, parameters)
        # train(self.net, self.trainloader, epochs=local_epochs)
        return get_parameters(self.net), len(self.trainloader), {'fit_runtime': 100}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

## Generating clients around the world

The client_generator generates instances of clients to be used in the simulation. It onces again follows along the tutorial in the [Flower documentation](https://flower.dev/docs/tutorial/Flower-1-Intro-to-FL-PyTorch.html), however, the location is chosen randomly from regions = ['eastus', 'westus', 'germany', 'norway', 'denmark']
If you need a more sophisticated selection of available clients feel free to edit client_generator()


IMPORTANT: this is not the selection of clients that will participate in each federated learning round, this is the function to create the pool of available clients that later on will be selected.

In [6]:
from random import sample
available_regions = ['westcentralus', 'ukwest', 'uksouth', 'westeurope', 'westus', 'australiacentral', 'australiaeast', 'swedencentral', 'norwaywest', 'norwayeast', 'northeurope', 'centralus', 'francesouth', 'francecentral']
regions = [sample(available_regions, 1)[0] for i in range(0, NUM_CLIENTS)]

def client_generator(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""

    # Load model
    net = Net().to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    # Create a  single Flower client representing a single organization
    return FlowerClient(net, trainloader, valloader, regions[int(cid)])

In [7]:
# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.025,  # Sample 100% of available clients for training
    fraction_evaluate=0.025,  # Sample 50% of available clients for evaluation
    min_fit_clients=10,  # Never sample less than 10 clients for training
    min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
    min_available_clients=10,  # Wait until all 10 clients are available
    evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
    # on_fit_config_fn=fit_config,  # fit_config function
    # evaluate_fn=server_evaluate
)

Here we start the federated learning simulation with our modified client_generator that gives us clients spread around the world, each with their own data.

To make it carbon aware, use lowcarb's client_manager and you are good to go.

__For some reason the simulation fails on Windows 10 after a few rounds, which is probably related to an [open Ray issue](https://github.com/ray-project/ray/issues/24361)__
Maybe it works on MacOS/Linux

In [8]:
from lowcarb import LowCarb_ClientManager
# Start simulation
fl.simulation.start_simulation(
    client_fn=client_generator,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
    client_manager=LowCarb_ClientManager(api_host='https://carbon-aware-api.azurewebsites.net', workload_duration=15, forecast_window=12),
)

INFO flower 2022-10-22 12:15:22,493 | app.py:142 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO flower 2022-10-22 12:15:25,147 | app.py:176 | Flower VCE: Ray initialized with resources: {'object_store_memory': 16171597824.0, 'CPU': 24.0, 'node:127.0.0.1': 1.0, 'memory': 32343195648.0}
INFO flower 2022-10-22 12:15:25,148 | server.py:86 | Initializing global parameters
INFO flower 2022-10-22 12:15:25,148 | server.py:270 | Requesting initial parameters from one random client


_______________________________________________________________________
 Available Clients with their locations
_______________________________________________________________________
0 westeurope
1 australiaeast
2 westeurope
3 centralus
4 centralus
5 ukwest
6 uksouth
7 francecentral
8 australiacentral
9 northeurope
10 francesouth
11 australiacentral
12 francecentral
13 swedencentral
14 westcentralus
15 northeurope
16 australiacentral
17 westcentralus
18 uksouth
19 uksouth
_______________________________________________________________________
 Available Clients with their participation
_______________________________________________________________________
0 0
1 0
2 0
3 0
4 0
5 0
6 0
7 0
8 0
9 0
10 0
11 0
12 0
13 0
14 0
15 0
16 0
17 0
18 0
19 0
_______________________________________________________________________
 selected low carbon clients
_______________________________________________________________________
3


INFO flower 2022-10-22 12:15:43,770 | server.py:274 | Received initial parameters from one random client
INFO flower 2022-10-22 12:15:43,770 | server.py:88 | Evaluating initial parameters
INFO flower 2022-10-22 12:15:43,771 | server.py:101 | FL starting
DEBUG flower 2022-10-22 12:16:00,702 | server.py:215 | fit_round 1: LowCarb_Strategy sampled 10 clients (out of 20)


_______________________________________________________________________
 Available Clients with their locations
_______________________________________________________________________
0 westeurope
1 australiaeast
2 westeurope
3 centralus
4 centralus
5 ukwest
6 uksouth
7 francecentral
8 australiacentral
9 northeurope
10 francesouth
11 australiacentral
12 francecentral
13 swedencentral
14 westcentralus
15 northeurope
16 australiacentral
17 westcentralus
18 uksouth
19 uksouth
_______________________________________________________________________
 Available Clients with their participation
_______________________________________________________________________
0 0
1 0
2 0
3 1
4 0
5 0
6 0
7 0
8 0
9 0
10 0
11 0
12 0
13 0
14 0
15 0
16 0
17 0
18 0
19 0
_______________________________________________________________________
 selected low carbon clients
_______________________________________________________________________
10
12
7
6
18
19
5
2
0
9


DEBUG flower 2022-10-22 12:16:07,954 | server.py:229 | fit_round 1 received 10 results and 0 failures
DEBUG flower 2022-10-22 12:16:24,894 | server.py:165 | evaluate_round 1: LowCarb_Strategy sampled 5 clients (out of 20)


_______________________________________________________________________
 Available Clients with their locations
_______________________________________________________________________
0 westeurope
1 australiaeast
2 westeurope
3 centralus
4 centralus
5 ukwest
6 uksouth
7 francecentral
8 australiacentral
9 northeurope
10 francesouth
11 australiacentral
12 francecentral
13 swedencentral
14 westcentralus
15 northeurope
16 australiacentral
17 westcentralus
18 uksouth
19 uksouth
_______________________________________________________________________
 Available Clients with their participation
_______________________________________________________________________
0 1
1 0
2 1
3 1
4 0
5 1
6 1
7 1
8 0
9 1
10 1
11 0
12 1
13 0
14 0
15 0
16 0
17 0
18 1
19 1
_______________________________________________________________________
 selected low carbon clients
_______________________________________________________________________
15
13
4
14
17


DEBUG flower 2022-10-22 12:16:28,159 | server.py:179 | evaluate_round 1 received 5 results and 0 failures
DEBUG flower 2022-10-22 12:16:45,553 | server.py:215 | fit_round 2: LowCarb_Strategy sampled 10 clients (out of 20)


_______________________________________________________________________
 Available Clients with their locations
_______________________________________________________________________
0 westeurope
1 australiaeast
2 westeurope
3 centralus
4 centralus
5 ukwest
6 uksouth
7 francecentral
8 australiacentral
9 northeurope
10 francesouth
11 australiacentral
12 francecentral
13 swedencentral
14 westcentralus
15 northeurope
16 australiacentral
17 westcentralus
18 uksouth
19 uksouth
_______________________________________________________________________
 Available Clients with their participation
_______________________________________________________________________
0 1
1 0
2 1
3 1
4 1
5 1
6 1
7 1
8 0
9 1
10 1
11 0
12 1
13 1
14 1
15 1
16 0
17 1
18 1
19 1
_______________________________________________________________________
 selected low carbon clients
_______________________________________________________________________
8
16
1
11
3
4
14
17
19
18


DEBUG flower 2022-10-22 12:16:51,870 | server.py:229 | fit_round 2 received 10 results and 0 failures


System error: Unknown error


[2m[36m(pid=)[0m [2022-10-22 12:17:03,306 E 19668 23332] (raylet.exe) agent_manager.cc:107: The raylet exited immediately because the Ray agent failed. The raylet fate shares with the agent. This can happen because the Ray agent was unexpectedly killed or failed. See `dashboard_agent.log` for the root cause.


System error: Unknown error
System error: Unknown error


[2m[36m(pid=)[0m [2022-10-22 12:17:05,382 E 29080 4988] (gcs_server.exe) gcs_server.cc:283: Failed to get the resource load: GrpcUnavailable: RPC Error message: failed to connect to all addresses; RPC Error details: 
[2m[36m(pid=)[0m [2022-10-22 12:17:05,382 E 29080 4988] (gcs_server.exe) gcs_server.cc:283: Failed to get the resource load: GrpcUnavailable: RPC Error message: failed to connect to all addresses; RPC Error details: 
[2m[36m(pid=)[0m [2022-10-22 12:17:05,383 E 29080 4988] (gcs_server.exe) gcs_server.cc:283: Failed to get the resource load: GrpcUnavailable: RPC Error message: failed to connect to all addresses; RPC Error details: 
[2m[36m(pid=)[0m [2022-10-22 12:17:06,358 E 29080 4988] (gcs_server.exe) gcs_server.cc:283: Failed to get the resource load: GrpcUnavailable: RPC Error message: failed to connect to all addresses; RPC Error details: 
[2m[36m(pid=)[0m [2022-10-22 12:17:07,365 E 29080 4988] (gcs_server.exe) gcs_server.cc:283: Failed to get the resourc

KeyError: '17'