In [19]:
%load_ext autoreload
%autoreload 2

import os, sys
import copy
import socket
from tqdm import tqdm
import torch
import pickle
from torch import optim
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
from libs import fl, nn, data, log

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
#log.init("info", "federated.log")
#log.init("debug", "flkafka.log")

In [21]:
class FedArgs():
    def __init__(self):
        self.num_clients = 5
        self.epochs = 5
        self.local_rounds = 1
        self.client_batch_size = 32
        self.test_batch_size = 128
        self.learning_rate = 1e-4
        self.weight_decay = 1e-5
        self.cuda = False
        self.seed = 1
        self.topic = 'pyflx'
        self.tb = SummaryWriter('../out/runs/federated', comment="Mnist Centralized Federated training")

fedargs = FedArgs()

In [22]:
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [23]:
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [24]:
# Load MNIST Data to clients
train_data, test_data = data.load_mnist_dataset()
clients_data = data.split_data(train_data, clients)
client_loaders, test_loader = data.load_client_data(clients_data, fedargs.client_batch_size, test_data, fedargs.test_batch_size, **kwargs)

# Load preliminary models
# if want to do random initialization, replace copy.deepcopy(global_model) with nn.ModelMNIST() and vice versa
global_model = nn.ModelMNIST().to(device)
clients_info = {
        client: { "model": copy.deepcopy(global_model).to(device), "loss": {}, "data_loader": client_loaders[client]}
        for client in clients
    }

In [25]:
import asyncio
import time

def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def process(client, client_models, epoch):
    # Train
    clients_info[client]['model'], clients_info[client]['loss'] = fl.client_update(clients_info[client]['model'],
                                                                                clients_info[client]['data_loader'],
                                                                                fedargs.learning_rate,
                                                                                fedargs.weight_decay,
                                                                                fedargs.local_rounds,
                                                                                device)
    
    client_models[client] = clients_info[client]['model']

    for local_epoch, loss in enumerate(list(clients_info[client]['loss'].values())):
        fedargs.tb.add_scalars("Training Loss/" + client, 
                               {str(epoch): loss}, 
                               str(local_epoch + 1))

    log.jsondebug(clients_info[client]['loss'],
                 "Epoch {} of {} : Federated Training loss, Client {}".format(epoch, 
                                                                      fedargs.epochs, 
                                                                      client))
    log.modeldebug(clients_info[client]['model'],
                   "Epoch {} of {} : Client {} Update".format(epoch, 
                                                              fedargs.epochs, 
                                                              client))

In [None]:
import time
start_time = time.time()

# Federated Training
for _epoch in tqdm(range(fedargs.epochs)):
    epoch = _epoch + 1
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    # Train
    client_models = {}
    tasks = [process(client, client_models, epoch) for client in clients]
    await asyncio.wait(tasks)

    # Average the client updates
    global_model = fl.federated_avg(client_models)
    log.modeldebug(global_model, "Epoch {} of {} : Global Model".format(epoch, fedargs.epochs))
    
    # Test Epoch
    test_output = fl.eval(global_model, test_loader, device)
    log.jsoninfo(test_output, "Test Outut after Epoch {} of {}".format(epoch, fedargs.epochs))
    
print(time.time() - start_time)

  0%|          | 0/5 [00:00<?, ?it/s]2021-07-27 23:57:32,335 - <ipython-input-26-9b49d7e383a8>::<module>(l:7) : Federated Training Epoch 1 of 5
2021-07-27 23:58:02,645 - <ipython-input-26-9b49d7e383a8>::<module>(l:20) : Test Outut after Epoch 1 of 5 {
    "accuracy": 92.44,
    "attack": {
        "attack_success_count": 0,
        "attack_success_rate": 0,
        "instances": 0,
        "misclassification_rate": 0,
        "misclassifications": 0
    },
    "correct": 9244,
    "test_loss": 0.3118038007259369
}
 20%|██        | 1/5 [00:30<02:01, 30.32s/it]2021-07-27 23:58:02,656 - <ipython-input-26-9b49d7e383a8>::<module>(l:7) : Federated Training Epoch 2 of 5
2021-07-27 23:58:34,289 - <ipython-input-26-9b49d7e383a8>::<module>(l:20) : Test Outut after Epoch 2 of 5 {
    "accuracy": 94.66,
    "attack": {
        "attack_success_count": 0,
        "attack_success_rate": 0,
        "instances": 0,
        "misclassification_rate": 0,
        "misclassifications": 0
    },
    "correct"