In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Subset

from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.models import vgg11

from sklearn.model_selection import KFold

import numpy as np
import matplotlib.pyplot as plt

from lib.models import model
from lib.server import Server
from lib.client import Client
from lib.data_helper import *
from lib.train_helper import *
from lib.plots import plot_loss_epoch

# Device

In [2]:
if torch.backends.mps.is_available():
    device = 'mps'
elif torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
device

'cuda'

# Manual Seeding

In [3]:
seed_generators()

# Data

In [4]:
train_dataset = CIFAR10('data', train=True, download=True, transform=transforms.ToTensor())
train_dataset

Files already downloaded and verified


Dataset CIFAR10
    Number of datapoints: 50000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()

In [5]:
test_dataset = CIFAR10('data', train=False, download=True, transform=transforms.ToTensor())
test_dataset

Files already downloaded and verified


Dataset CIFAR10
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: ToTensor()

# Server & Clients

In [6]:
num_clients = 8
num_classes = 10

## Data repartition among clients

In [7]:
proportions = generate_proportions(num_clients,num_classes)
proportions

[tensor([1.7277e-03, 1.4681e-03, 1.7119e-01, 6.7152e-08, 5.6816e-13, 3.8499e-02,
         1.0243e-05, 7.8711e-01]),
 tensor([1.0118e-01, 3.6152e-06, 1.2661e-02, 7.9172e-01, 5.1877e-21, 6.2118e-12,
         9.4433e-02, 3.0271e-15]),
 tensor([2.6149e-01, 2.5464e-01, 3.9212e-11, 8.3872e-11, 1.3594e-10, 3.5875e-01,
         1.1811e-01, 7.0036e-03]),
 tensor([9.6308e-04, 1.4171e-06, 6.7539e-05, 4.1170e-02, 8.3330e-04, 3.0664e-03,
         2.2048e-01, 7.3342e-01]),
 tensor([5.0231e-01, 2.2462e-11, 4.8773e-01, 1.6541e-03, 2.9940e-04, 9.7107e-06,
         6.0056e-03, 1.9903e-03]),
 tensor([6.4948e-06, 1.6067e-10, 9.1099e-09, 9.7064e-05, 1.0678e-03, 1.6058e-08,
         3.3956e-20, 9.9883e-01]),
 tensor([4.1385e-03, 8.2548e-01, 3.1079e-04, 2.0520e-11, 1.6359e-01, 5.3383e-07,
         3.7362e-05, 6.4420e-03]),
 tensor([8.4674e-01, 1.1771e-08, 8.1328e-03, 4.3720e-11, 7.8426e-02, 2.2480e-03,
         6.4372e-02, 7.9591e-05]),
 tensor([1.3971e-03, 7.3353e-18, 3.2855e-02, 1.9014e-01, 2.6827e-01, 5.9

# Cross Validation

In [8]:
def run_FedAvg(server,clients,device,train_dataset,valid_dataset,rounds,epochs,batch_size,optimizer,lr,**kwargs):
    
    train_loader = DeviceDataLoader(DataLoader(train_dataset, batch_size, shuffle=True), device)
    valid_loader = DeviceDataLoader(DataLoader(valid_dataset, batch_size, shuffle=True), device)
    history = []

    server.reset_weights()
    for i in range(rounds):
        print(f'>>> Round {i+1} ...')
        server_params = server.net.state_dict()
        clients_params = []
        for client in clients:
            client.train(device,server_params,epochs,batch_size, optimizer,lr,**kwargs)
        server.merge(clients)
        
        train_loss, train_acc = evaluate(server.net,train_loader)
        valid_loss, valid_acc = evaluate(server.net,valid_loader)
        
        print(f'''
Server :
    - training loss = {train_loss:.4f}
    - training accuracy = {train_acc:.4f}
    - validation loss = {valid_loss:.4f}
    - validation accuracy = {valid_acc:.4f}
        ''')
        
        history.append([train_loss, valid_loss, train_acc, valid_acc])
    return history

In [9]:
def run_CV(n_splits,device,rounds,epochs,batch_size,model,optimizer,lr,**kwargs):
    folds = KFold(n_splits=n_splits,shuffle=True,random_state=42).split(np.arange(len(train_dataset)))

    results = []
    for fold, (train_idx,valid_idx) in enumerate(folds):
        print("="*25)
        print(f"Fold #{fold}")
        print("="*25)

        # create train and validation subsets
        train_subset = Subset(train_dataset, train_idx)
        valid_subset = Subset(train_dataset, valid_idx)

        # split data between users
        partition = partition_by_class(train_subset)
        clients = [Client(i,d,model().to(device)) for i,d in enumerate(split(partition,proportions),start=1)]
        server = Server(model().to(device))
        
        for c in clients:
            print(f"client {c.client_id} : {len(c.dataset)} samples")
        print()
        
        history = run_FedAvg(server,clients,device,train_subset,valid_subset,rounds,epochs,batch_size,optimizer,lr,**kwargs)
        
        plot_loss_epoch(history)
        
        results.append(history[-1][-1])
        
        
    return torch.tensor(results).mean()

In [None]:
%%time
n_splits = 5
rounds = 40
batch_size = 128
epochs = 5
lrs = [1e-1,1e-2,1e-3,1e-4]
optimizer = optim.SGD

results = []
for lr in lrs:
    res = run_CV(n_splits,device,rounds,epochs,batch_size,model,optimizer,lr)
    results.append((lr,res))
    
print("="*25)
print(f"RESULTS")
print("="*25)
for lr,res in results:
    print(f"lr = {lr:.4f} : Accuracy = {res}")
print()


Fold #0
client 1 : 6914 samples
client 2 : 4322 samples
client 3 : 2887 samples
client 4 : 8071 samples
client 5 : 2047 samples
client 6 : 1624 samples
client 7 : 4006 samples
client 8 : 10129 samples

>>> Round 1 ...
client 1 : Loss = 0.9771, Accuracy = 0.6602
client 2 : Loss = 0.1448, Accuracy = 0.9441
client 3 : Loss = 0.2537, Accuracy = 0.9030
client 4 : Loss = 0.4797, Accuracy = 0.8256
client 5 : Loss = 0.1411, Accuracy = 0.9492
client 6 : Loss = 0.1178, Accuracy = 0.9557
client 7 : Loss = 0.4629, Accuracy = 0.8283
client 8 : Loss = 0.5101, Accuracy = 0.7762

Server :
    - training loss = 2.3044
    - training accuracy = 0.2431
    - validation loss = 2.3544
    - validation accuracy = 0.2372
        
>>> Round 2 ...
client 1 : Loss = 0.9586, Accuracy = 0.6911
client 2 : Loss = 0.1129, Accuracy = 0.9569
client 3 : Loss = 0.1390, Accuracy = 0.9562
client 4 : Loss = 0.4244, Accuracy = 0.8588
client 5 : Loss = 0.0875, Accuracy = 0.9751
client 6 : Loss = 0.0859, Accuracy = 0.9707
cli