In [1]:
import argparse
import flwr as fl
from flwr.common.typing import Scalar
import ray
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import VisionDataset, Food101

import numpy as np
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Callable, Optional, Tuple, List, Any

In [2]:
import os
import multiprocessing

data_path = os.path.join(os.getcwd(),'data', 'food-101')
cpu_count = multiprocessing.cpu_count() - 1 # set as you like!

In [3]:

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.convolutional_layer = nn.Sequential(            
            nn.Conv2d(in_channels=3, out_channels=20, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride =2),
            nn.Conv2d(in_channels=20, out_channels=50, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride =2),
           
        )

        self.linear_layer = nn.Sequential(
            nn.Linear(in_features=1250, out_features=500),
            nn.ReLU(),
            nn.Linear(in_features=500, out_features=2),
        )


    def forward(self, x):
        x = self.convolutional_layer(x)
        x = torch.flatten(x, 1)
        x = self.linear_layer(x)
        x = F.softmax(x, dim=1)
        return x


# borrowed from Pytorch quickstart example
def train(net, trainloader, epochs, device: str):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    net.train()
    for _ in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()


# borrowed from Pytorch quickstart example
def test(net, testloader, device: str):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch[0].to(device), batch[1].to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy


In [4]:
# Flower client, adapted from Pytorch quickstart example
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid: str, fed_dir_data: str):
        self.cid = cid
        self.fed_dir = Path(fed_dir_data)
        self.properties: Dict[str, Scalar] = {"tensor_type": "numpy.ndarray"}

        # Instantiate model
        self.net = LeNet5()

        # Determine device
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def get_parameters(self, config):
        return get_params(self.net)

    def fit(self, parameters, config):
        set_params(self.net, parameters)

        # Load data for this client and get trainloader
        num_workers = int(ray.get_runtime_context().get_assigned_resources()["CPU"])
        trainloader = get_dataloader(
            self.fed_dir,
            self.cid,
            is_train=True,
            batch_size=config["batch_size"],
            workers=num_workers,
        )

        # Send model to device
        self.net.to(self.device)

        # Train
        train(self.net, trainloader, epochs=config["epochs"], device=self.device)

        # Return local model and statistics
        return get_params(self.net), len(trainloader.dataset), {}

    def evaluate(self, parameters, config):
        set_params(self.net, parameters)

        # Load data for this client and get trainloader
        num_workers = int(ray.get_runtime_context().get_assigned_resources()["CPU"])
        valloader = get_dataloader(
            self.fed_dir, self.cid, is_train=False, batch_size=50, workers=num_workers
        )

        # Send model to device
        self.net.to(self.device)

        # Evaluate
        loss, accuracy = test(self.net, valloader, device=self.device)

        # Return statistics
        return float(loss), len(valloader.dataset), {"accuracy": float(accuracy)}


def fit_config(server_round: int) -> Dict[str, Scalar]:
    """Return a configuration with static batch size and (local) epochs."""
    config = {
        "epochs": 5,  # number of local epochs
        "batch_size": 64,
    }
    return config


def get_params(model: torch.nn.ModuleList) -> List[np.ndarray]:
    """Get model weights as a list of NumPy ndarrays."""
    return [val.cpu().numpy() for _, val in model.state_dict().items()]


def set_params(model: torch.nn.ModuleList, params: List[np.ndarray]):
    """Set model weights from a list of NumPy ndarrays."""
    params_dict = zip(model.state_dict().keys(), params)
    state_dict = OrderedDict({k: torch.from_numpy(np.copy(v)) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)


def get_evaluate_fn(
    testset: torchvision.datasets.FashionMNIST,
) -> Callable[[fl.common.NDArrays], Optional[Tuple[float, float]]]:
    """Return an evaluation function for centralized evaluation."""

    def evaluate(
        server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar]
    ) -> Optional[Tuple[float, float]]:
        """Use the entire CIFAR-10 test set for evaluation."""

        # determine device
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        model = LeNet5()
        set_params(model, parameters)
        model.to(device)
        
        testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
        loss, accuracy = test(model, testloader, device=device)

        # return statistics
        return loss, {"accuracy": accuracy}

    return evaluate

In [5]:
class TorchVision_FL(VisionDataset):
    """This is just a trimmed down version of torchvision.datasets.MNIST.
    Use this class by either passing a path to a torch file (.pt)
    containing (data, targets) or pass the data, targets directly
    instead.
    """

    def __init__(
        self,
        path_to_data=None,
        data=None,
        targets=None,
        transform: Optional[Callable] = None,
    ) -> None:
        path = path_to_data.parent if path_to_data else None
        super(TorchVision_FL, self).__init__(path, transform=transform)
        self.transform = transform

        if path_to_data:
            # load data and targets (path_to_data points to an specific .pt file)
            self.data, self.targets = torch.load(path_to_data)
        else:
            self.data = data
            self.targets = targets

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        if not isinstance(img, Image.Image):  # if not PIL image
            if not isinstance(img, np.ndarray):  # if torch tensor
                img = img.numpy()

            img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.data)



In [6]:
pool_size = 100  # number of dataset partions (= number of total clients)

client_resources = {
        "num_cpus": cpu_count
}  # each client will get allocated 1 CPUs

transformations = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
])
 

# Download Dataset
try:
    train_data = Food101(data_path, transform=transformations)
except:
    train_data = Food101(data_path, transform=transformations, download=True) 
test_data = Food101(data_path, split='test', transform=transformations)

lengths = []
while sum(lengths) != len(train_data):
    lengths = [round(x) for x in np.random.dirichlet(
        np.ones(pool_size),size=1)[0] * len(train_data)]

client_datasets = torch.utils.data.random_split(train_data, lengths)

for i, x in enumerate(client_datasets):
    torch.save(x, os.path.join(data_path, 'client_%s' % str(i+1))) #human-readable clients starting w/client_1

In [7]:
len(test_data.classes)

101

In [8]:
# configure the strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.1,
    fraction_evaluate=0.1,
    min_fit_clients=10,
    min_evaluate_clients=10,
    min_available_clients=pool_size,  # All clients should be available
    on_fit_config_fn=fit_config,
    evaluate_fn=get_evaluate_fn(test_data),  # centralised evaluation of global model
)

def client_fn(cid: str):
    # create a single client instance
    return FlowerClient(cid, os.path.join(data_path, 'client_%s' % cid))

ray_init_args = {"include_dashboard": False}

# start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=pool_size,
    client_resources=client_resources,
    config=fl.server.ServerConfig(num_rounds=100),
    strategy=strategy
)

INFO flwr 2023-02-19 18:22:24,536 | app.py:145 | Starting Flower simulation, config: ServerConfig(num_rounds=100, round_timeout=None)
INFO flwr 2023-02-19 18:22:26,646 | app.py:179 | Flower VCE: Ray initialized with resources: {'memory': 5845494990.0, 'node:192.168.1.171': 1.0, 'object_store_memory': 2922747494.0, 'CPU': 8.0}
INFO flwr 2023-02-19 18:22:26,648 | server.py:86 | Initializing global parameters
INFO flwr 2023-02-19 18:22:26,649 | server.py:270 | Requesting initial parameters from one random client
2023-02-19 18:22:27,359	ERROR serialization.py:342 -- 
Traceback (most recent call last):
  File "/home/katharine/.virtualenvs/py3data/lib/python3.8/site-packages/ray/serialization.py", line 201, in _deserialize_msgpack_data
    obj = MessagePackSerializer.loads(msgpack_data, _python_deserializer)
  File "python/ray/includes/serialization.pxi", line 191, in ray._raylet.MessagePackSerializer.loads
  File "python/ray/includes/serialization.pxi", line 192, in ray._raylet.MessagePackS

RaySystemError: System error: 
traceback: Traceback (most recent call last):
  File "/home/katharine/.virtualenvs/py3data/lib/python3.8/site-packages/ray/serialization.py", line 201, in _deserialize_msgpack_data
    obj = MessagePackSerializer.loads(msgpack_data, _python_deserializer)
  File "python/ray/includes/serialization.pxi", line 191, in ray._raylet.MessagePackSerializer.loads
  File "python/ray/includes/serialization.pxi", line 192, in ray._raylet.MessagePackSerializer.loads
  File "msgpack/_unpacker.pyx", line 161, in msgpack._unpacker.unpackb
TypeError: unpackb() got an unexpected keyword argument 'strict_map_key'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/katharine/.virtualenvs/py3data/lib/python3.8/site-packages/ray/serialization.py", line 340, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
  File "/home/katharine/.virtualenvs/py3data/lib/python3.8/site-packages/ray/serialization.py", line 259, in _deserialize_object
    obj = self._deserialize_msgpack_data(data, metadata_fields)
  File "/home/katharine/.virtualenvs/py3data/lib/python3.8/site-packages/ray/serialization.py", line 203, in _deserialize_msgpack_data
    raise DeserializationError()
ray.serialization.DeserializationError


In [10]:
import flwr

In [11]:
flwr.__version__

'1.3.0'

In [19]:
!pip install --upgrade flwr

Collecting flwr
  Downloading flwr-1.3.0-py3-none-any.whl (139 kB)
[K     |████████████████████████████████| 139 kB 741 kB/s eta 0:00:01
Installing collected packages: flwr
  Attempting uninstall: flwr
    Found existing installation: flwr 1.0.0
    Uninstalling flwr-1.0.0:
      Successfully uninstalled flwr-1.0.0
Successfully installed flwr-1.3.0


## Challenges

- Adjust the fit and evaluate settings and see how the performance changes.
- Try out another [Flower tutorial](https://flower.dev/docs/quickstart-pytorch.html).
- Get a group of several folks together to try running flower in a distributed setup. Document your learnings and share in the reader-contributions!