# Carbon Aware Federated Learning using Flower and Carbon Aware SDK
This notebook follows along the official "Introduction to Federated Learning" from the official Flower [documentation](https://flower.dev/docs/tutorial/Flower-1-Intro-to-FL-PyTorch.html) using PyTorch and the CIFAR10 dataset

The extra steps to make your federated learning carben aware using the Carbon Aware SDK are highlighted along the way.

In [1]:
import random
from collections import OrderedDict
from typing import List

import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

In [2]:
DEVICE = torch.device("cpu")
NUM_CLIENTS = 10
BATCH_SIZE = 32

In [3]:
CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader

trainloaders, valloaders, testloader = load_datasets(num_clients=NUM_CLIENTS)

Files already downloaded and verified
Files already downloaded and verified


## Implementation of the neural net
using the CIFAR10 dataset following along the [Flower documentation](https://flower.dev/docs/tutorial/Flower-1-Intro-to-FL-PyTorch.html)

In [4]:
def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader


class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")

def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

def weighted_average(metrics):
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

## Implementation of the Flower Client
The simple implementation of the FlowerClient follows once again the [Flower documentation](https://flower.dev/docs/tutorial/Flower-1-Intro-to-FL-PyTorch.html), however, we make it carbon aware using the lowcarb framework.
Here you have two options:
-   inherit lowcarb's Lowcarb_Client and make sure your client has a 'location' attribute,
-   if your own future client implementation needs to implement the get_properties() method, decorate it with @lowcarb and go on with your business.
```
from lowcarb import Lowcarb_Client, lowcarb

class My_Custom_Client(Lowcarb_Client, fl.client.NumPyClient)

    @lowcarb
    def get_properties(config):
        my_business = {'value1': 123, 'value2': 456, 'value3': 789}

        return my_business
```

Note:
-   the @lowcarb decorated attaches a 'location' attribute to your get_properties() return. lowcarb's Client_Manager we will meet further down needs this information for a carbon aware scheduling.
    ```
    my_custom_client.get_properties({}) -> {'value1': 123, 'value2': 456, 'value3': 789, 'location': my_custom_client.location}
    ```
-   If the locatio of the client changes througout its lifetime, update the location in your get_properties() implementation

In [5]:
from lowcarb import Lowcarb_Client

class FlowerClient(Lowcarb_Client, fl.client.NumPyClient):
    def __init__(self, net, trainloader, valloader, location, *args, **kwargs):
        super(FlowerClient, self).__init__(*args, **kwargs)
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

        self.location = location

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        # local_epochs, internal_value = config["carbon_optimization"](self.location, 10)
        local_epochs, internal_value = 1, 'test inernalt value'


        # print(f'fitting with {local_epochs} local_epochs and "{internal_value}" came along')

        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=local_epochs)
        return get_parameters(self.net), len(self.trainloader), {'fit_runtime': 100}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

## Generating clients around the world

The client_generator generates instances of clients to be used in the simulation. It onces again follows along the tutorial in the [Flower documentation](https://flower.dev/docs/tutorial/Flower-1-Intro-to-FL-PyTorch.html), however, the location is chosen randomly from regions = ['eastus', 'westus', 'germany', 'norway', 'denmark']
If you need a more sophisticated selection of available clients feel free to edit client_generator()


IMPORTANT: this is not the selection of clients that will participate in each federated learning round, this is the function to create the pool of available clients that later on will be selected.

In [6]:
def client_generator(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""

    # Load model
    net = Net().to(DEVICE)

    # Load data (CIFAR-10)
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    regions = ['eastus', 'westus', 'germany', 'norway', 'denmark']

    location = regions[int(random.random()*len(regions))] ### weird hack because random.choice is somehow not working with Ray. I guess every spawned process gets the same random seed and starts at the same position. This somehow is not the case for random.random

    # Create a  single Flower client representing a single organization
    return FlowerClient(net, trainloader, valloader, location)

In [7]:
# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.025,  # Sample 100% of available clients for training
    fraction_evaluate=0.025,  # Sample 50% of available clients for evaluation
    min_fit_clients=10,  # Never sample less than 10 clients for training
    min_evaluate_clients=5,  # Never sample less than 5 clients for evaluation
    min_available_clients=10,  # Wait until all 10 clients are available
    evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
    # on_fit_config_fn=fit_config,  # fit_config function
    # evaluate_fn=server_evaluate
)

Here we start the federated learning simulation with our modified client_generator that gives us clients spread around the world, each with their own data.

To make it carbon aware, use lowcarb's client_manager and you are good to go.

__For some reason the simulation fails on Windows 10 after a few rounds, which is probably related to an [open Ray issue](https://github.com/ray-project/ray/issues/24361)__
Maybe it works on MacOS/Linux

In [None]:
from lowcarb import LowCarb_ClientManager
# Start simulation
fl.simulation.start_simulation(
    client_fn=client_generator,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
    client_manager=LowCarb_ClientManager(),
)

INFO flower 2022-10-13 14:00:54,045 | app.py:142 | Starting Flower simulation, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO flower 2022-10-13 14:00:56,660 | app.py:176 | Flower VCE: Ray initialized with resources: {'CPU': 24.0, 'object_store_memory': 13518057062.0, 'node:127.0.0.1': 1.0, 'memory': 27036114126.0}
INFO flower 2022-10-13 14:00:56,660 | server.py:86 | Initializing global parameters
INFO flower 2022-10-13 14:00:56,661 | server.py:270 | Requesting initial parameters from one random client


_______________________________________________________________________
 Available Clients with their locations
_______________________________________________________________________
  cid location
0   0   norway
1   1  germany
2   2  denmark
3   3   eastus
4   4   norway
5   5   norway
6   6  denmark
7   7   westus
8   8  denmark
9   9   westus
                         time       value  region
0   2022-10-13 12:05:00+00:00  452.940912  westus
1   2022-10-13 12:10:00+00:00  452.987852  westus
2   2022-10-13 12:15:00+00:00  452.859186  westus
3   2022-10-13 12:20:00+00:00  452.584744  westus
4   2022-10-13 12:25:00+00:00  452.548255  westus
..                        ...         ...     ...
282 2022-10-14 11:35:00+00:00  451.697146  westus
283 2022-10-14 11:40:00+00:00  451.600899  westus
284 2022-10-14 11:45:00+00:00  451.514785  westus
285 2022-10-14 11:50:00+00:00  451.533735  westus
286 2022-10-14 11:55:00+00:00  456.039648  westus

[1722 rows x 3 columns]
__________________________

INFO flower 2022-10-13 14:01:09,071 | server.py:274 | Received initial parameters from one random client
INFO flower 2022-10-13 14:01:09,071 | server.py:88 | Evaluating initial parameters
INFO flower 2022-10-13 14:01:09,072 | server.py:101 | FL starting


_______________________________________________________________________
 Available Clients with their locations
_______________________________________________________________________
  cid location
0   0   westus
1   1   westus
2   2  germany
3   3  denmark
4   4   westus
5   5  denmark
6   6   norway
7   7   eastus
8   8  germany
9   9   westus


DEBUG flower 2022-10-13 14:01:19,148 | server.py:215 | fit_round 1: LowCarb_Strategy sampled 1 clients (out of 10)


                         time       value  region
0   2022-10-13 12:05:00+00:00  452.940912  westus
1   2022-10-13 12:10:00+00:00  452.987852  westus
2   2022-10-13 12:15:00+00:00  452.859186  westus
3   2022-10-13 12:20:00+00:00  452.584744  westus
4   2022-10-13 12:25:00+00:00  452.548255  westus
..                        ...         ...     ...
282 2022-10-14 11:35:00+00:00  451.697146  westus
283 2022-10-14 11:40:00+00:00  451.600899  westus
284 2022-10-14 11:45:00+00:00  451.514785  westus
285 2022-10-14 11:50:00+00:00  451.533735  westus
286 2022-10-14 11:55:00+00:00  456.039648  westus

[1722 rows x 3 columns]
_______________________________________________________________________
 selected low carbon clients
_______________________________________________________________________
  cid location
6   6   norway


DEBUG flower 2022-10-13 14:01:20,813 | server.py:229 | fit_round 1 received 1 results and 0 failures


[2m[36m(launch_and_fit pid=11772)[0m Epoch 1: train loss 0.06491637974977493, accuracy 0.22755555555555557
_______________________________________________________________________
 Available Clients with their locations
_______________________________________________________________________
  cid location
0   0   eastus
1   1  denmark
2   2  germany
3   3  denmark
4   4   eastus
5   5  denmark
6   6   eastus
7   7   norway
8   8  denmark
9   9   norway


DEBUG flower 2022-10-13 14:01:30,866 | server.py:165 | evaluate_round 1: LowCarb_Strategy sampled 2 clients (out of 10)


                         time       value  region
0   2022-10-13 12:05:00+00:00  452.940912  westus
1   2022-10-13 12:10:00+00:00  452.987852  westus
2   2022-10-13 12:15:00+00:00  452.859186  westus
3   2022-10-13 12:20:00+00:00  452.584744  westus
4   2022-10-13 12:25:00+00:00  452.548255  westus
..                        ...         ...     ...
282 2022-10-14 11:35:00+00:00  451.697146  westus
283 2022-10-14 11:40:00+00:00  451.600899  westus
284 2022-10-14 11:45:00+00:00  451.514785  westus
285 2022-10-14 11:50:00+00:00  451.533735  westus
286 2022-10-14 11:55:00+00:00  456.039648  westus

[1722 rows x 3 columns]
_______________________________________________________________________
 selected low carbon clients
_______________________________________________________________________
  cid location
7   7   norway
9   9   norway


DEBUG flower 2022-10-13 14:01:32,894 | server.py:179 | evaluate_round 1 received 2 results and 0 failures


_______________________________________________________________________
 Available Clients with their locations
_______________________________________________________________________
  cid location
0   0  germany
1   1  denmark
2   2   westus
3   3   norway
4   4  denmark
5   5   eastus
6   6   eastus
7   7   norway
8   8   eastus
9   9   norway


DEBUG flower 2022-10-13 14:01:42,883 | server.py:215 | fit_round 2: LowCarb_Strategy sampled 3 clients (out of 10)


                         time       value  region
0   2022-10-13 12:05:00+00:00  452.940912  westus
1   2022-10-13 12:10:00+00:00  452.987852  westus
2   2022-10-13 12:15:00+00:00  452.859186  westus
3   2022-10-13 12:20:00+00:00  452.584744  westus
4   2022-10-13 12:25:00+00:00  452.548255  westus
..                        ...         ...     ...
282 2022-10-14 11:35:00+00:00  451.697146  westus
283 2022-10-14 11:40:00+00:00  451.600899  westus
284 2022-10-14 11:45:00+00:00  451.514785  westus
285 2022-10-14 11:50:00+00:00  451.533735  westus
286 2022-10-14 11:55:00+00:00  456.039648  westus

[1722 rows x 3 columns]
_______________________________________________________________________
 selected low carbon clients
_______________________________________________________________________
  cid location
3   3   norway
7   7   norway
9   9   norway
[2m[36m(launch_and_fit pid=11772)[0m Epoch 1: train loss 0.05637894570827484, accuracy 0.33266666666666667


DEBUG flower 2022-10-13 14:01:47,785 | server.py:229 | fit_round 2 received 3 results and 0 failures


[2m[36m(launch_and_fit pid=24412)[0m Epoch 1: train loss 0.05531826242804527, accuracy 0.3393333333333333
[2m[36m(launch_and_fit pid=29100)[0m Epoch 1: train loss 0.056043900549411774, accuracy 0.3337777777777778


KeyboardInterrupt: 

2022-10-13 14:01:52,709	ERROR worker.py:501 -- print_logs: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNKNOWN
	details = "Stream removed"
	debug_error_string = "UNKNOWN:Error received from peer ipv4:127.0.0.1:62628 {grpc_message:"Stream removed", grpc_status:2, created_time:"2022-10-13T12:01:52.678685808+00:00"}"
>
Exception in thread ray_listen_error_messages:
Traceback (most recent call last):
  File "C:\Program Files\Python39\lib\threading.py", line 954, in _bootstrap_inner
2022-10-13 14:01:52,723	ERROR import_thread.py:76 -- ImportThread: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNKNOWN
	details = "Stream removed"
	debug_error_string = "UNKNOWN:Error received from peer ipv4:127.0.0.1:62628 {grpc_message:"Stream removed", grpc_status:2, created_time:"2022-10-13T12:01:52.678668028+00:00"}"
>
    self.run()
  File "C:\Program Files\Python39\lib\threading.py", line 892, in run


    _, error_data = worker.gcs_error_subscriber.poll()
  File "C:\Users\skype\Documents\Programming\Carbon_Hack_22\carbonhack22\venv\lib\site-packages\ray\_private\gcs_pubsub.py", line 317, in poll
    self._poll_locked(timeout=timeout)
  File "C:\Users\skype\Documents\Programming\Carbon_Hack_22\carbonhack22\venv\lib\site-packages\ray\_private\gcs_pubsub.py", line 249, in _poll_locked
    fut.result(timeout=1)
  File "C:\Users\skype\Documents\Programming\Carbon_Hack_22\carbonhack22\venv\lib\site-packages\grpc\_channel.py", line 744, in result
    raise self
grpc._channel._MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNKNOWN
	details = "Stream removed"
	debug_error_string = "UNKNOWN:Error received from peer ipv4:127.0.0.1:62628 {created_time:"2022-10-13T12:01:52.678378868+00:00", grpc_status:2, grpc_message:"Stream removed"}"
>
