In [1]:
%load_ext autoreload
%autoreload 2

import asyncio, copy, os, pickle, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
import networkx as nx
import numpy as np
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
from libs import agg, data, fl, log, nn, plot, poison, resnet, sim, wandb
from libs.distributed import *
from cfgs.fedargs import *

In [2]:
project = 'fl-kafka'
name = 'fedavg-cnn-mnist-na'

# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
#log.init("info", name)
#log.init("debug", name)

fedargs.tb = SummaryWriter('../out/runs/' + project + '/' + name, comment="fl")
plot = plot.init(name, project)
wb = wandb.init(name, project)

[34m[1mwandb[0m: Currently logged in as: [33mkasyah[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.5 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [3]:
# Kafka topic to publish and subscribe
fedargs.topic = 'pyflx'
fedargs.num_clients = 10

In [4]:
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [5]:
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [6]:
# Distributed topology
dt = Distributed(clients)

In [7]:
# Initialize Global and Client models
global_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_data, test_data = data.load_dataset(fedargs.dataset)

In [8]:
clients_data = data.split_data(train_data, clients)

In [9]:
client_train_loaders, client_test_loaders = data.load_client_data(clients_data, fedargs.client_batch_size, 0.2, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

client_details = {
        client: {"train_loader": client_train_loaders[client],
                 "test_loader": client_test_loaders[client],
                 "model": copy.deepcopy(global_model),
                 "model_update": None}
        for client in clients
    }

In [10]:
def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def process(client, epoch, dt, model, train_loader, test_loader, fedargs, device):
    log.info("Epoch: {}, Processing Client {}".format(epoch, client))
    
    # Consume and Average, epoch passed is actually prev epoch, for which we want to consume updates
    client_model_updates = dt.consume_model(client, fedargs.topic, model, epoch)
    
    # Pop one's own update
    if client in client_model_updates:
        client_model_updates.pop(client)

    log.info("Epoch: {}, Client {} received {} model update(s) from {}".format(epoch, client, 
                                                                               len(client_model_updates), 
                                                                               list(client_model_updates.keys())))
    
    if len(client_model_updates) != 0:
        model = fl.federated_avg(client_model_updates, model)

    # Train    
    model_update, model, loss = fedargs.train_func(model, train_loader, 
                                                   fedargs.learning_rate,
                                                   fedargs.weight_decay,
                                                   fedargs.local_rounds, device)

    # Publish
    epoch = epoch + 1
    dt.produce_model(client, fedargs.topic, model_update, epoch)

    log.jsondebug(loss, "Epoch {} : Federated Training loss, Client {}".format(epoch, client))
    log.modeldebug(model, "Epoch {}: Client {} Update".format(epoch, client))

    # Test, Plot and Log
    test_output = fedargs.eval_func(model, test_loader, device)
    fedargs.tb.add_scalar("Accuracy/" + client, test_output["accuracy"], epoch)
    fedargs.tb.add_scalar("Test Loss/" + client, test_output["test_loss"], epoch)
    plot.alog(client, {epoch: {"time": time.time(), "acc": test_output["accuracy"], "loss": test_output["test_loss"]}})
    wb.log({client: {"epoch": epoch, "time": time.time(), "acc": test_output["accuracy"], "loss": test_output["test_loss"]}})
    log.jsoninfo(test_output, "Test Outut after Epoch {} of {} for Client {}".format(epoch, fedargs.epochs, client))

    return model

In [None]:
import time
start_time = time.time()

# Federated Training
for epoch in tqdm(range(fedargs.epochs)):
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    # Clients  
    tasks = [process(client, epoch, dt, client_details[client]['model'],
                     client_details[client]['train_loader'],
                     client_details[client]['test_loader'],
                     fedargs, device) for client in clients]
    try:
        updates = fedargs.loop.run_until_complete(asyncio.gather(*tasks))
    except KeyboardInterrupt as e:
        log.error("Caught keyboard interrupt. Canceling tasks...")
        tasks.cancel()
        fedargs.loop.run_forever()
        tasks.exception()
    
    for client, update in zip(clients, updates):
        client_details[client]['model'] = update
    
print(time.time() - start_time)

  0%|          | 0/51 [00:00<?, ?it/s]2021-10-26 16:09:29,199 - <ipython-input-11-d06f8893748b>::<module>(l:6) : Federated Training Epoch 0 of 51 [MainProcess : MainThread (INFO)]
2021-10-26 16:09:29,272 - <ipython-input-10-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(4) [MainProcess : asyncio_3 (INFO)]
2021-10-26 16:09:29,279 - <ipython-input-10-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(5) [MainProcess : asyncio_4 (INFO)]
2021-10-26 16:09:29,292 - <ipython-input-10-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(1) [MainProcess : asyncio_0 (INFO)]
2021-10-26 16:09:29,300 - <ipython-input-10-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(7) [MainProcess : asyncio_6 (INFO)]
2021-10-26 16:09:29,301 - <ipython-input-10-988a6d4ae51a>::process(l:9) : Epoch: 0, Processing Client bladecluster.iitp.org(3) [MainProcess : asyncio_2 (INFO)]
2021-10-26 16:09:29,

2021-10-26 16:11:09,579 - <ipython-input-10-988a6d4ae51a>::process(l:44) : Test Outut after Epoch 1 of 51 for Client bladecluster.iitp.org(7) {
    "accuracy": 86.91666666666666,
    "correct": 1043,
    "test_loss": 0.46456432978312173
} [MainProcess : asyncio_6 (INFO)]
2021-10-26 16:11:10,492 - /home/harsh_1921cs01/hub/AgroFed/fl/libs/protobuf_producer.py::produce(l:56) : Producing user records to topic pyflx. ^C to exit. [MainProcess : asyncio_0 (INFO)]
2021-10-26 16:11:10,902 - <ipython-input-10-988a6d4ae51a>::process(l:44) : Test Outut after Epoch 1 of 51 for Client bladecluster.iitp.org(8) {
    "accuracy": 87.75,
    "correct": 1053,
    "test_loss": 0.47700483838717145
} [MainProcess : asyncio_7 (INFO)]
2021-10-26 16:11:11,012 - /home/harsh_1921cs01/hub/AgroFed/fl/libs/protobuf_producer.py::produce(l:66) : Flushing records... [MainProcess : asyncio_0 (INFO)]
2021-10-26 16:11:11,217 - <ipython-input-10-988a6d4ae51a>::process(l:44) : Test Outut after Epoch 1 of 51 for Client blad

2021-10-26 16:11:36,116 - <ipython-input-10-988a6d4ae51a>::process(l:18) : Epoch: 1, Client bladecluster.iitp.org(2) received 9 model update(s) from ['bladecluster.iitp.org(7)', 'bladecluster.iitp.org(8)', 'bladecluster.iitp.org(5)', 'bladecluster.iitp.org(9)', 'bladecluster.iitp.org(6)', 'bladecluster.iitp.org(3)', 'bladecluster.iitp.org(10)', 'bladecluster.iitp.org(1)', 'bladecluster.iitp.org(4)'] [MainProcess : asyncio_7 (INFO)]
2021-10-26 16:11:36,323 - <ipython-input-10-988a6d4ae51a>::process(l:18) : Epoch: 1, Client bladecluster.iitp.org(4) received 9 model update(s) from ['bladecluster.iitp.org(7)', 'bladecluster.iitp.org(8)', 'bladecluster.iitp.org(5)', 'bladecluster.iitp.org(2)', 'bladecluster.iitp.org(9)', 'bladecluster.iitp.org(6)', 'bladecluster.iitp.org(3)', 'bladecluster.iitp.org(10)', 'bladecluster.iitp.org(1)'] [MainProcess : asyncio_1 (INFO)]
2021-10-26 16:11:36,545 - <ipython-input-10-988a6d4ae51a>::process(l:18) : Epoch: 1, Client bladecluster.iitp.org(9) received 9 

<h1> End </h1>