# Install Dependencies

In [None]:
!pip install -q flwr[simulation] torch torchvision matplotlib

In [2]:
from collections import OrderedDict
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10, MNIST
import flwr as fl
from flwr.common import Metrics


Reconstructing neural networks from papers: FedAvg, DFedAvgM, FedProx

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding='same') #  32*28*28
        self.pool1 = nn.MaxPool2d(2, 2)                  #  32*14*14
        self.conv2 = nn.Conv2d(32,64, 5, padding='same') #  64*14*14
        self.pool2 = nn.MaxPool2d(2,2)                   #  64*7*7
        self.fc1 = nn.Linear(64* 7* 7, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool1(self.conv1(x))
        x = self.pool2(self.conv2(x))
        x = x.view(-1, 64* 7* 7)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

class MLP(nn.Module):
  def __init__(self, input_size:int = 28*28*1):
    super(MLP, self).__init__()
    self.fc1 = nn.Linear(input_size, 200)
    self.fc2 = nn.Linear(200,200)
    self.fc3 = nn.Linear(200, 10)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = x.view(-1, 28*28*1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.softmax(self.fc3(x), dim=1)
    return x

class LSTM1(nn.Module):
    def __init__(self, num_classes, input_size, hidden_size, num_layers, seq_length):
        super(LSTM1, self).__init__()
        self.num_classes = num_classes #number of classes
        self.num_layers = num_layers   #number of layers
        self.input_size = input_size   #input size
        self.hidden_size = hidden_size #hidden state
        self.seq_length = seq_length   #sequence length
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True) #lstm
        self.fc_1 =  nn.Linear(hidden_size, 128) #fully connected 1
        self.fc = nn.Linear(128, num_classes) #fully connected last layer

    def forward(self,x):
        h_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) #hidden state
        c_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) #internal state
        # Propagate input through LSTM
        output, (hn, cn) = self.lstm(x, (h_0, c_0)) #lstm with input, hidden, and internal state
        hn = hn.view(-1, self.hidden_size) #reshaping the data for Dense layer next
        out = F.relu(hn)
        out = self.fc_1(out) #first Dense
        out = F.relu(out) #relu
        out = self.fc(out) #Final Output
        return out


Check paper dimensionality

In [4]:
net1 = MLP()
total_params1 = sum(p.numel() for p in net1.parameters())
print(f"Number of parameters MLP: {total_params1}")
assert total_params1==199210

net2 = CNN()
total_params2 = sum(p.numel() for p in net2.parameters())
print(f"Number of parameters CNN: {total_params2}")
assert total_params2==1663370

Number of parameters MLP: 199210
Number of parameters CNN: 1663370


Create DatasetNode object to store dataset and model parameters

In [5]:

class DatasetNode():
  def __init__(self, dataset:str):
    self.name = dataset
    if dataset=='CIFAR10':
      self.dataset = torchvision.datasets.CIFAR10
      self.archs = ['MLP','CNN']
      self.input_size = (3, 32, 32)
      self.classes = (
          "plane",
          "car",
          "bird",
          "cat",
          "deer",
          "dog",
          "frog",
          "horse",
          "ship",
          "truck",)

    elif dataset=='MNIST':
      self.dataset = torchvision.datasets.MNIST
      self.archs= ['MLP','CNN']
      self.input_size = (1,28, 28)
      self.classes = tuple(i for i in range(1,10))

    elif dataset=='FEMNIST':
      self.dataset = torchvision.datasets.FEMNIST

# Simulation Input variables

In [7]:
#Federated variables
NUM_CLIENTS = 20
NUM_ROUNDS = 3
CLIENT_FRAC = 1.0
MIN_FIT_CLIENTS = 10
MIN_EVALUABLE_CLIENTS = 5
MIN_AVAILABLE_CLIENTS = 10

#Data loading and training variables
DATASET = MNIST
SHAPE = (1,28,28)
SPLIT_FN = random_split # ADD SIGNATURE OF FUNCTION
SEED = 42
VAL_SPLIT = 0.1
BATCH_SIZE = 32
EPOCHS = 5

#Evaluation variables
LOSS_FN = torch.nn.CrossEntropyLoss()

#Device Variables
NUM_GPU = torch.cuda.device_count()
if torch.cuda.is_available():
  DEVICE = torch.device('cuda')
else: DEVICE = torch.device('cpu')
print(print(
    f"Training on {NUM_GPU} X {DEVICE} using PyTorch {torch.__version__} and Flower {fl.__version__}")
)

# Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
client_resources = None
if DEVICE.type == "cuda":
    client_resources = {"num_gpus": NUM_GPU}

Training on 1 X cuda using PyTorch 2.0.1 and Flower 1.4.0
None


# Local Training Pipeline

In [8]:
def load_datasets(dataset = DATASET, batch_size = BATCH_SIZE, split_fn = SPLIT_FN, val_split = VAL_SPLIT ,seed:int = SEED):
    '''
    Given a dataset, batch size, and validation dataset fraction, it returns 3 lists of corresponding DataLoaders for train, validation, test
    :param val_split: float
    :param seed: int
    '''
    # Download and transform dataset (train and test)
    normalizer=transforms.Normalize([0.5]*SHAPE[0], [0.5]*SHAPE[0])
    transform = transforms.Compose(
        [transforms.ToTensor(),
         normalizer,
         ]
    )
    trainset = dataset("./dataset", train=True, download=True, transform=transform)
    testset  = dataset("./dataset", train=False,download=True, transform=transform)

    # Split training set into partitions to simulate the individual dataset
    partition_size = len(trainset) // NUM_CLIENTS
    lengths = [partition_size] * NUM_CLIENTS
    datasets = split_fn(trainset, lengths, torch.Generator().manual_seed(seed))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []

    for ds in datasets:
        len_val = int( len(ds) * val_split )
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = split_fn(ds, lengths, torch.Generator().manual_seed(seed))
        trainloaders.append(DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=BATCH_SIZE))
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)

    return trainloaders, valloaders, testloader


In [9]:
trainloaders, valloaders, testloader =load_datasets()

0.3%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ./dataset/MNIST/raw/train-images-idx3-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%
6.0%

Extracting ./dataset/MNIST/raw/train-labels-idx1-ubyte.gz to ./dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ./dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to ./dataset/MNIST/raw


100.0%


Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz





Extracting ./dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./dataset/MNIST/raw



In [10]:
def train(net, trainloader, epochs: int = EPOCHS, loss_fn = LOSS_FN,verbose=False):
    """Train the network on the training set."""
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        if verbose:
            print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


In [11]:
def test(net, testloader, loss_fn=LOSS_FN):
    """Evaluate the network on the entire test set."""
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += loss_fn(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

Federated Client

In [12]:
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

class FlowerClient(fl.client.NumPyClient):
    def __init__(self, net, trainloader, valloader):
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [13]:
def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""
    net = MLP().to(DEVICE)  # <-- pass the Model
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    return FlowerClient(net, trainloader, valloader)

Metrics aggregators

In [14]:
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

Create FedAvg strategy

In [15]:
strategy = fl.server.strategy.FedAvg(
    fraction_fit= CLIENT_FRAC,
    fraction_evaluate= CLIENT_FRAC,
    min_fit_clients= MIN_FIT_CLIENTS,
    min_evaluate_clients=  MIN_EVALUABLE_CLIENTS,
    min_available_clients= MIN_AVAILABLE_CLIENTS,
    evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
)


Start simulation

In [16]:
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources=client_resources,
)


INFO flwr 2023-07-10 12:56:01,897 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
2023-07-10 12:56:04,437	ERROR services.py:1207 -- Failed to start the dashboard , return code 1
2023-07-10 12:56:04,439	ERROR services.py:1232 -- Error should be written to 'dashboard.log' or 'dashboard.err'. We are printing the last 20 lines for you. See 'https://docs.ray.io/en/master/ray-observability/ray-logging.html#logging-directory-structure' to find where the log file is.
2023-07-10 12:56:04,440	ERROR services.py:1276 -- 
The last 20 lines of /tmp/ray/session_2023-07-10_12-56-01_927618_3475/logs/dashboard.log (it contains the error message from the dashboard): 
  File "/home/eduardburlacu/miniconda3/envs/torch/lib/python3.9/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1030, in _gcd_import
  File "<frozen importlib._bootstrap>", line 100

AttributeError: module 'pydantic.fields' has no attribute 'ModelField'

# Replicate Experiments on FedAvg paper