In [1]:
# !pip install --upgrade tensorflow # install tensor flow to download mnist dataset.

Load the parameters, libraries and datasets needed for the federated training

In [2]:
import os
import sys
import requests
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import time
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm.notebook import tqdm
# Federated imports
import forcast_federated_learning as ffl

# Parameters
num_clients        = 10
com_rounds         = 20
seed               = 0
batch_size         = 200
noise_multiplier   = 0.2
max_grad_norm      = 0.5
epochs             = 2
lr                 = 0.005
device             = 'cuda' # 'cpu'

# Metrics
df_metrics = pd.DataFrame(dict(zip(['round', 'accuracy', 'loss', 'epsilon', 'delta'], [int,[],[],[],[]])))

# Load local train data
import tensorflow as tf
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0

# Create custom pytorch datasers for train and testing
traindata = ffl.datasets.ImageDataset(X_train, y_train, categorical=True)
testdata  = ffl.datasets.ImageDataset(X_test, y_test, categorical=True)

Split the train data between the clients

In [3]:
# Split the train data and use only a fraction
traindata_split = ffl.data.random_split(traindata, num_clients=num_clients, seed=seed)

# Get data loader
train_loaders = [ffl.utils.DataLoader(traindata, batch_size=batch_size, shuffle=True, seed=seed)   for traindata in traindata_split]
test_loader  = ffl.utils.DataLoader(testdata, batch_size=len(testdata), shuffle=True, seed=seed)

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d( 1, 64, kernel_size=3)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.bn1   = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2, 2)
        # flatten
        self.bn2   = nn.BatchNorm1d(128 * 4 * 4)
        self.fc1   = nn.Linear(128 * 4 * 4, 512)
        self.fc2   = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool1(x)
        x = self.bn1(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool2(x)
        x = x.view(-1, 128 * 4 * 4)
        x = self.bn2(x)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        
        return x

Initialize every model for each client. The model topology and optimizer is shared between the clients model and the server or global model.

In [5]:
# Train params
optimizer_params   = {'lr': lr}
train_params       = {'epochs': epochs}

local_models       = []
for _ in range(num_clients):
    # Create federated model based on a pytorch model
    num_features, num_classes  = 4, 3
    model                      = CNN() # pytorch model
    loss_fn                    = nn.CrossEntropyLoss() # classification
    local_model                = ffl.LocalModel(model, model_type = 'nn', loss_fn=loss_fn, train_params=train_params)
    local_model.optimizer      = ffl.optim.Adam(local_model.parameters(), **optimizer_params)
    
    local_models.append(local_model)

Initialize the global model  and in case is needed the encryption parameters

In [6]:
model           = local_model.model # pytorch model
fed_model       = ffl.FederatedModel(model, model_type='nn')
public_context, secret_key = ffl.encryption.get_context()
encryption      = False

Each round of communication every client individually train their respective model using private data. Then aggregate those models onto the global model. That model is then shared to the clients and the cicle repeats until a model with sufficient accuracy is obtained or the a fixed number of communication rounds in reached.

In [7]:
for com_round in tqdm(range(com_rounds)):
    for local_model, train_loader in zip(local_models, train_loaders):
        local_model.step(train_loader, device=device)
    
    client_weights = []
    for local_model in local_models:
        state_dict      = local_model.state_dict()
        if encryption:
            # Each client encrypts the their model parameters (state_dict)
            # The library handles internally the encrypted data so the functions don't change much
            enc_state_dict  = ffl.encryption.EncStateDict(state_dict)
            enc_state_dict  = enc_state_dict.encrypt(public_context)
            client_weights.append(enc_state_dict)
        else:
            client_weights.append(state_dict)
    client_lens    = [len(traindata) for traindata in traindata_split]
    
    ## Server aggregate
    if encryption:
        fed_model.server_agregate(client_weights, client_lens, secret_key=secret_key)
    else:
        fed_model.server_agregate(client_weights, client_lens)
    weights = fed_model.state_dict()
    
    for local_model in local_models:
        local_model.load_state_dict(weights)
    
    acc, loss = local_model.test(test_loader) # local model now is the same as the global model
    print(f'Test accuracy: {acc:.2f}')
    df_aux       = pd.DataFrame({'round': [com_round+1], 'accuracy': [acc], 'loss': [loss], 'epsilon': [None], 'delta':[None] })
        
    # Save metrics
    df_metrics   = pd.concat([df_metrics, df_aux], axis=0)

  0%|          | 0/20 [00:00<?, ?it/s]

Test accuracy: 11.35
Test accuracy: 71.65
Test accuracy: 86.37
Test accuracy: 93.85
Test accuracy: 94.31
Test accuracy: 97.09
Test accuracy: 96.69
Test accuracy: 97.61
Test accuracy: 97.71
Test accuracy: 98.06
Test accuracy: 98.38
Test accuracy: 97.53
Test accuracy: 98.08
Test accuracy: 98.59
Test accuracy: 98.46
Test accuracy: 98.38
Test accuracy: 98.67
Test accuracy: 98.36
Test accuracy: 98.44
Test accuracy: 98.72


In [8]:
# # Save metrics onto csv file
# df_metrics.to_csv('./sim_mnist.csv', index=False)