This example contains the necessary bits of code to run the federated training with homomorphic encryption (he).

In [1]:
import os
import sys
import requests
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import jsonpickle as jpk
import time
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm.notebook import tqdm
# Federated imports
import forcast_federated_learning as ffl

# Parameters
num_clients        = 10
com_rounds         = 40
seed               = 0
batch_size         = 1
noise_multiplier   = 0.3
max_grad_norm      = 0.5

# Metrics
df_metrics = pd.DataFrame(dict(zip(['round', 'rmse', 'r2_score', 'epsilon', 'delta'], [int,[],[],[],[]])))

# Load local train data
X, y, df_data, target_names = ffl.datasets.load_scikit_iris()

# Split the database in train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=seed)  

# Create custom pytorch datasers for train and testing
traindata = ffl.datasets.StructuredDataset(X_train, y_train, categorical=True)
testdata  = ffl.datasets.StructuredDataset(X_test, y_test, categorical=True)

In [2]:
# Split the train data and use only a fraction
traindata_split = ffl.data.random_split(traindata, num_clients=num_clients, seed=seed)

# Get data loader
train_loaders = [ffl.utils.DataLoader(traindata, batch_size=batch_size, shuffle=True, seed=seed)   for traindata in traindata_split]
test_loader  = ffl.utils.DataLoader(testdata, batch_size=len(testdata), shuffle=True, seed=seed)

In [3]:
# Train params
delta              = 10**-np.ceil(np.log10(len(traindata))) # delta < 1/len(dataset)
security_params    = {'noise_multiplier': noise_multiplier, 'max_grad_norm': max_grad_norm, 'batch_size': batch_size, 'sample_size': len(traindata), 'target_delta': delta, 'secure_rng': True} 
optimizer_params   = {'lr': 0.01}
train_params       = {'epochs': 4}

local_models       = []
for _ in range(num_clients):
    # Create federated model based on a pytorch model
    num_features, num_classes  = 4, 3
    model                      = ffl.models.NN(input_dim=num_features, output_dim=num_classes) # pytorch model
    loss_fn                    = nn.CrossEntropyLoss() # classification
    local_model                = ffl.LocalModel(model, model_type = 'nn', loss_fn=loss_fn, train_params=train_params)
    local_model.optimizer      = ffl.optim.Adam(local_model.parameters(), **optimizer_params)
    local_model.privacy_engine = ffl.security.PrivacyEngine(local_model, **security_params)
    local_model.privacy_engine.attach(local_model.optimizer)
    
    local_models.append(local_model)

In [4]:
model           = local_model.model # pytorch model
fed_model       = ffl.FederatedModel(model, model_type='nn')
public_context, secret_key = ffl.encryption.get_context()

As a coment in practice, when deploying the public_context object need to be serialized to be shared with the clients with `context = public_context.serialize()`, and then each client needs to load it onto a python object with `context = ffl.encryption.load_context(context)`.

In [5]:
for com_round in tqdm(range(com_rounds)):
    for local_model, train_loader in zip(local_models, train_loaders):
        local_model.step(train_loader)
    
    client_weights = []
    for local_model in local_models:
        state_dict      = local_model.state_dict()
        enc_state_dict  = ffl.encryption.EncStateDict(state_dict)
        enc_state_dict  = enc_state_dict.encrypt(public_context)
        client_weights.append(enc_state_dict)
    client_lens    = [len(traindata) for traindata in traindata_split]
    
    ## Server aggregate
    fed_model.server_agregate(client_weights, client_lens, secret_key=secret_key)
    weights = fed_model.state_dict()
    
    for local_model in local_models:
        local_model.load_state_dict(weights)
    
    acc, _ = local_model.test(test_loader)
    if local_model.privacy_engine: # privacy spent
        epsilon, best_alpha = local_model.privacy_engine.get_privacy_spent(delta)
        print(f'Test accuracy: {acc:.2f} - Privacy spent: (ε = {epsilon:.2f}, δ = {delta:.2f})')
    else: 
        print(f'Test accuracy: {acc:.2f}')

HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))

Test accuracy: 32.00 - Privacy spent: (ε = 13.59, δ = 0.01)
Test accuracy: 32.00 - Privacy spent: (ε = 16.83, δ = 0.01)
Test accuracy: 32.00 - Privacy spent: (ε = 19.49, δ = 0.01)
Test accuracy: 32.00 - Privacy spent: (ε = 22.14, δ = 0.01)
Test accuracy: 42.00 - Privacy spent: (ε = 23.95, δ = 0.01)
Test accuracy: 64.00 - Privacy spent: (ε = 25.68, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 27.40, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 29.12, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 30.84, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 32.56, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 34.28, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 36.00, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 37.72, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 39.44, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 40.66, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε = 41.83, δ = 0.01)
Test accuracy: 62.00 - Privacy spent: (ε