In [1]:
from collections import OrderedDict
from typing import List, Optional, Tuple

import flwr as fl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from flwr.server.strategy import FedAvg
from flwr.server.client_proxy import ClientProxy
from flwr.common import FitRes, parameters_to_ndarrays, ndarrays_to_parameters
from torchvision.models._api import Weights

from tinysmpc import VirtualMachine, PrivateScalar
from tinysmpc.fixed_point import fixed_point, float_point

DEVICE = torch.device("cpu")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}")

Training on cpu using PyTorch 1.13.1+cpu and Flower 1.1.0


In [2]:
NUM_CLIENTS = 10


def load_datasets(num_clients: int):
    # Download and transform CIFAR-10 (train and test)
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = len(trainset) // num_clients
    lengths = [partition_size] * num_clients
    datasets = random_split(trainset, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
    testloader = DataLoader(testset, batch_size=32)
    return trainloaders, valloaders, testloader


trainloaders, valloaders, testloader = load_datasets(NUM_CLIENTS)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch + 1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [4]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(cid) -> FlowerClient:
    net = Net().to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

In [5]:
Q = 2657003489534545107915232808830590043
fixedPoint = np.vectorize(fixed_point)
floatPoint = np.vectorize(float_point)


class FedAvgSmc(FedAvg):
    def aggregate_fit(
            self,
            rnd: int,
            results: List[Tuple[fl.server.client_proxy.ClientProxy, FitRes]],
            failures: List[BaseException],
    ) -> Optional[Weights]:
        if not results:
            return None

        clients = []
        fit_results = []

        for client, fit_result in results:
            clients.append(client)
            fit_results.append(fit_result)

        fit_res_ndarray_parameters = [parameters_to_ndarrays(result.parameters) for result in
                                      fit_results]  # list of clients, each with a list of layers(each layer representing weights)

        fl_nodes = [VirtualMachine(f"Client: {client.cid}") for client in clients]

        num_layers = len(fit_results[0].parameters.tensors)

        layers_weights = {}

        layers_weights_smc = {}

        for layer in range(num_layers):  #loop through number of layers
            layers_weights[f"layer_{layer}"] = []
            for weights in fit_res_ndarray_parameters:
                layers_weights[f"layer_{layer}"].append(fixedPoint(weights[layer]))

        for layer in range(num_layers):
            fl_node_values = [PrivateScalar(tensor, node) for tensor, node in
                              zip(layers_weights[f'layer_{layer}'], fl_nodes)]

            fl_exchanged_shares = []
            fl_exchanged_shares_list = []
            layers_weights_smc[f"layer_{layer}"] = []

            for value in fl_node_values:
                fl_exchanged_shares.append(value.share_tensor(fl_nodes, Q))

            for client_shares in fl_exchanged_shares:
                fl_exchanged_shares_list = [share.value for share in client_shares.shares]

            layers_weights_smc[f'layer_{layer}'] = floatPoint(fl_exchanged_shares_list)

        for i, client in enumerate(clients):
            client_weights = []
            for layer in layers_weights_smc.values():
                client_weights.append(np.array(layer[i]))
            fit_results[i].parameters = ndarrays_to_parameters(client_weights)

        # results = tuple(zip(clients, fit_results))

        return super().aggregate_fit(rnd, results, failures)


strategy = FedAvgSmc(fraction_fit=0.3, fraction_evaluate=0.3, min_fit_clients=3, min_evaluate_clients=3,
                     min_available_clients=NUM_CLIENTS)
fl.simulation.start_simulation(client_fn=client_fn, num_clients=2, config=fl.server.ServerConfig(num_rounds=2),
                               strategy=FedAvgSmc(), )

INFO flower 2023-01-24 23:12:16,665 | app.py:140 | Starting Flower simulation, config: ServerConfig(num_rounds=2, round_timeout=None)
2023-01-24 23:12:18,067	INFO worker.py:1518 -- Started a local Ray instance.
INFO flower 2023-01-24 23:12:21,008 | app.py:174 | Flower VCE: Ray initialized with resources: {'node:127.0.0.1': 1.0, 'CPU': 4.0, 'object_store_memory': 2303655936.0, 'memory': 4607311872.0}
INFO flower 2023-01-24 23:12:21,009 | server.py:86 | Initializing global parameters
INFO flower 2023-01-24 23:12:21,011 | server.py:270 | Requesting initial parameters from one random client
INFO flower 2023-01-24 23:12:24,013 | server.py:274 | Received initial parameters from one random client
INFO flower 2023-01-24 23:12:24,014 | server.py:88 | Evaluating initial parameters
INFO flower 2023-01-24 23:12:24,014 | server.py:101 | FL starting
DEBUG flower 2023-01-24 23:12:24,015 | server.py:215 | fit_round 1: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_get_parameters pid=13384)[0m [Client 0] get_parameters
[2m[36m(launch_and_fit pid=13384)[0m [Client 0] fit, config: {}
[2m[36m(launch_and_fit pid=3728)[0m [Client 1] fit, config: {}
[2m[36m(launch_and_fit pid=13384)[0m Epoch 1: train loss 0.06490003317594528, accuracy 0.2331111111111111


DEBUG flower 2023-01-24 23:12:30,734 | server.py:229 | fit_round 1 received 2 results and 0 failures


[2m[36m(launch_and_fit pid=3728)[0m Epoch 1: train loss 0.06493470817804337, accuracy 0.23044444444444445


DEBUG flower 2023-01-24 23:12:31,067 | server.py:165 | evaluate_round 1: strategy sampled 2 clients (out of 2)
DEBUG flower 2023-01-24 23:12:33,375 | server.py:179 | evaluate_round 1 received 2 results and 0 failures
DEBUG flower 2023-01-24 23:12:33,376 | server.py:215 | fit_round 2: strategy sampled 2 clients (out of 2)


[2m[36m(launch_and_evaluate pid=13384)[0m [Client 0] evaluate, config: {}
[2m[36m(launch_and_evaluate pid=3728)[0m [Client 1] evaluate, config: {}
[2m[36m(launch_and_fit pid=13384)[0m [Client 1] fit, config: {}
[2m[36m(launch_and_fit pid=3728)[0m [Client 0] fit, config: {}


DEBUG flower 2023-01-24 23:12:38,179 | server.py:229 | fit_round 2 received 2 results and 0 failures


nan
[2m[36m(launch_and_fit pid=13384)[0m Epoch 1: train loss nan, accuracy 0.09733333333333333
[2m[36m(launch_and_fit pid=3728)[0m Epoch 1: train loss nan, accuracy 0.10533333333333333


ValueError: cannot convert float NaN to integer