In [1]:
%load_ext autoreload
%autoreload 2

import asyncio, copy, os, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
from libs import agg, data, fl, log, nn, plot, poison, resnet, sim, wandb
from libs.distributed import *
from cfgs.fedargs import *

In [2]:
project = 'fl-kafka'
fedargs.num_clients = 5
name = 'fedavg-cnn-mnist-na'

# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
#log.init("info", name)
#log.init("debug", name)

fedargs.tb = SummaryWriter('../out/runs/' + project + '/' + name, comment="fl")
plot = plot.init(name, project)
wb = wandb.init(name, project)

[34m[1mwandb[0m: Currently logged in as: [33mkasyah[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.6 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [3]:
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [4]:
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [5]:
# Distributed topology
dt = Distributed(clients, fedargs.broker_ip, fedargs.schema_ip, fedargs.wait_to_consume)

In [6]:
# Initialize Global and Client models
global_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_data, test_data = data.load_dataset(fedargs.dataset)

In [7]:
clients_data = data.split_data(train_data, clients)

In [8]:
client_train_loaders, client_test_loaders = data.load_client_data(clients_data, fedargs.client_batch_size, 0.2, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

client_details = {
        client: {"train_loader": client_train_loaders[client],
                 "test_loader": client_test_loaders[client],
                 "model": copy.deepcopy(global_model),
                 "model_update": None}
        for client in clients
    }

In [9]:
def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def process(client, epoch, dt, model, train_loader, test_loader, fedargs, device):
    log.info("Epoch: {}, Processing Client {}".format(epoch, client))
    
    # Consume and Average, epoch passed is actually prev epoch, for which we want to consume updates
    client_model_updates = dt.consume_model(client, fedargs.topic, model, epoch)
    
    # Pop one's own update
    if client in client_model_updates:
        client_model_updates.pop(client)

    log.info("Epoch: {}, Client {} received {} model update(s) from {}".format(epoch, client, 
                                                                               len(client_model_updates), 
                                                                               list(client_model_updates.keys())))
    
    if len(client_model_updates) != 0:
        model = fl.federated_avg(client_model_updates, model)

    # Train    
    model_update, model, loss = fedargs.train_func(model, train_loader, 
                                                   fedargs.learning_rate,
                                                   fedargs.weight_decay,
                                                   fedargs.local_rounds, device)

    # Publish
    epoch = epoch + 1
    dt.produce_model(client, fedargs.topic, model_update, epoch)

    log.jsondebug(loss, "Epoch {} : Federated Training loss, Client {}".format(epoch, client))
    log.modeldebug(model, "Epoch {}: Client {} Update".format(epoch, client))

    # Test, Plot and Log
    test_output = fedargs.eval_func(model, test_loader, device)
    fedargs.tb.add_scalar("Accuracy/" + client, test_output["accuracy"], epoch)
    fedargs.tb.add_scalar("Test Loss/" + client, test_output["test_loss"], epoch)
    plot.alog(client, {epoch: {"time": time.time(), "acc": test_output["accuracy"], "loss": test_output["test_loss"]}})
    wb.log({client: {"epoch": epoch, "time": time.time(), "acc": test_output["accuracy"], "loss": test_output["test_loss"]}})
    log.jsoninfo(test_output, "Test Outut after Epoch {} of {} for Client {}".format(epoch, fedargs.epochs, client))

    return model

In [None]:
import time
start_time = time.time()

# Federated Training
for epoch in tqdm(range(fedargs.epochs)):
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    # Clients  
    tasks = [process(client, epoch, dt, client_details[client]['model'],
                     client_details[client]['train_loader'],
                     client_details[client]['test_loader'],
                     fedargs, device) for client in clients]
    try:
        updates = fedargs.loop.run_until_complete(asyncio.gather(*tasks))
    except KeyboardInterrupt as e:
        log.error("Caught keyboard interrupt. Canceling tasks...")
        tasks.cancel()
        fedargs.loop.run_forever()
        tasks.exception()
    
    for client, update in zip(clients, updates):
        client_details[client]['model'] = update
    
print(time.time() - start_time)

  0%|          | 0/51 [00:00<?, ?it/s]2021-11-06 13:25:44,101 - <ipython-input-10-d06f8893748b>::<module>(l:6) : Federated Training Epoch 0 of 51 [MainProcess : MainThread (INFO)]
2021-11-06 13:25:44,197 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(4) [MainProcess : asyncio_3 (INFO)]
2021-11-06 13:25:44,198 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(5) [MainProcess : asyncio_4 (INFO)]
2021-11-06 13:25:44,198 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(3) [MainProcess : asyncio_2 (INFO)]
2021-11-06 13:25:44,202 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(2) [MainProcess : asyncio_1 (INFO)]
2021-11-06 13:25:44,205 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(1) [MainProcess : asyncio_0 (INFO)]
2021-11-06 13:25:54,869 -

2021-11-06 13:26:27,937 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 1, Processing Client bladecluster.iitp.org(1) [MainProcess : asyncio_0 (INFO)]
2021-11-06 13:26:27,954 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 1, Processing Client bladecluster.iitp.org(4) [MainProcess : asyncio_2 (INFO)]
2021-11-06 13:26:27,969 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 1, Processing Client bladecluster.iitp.org(5) [MainProcess : asyncio_1 (INFO)]
2021-11-06 13:26:27,970 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 1, Processing Client bladecluster.iitp.org(3) [MainProcess : asyncio_4 (INFO)]
2021-11-06 13:26:27,973 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 1, Processing Client bladecluster.iitp.org(2) [MainProcess : asyncio_3 (INFO)]
2021-11-06 13:26:38,277 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 1, Client bladecluster.iitp.org(5) received 13 model update(s) from ['bladecluster.iitp.org(20)', 'bladecluster.iitp.org(8)'

2021-11-06 13:27:24,848 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 2, Processing Client bladecluster.iitp.org(3) [MainProcess : asyncio_4 (INFO)]
2021-11-06 13:27:24,857 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 2, Processing Client bladecluster.iitp.org(4) [MainProcess : asyncio_1 (INFO)]
2021-11-06 13:27:24,873 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 2, Processing Client bladecluster.iitp.org(5) [MainProcess : asyncio_3 (INFO)]
2021-11-06 13:27:24,874 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 2, Processing Client bladecluster.iitp.org(2) [MainProcess : asyncio_2 (INFO)]
2021-11-06 13:27:34,989 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 2, Client bladecluster.iitp.org(2) received 13 model update(s) from ['bladecluster.iitp.org(35)', 'bladecluster.iitp.org(49)', 'bladecluster.iitp.org(36)', 'bladecluster.iitp.org(50)', 'bladecluster.iitp.org(44)', 'bladecluster.iitp.org(8)', 'bladecluster.iitp.org(19)', 'bladecluste

2021-11-06 13:28:20,895 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 3, Processing Client bladecluster.iitp.org(5) [MainProcess : asyncio_4 (INFO)]
2021-11-06 13:28:20,897 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 3, Processing Client bladecluster.iitp.org(2) [MainProcess : asyncio_1 (INFO)]
2021-11-06 13:28:31,580 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 3, Client bladecluster.iitp.org(2) received 13 model update(s) from ['bladecluster.iitp.org(15)', 'bladecluster.iitp.org(13)', 'bladecluster.iitp.org(6)', 'bladecluster.iitp.org(7)', 'bladecluster.iitp.org(1)', 'bladecluster.iitp.org(22)', 'bladecluster.iitp.org(28)', 'bladecluster.iitp.org(14)', 'bladecluster.iitp.org(21)', 'bladecluster.iitp.org(20)', 'bladecluster.iitp.org(30)', 'bladecluster.iitp.org(10)', 'bladecluster.iitp.org(3)'] [MainProcess : asyncio_1 (INFO)]
2021-11-06 13:28:31,594 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 3, Client bladecluster.iitp.org(1) received 1

2021-11-06 13:29:17,005 - <ipython-input-9-988a6d4ae51a>::process(l:9) : Epoch: 4, Processing Client bladecluster.iitp.org(5) [MainProcess : asyncio_4 (INFO)]
2021-11-06 13:29:27,431 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 4, Client bladecluster.iitp.org(2) received 12 model update(s) from ['bladecluster.iitp.org(12)', 'bladecluster.iitp.org(11)', 'bladecluster.iitp.org(17)', 'bladecluster.iitp.org(32)', 'bladecluster.iitp.org(23)', 'bladecluster.iitp.org(24)', 'bladecluster.iitp.org(26)', 'bladecluster.iitp.org(25)', 'bladecluster.iitp.org(27)', 'bladecluster.iitp.org(31)', 'bladecluster.iitp.org(34)', 'bladecluster.iitp.org(37)'] [MainProcess : asyncio_2 (INFO)]
2021-11-06 13:29:27,437 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 4, Client bladecluster.iitp.org(5) received 16 model update(s) from ['bladecluster.iitp.org(10)', 'bladecluster.iitp.org(3)', 'bladecluster.iitp.org(12)', 'bladecluster.iitp.org(11)', 'bladecluster.iitp.org(17)', 'bladecluster.ii

2021-11-06 13:30:20,658 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 5, Client bladecluster.iitp.org(3) received 14 model update(s) from ['bladecluster.iitp.org(25)', 'bladecluster.iitp.org(27)', 'bladecluster.iitp.org(31)', 'bladecluster.iitp.org(34)', 'bladecluster.iitp.org(37)', 'bladecluster.iitp.org(35)', 'bladecluster.iitp.org(40)', 'bladecluster.iitp.org(42)', 'bladecluster.iitp.org(38)', 'bladecluster.iitp.org(46)', 'bladecluster.iitp.org(36)', 'bladecluster.iitp.org(33)', 'bladecluster.iitp.org(41)', 'bladecluster.iitp.org(44)'] [MainProcess : asyncio_2 (INFO)]
2021-11-06 13:30:20,659 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 5, Client bladecluster.iitp.org(2) received 14 model update(s) from ['bladecluster.iitp.org(35)', 'bladecluster.iitp.org(40)', 'bladecluster.iitp.org(42)', 'bladecluster.iitp.org(38)', 'bladecluster.iitp.org(46)', 'bladecluster.iitp.org(36)', 'bladecluster.iitp.org(33)', 'bladecluster.iitp.org(41)', 'bladecluster.iitp.org(44)', 

2021-11-06 13:31:11,992 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 6, Client bladecluster.iitp.org(3) received 14 model update(s) from ['bladecluster.iitp.org(39)', 'bladecluster.iitp.org(45)', 'bladecluster.iitp.org(47)', 'bladecluster.iitp.org(48)', 'bladecluster.iitp.org(43)', 'bladecluster.iitp.org(49)', 'bladecluster.iitp.org(50)', 'bladecluster.iitp.org(23)', 'bladecluster.iitp.org(22)', 'bladecluster.iitp.org(9)', 'bladecluster.iitp.org(10)', 'bladecluster.iitp.org(11)', 'bladecluster.iitp.org(15)', 'bladecluster.iitp.org(5)'] [MainProcess : asyncio_2 (INFO)]
2021-11-06 13:31:12,009 - <ipython-input-9-988a6d4ae51a>::process(l:18) : Epoch: 6, Client bladecluster.iitp.org(2) received 14 model update(s) from ['bladecluster.iitp.org(49)', 'bladecluster.iitp.org(50)', 'bladecluster.iitp.org(3)', 'bladecluster.iitp.org(23)', 'bladecluster.iitp.org(22)', 'bladecluster.iitp.org(9)', 'bladecluster.iitp.org(10)', 'bladecluster.iitp.org(11)', 'bladecluster.iitp.org(15)', 'bla

<h1> End </h1>