In [1]:
%load_ext autoreload
%autoreload 2

import copy, os, socket, sys, time
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
from libs import agg, data, fl, log, nn, plot, poison, resnet, sim, wandb
from libs.distributed import *
from cfgs.fedargs import *

In [2]:
project = 'fl-kafka-client'
fedargs.num_clients = 1
name = 'fedavg-cnn-mnist-na-' + fedargs.name

# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
#log.init("info", name)
#log.init("debug", name)

fedargs.tb = SummaryWriter('../out/runs/' + project + '/' + name, comment="fl")
plot = plot.init(name, project)
wb = wandb.init(name, project)

[34m[1mwandb[0m: Currently logged in as: [33mkasyah[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.6 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [3]:
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [4]:
host = socket.gethostname()
clients = [host + ": " + fedargs.name]

In [5]:
# Distributed topology
dt = Distributed(clients, fedargs.broker_ip, fedargs.schema_ip, fedargs.wait_to_consume)

In [6]:
# Initialize Global and Client models
global_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_data, test_data = data.load_dataset(fedargs.dataset)

In [7]:
clients_data = data.split_data(train_data, clients)

In [8]:
client_train_loaders, client_test_loaders = data.load_client_data(clients_data, fedargs.client_batch_size, 0.2, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

client_details = {
        client: {"train_loader": client_train_loaders[client],
                 "test_loader": client_test_loaders[client],
                 "model": copy.deepcopy(global_model),
                 "model_update": None}
        for client in clients
    }

In [9]:
def process(client, epoch, dt, model, train_loader, test_loader, fedargs, device):
    log.info("Epoch: {}, Processing Client {}".format(epoch, client))
    
    # Consume and Average, epoch passed is actually prev epoch, for which we want to consume updates
    client_model_updates = dt.consume_model(client, fedargs.topic, model, epoch)
    
    # Pop one's own update
    if client in client_model_updates:
        client_model_updates.pop(client)

    log.info("Epoch: {}, Client {} received {} model update(s) from {}".format(epoch, client, 
                                                                               len(client_model_updates), 
                                                                               list(client_model_updates.keys())))
    
    if len(client_model_updates) != 0:
        model = fl.federated_avg(client_model_updates, model)

    # Train    
    model_update, model, loss = fedargs.train_func(model, train_loader, 
                                                   fedargs.learning_rate,
                                                   fedargs.weight_decay,
                                                   fedargs.local_rounds, device)

    # Publish
    epoch = epoch + 1
    dt.produce_model(client, fedargs.topic, model_update, epoch)

    log.jsondebug(loss, "Epoch {} : Federated Training loss, Client {}".format(epoch, client))
    log.modeldebug(model, "Epoch {}: Client {} Update".format(epoch, client))

    # Test, Plot and Log
    test_output = fedargs.eval_func(model, test_loader, device)
    fedargs.tb.add_scalar("Accuracy/" + client, test_output["accuracy"], epoch)
    fedargs.tb.add_scalar("Test Loss/" + client, test_output["test_loss"], epoch)
    plot.alog(client, {epoch: {"time": time.time(), "acc": test_output["accuracy"], "loss": test_output["test_loss"]}})
    wb.log({client: {"epoch": epoch, "time": time.time(), "acc": test_output["accuracy"], "loss": test_output["test_loss"]}})
    log.jsoninfo(test_output, "Test Outut after Epoch {} of {} for Client {}".format(epoch, fedargs.epochs, client))

    return model

In [None]:
import time
start_time = time.time()

# Federated Training
for epoch in tqdm(range(fedargs.epochs)):
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    for client in clients:
        client_details[client]['model'] = process(client, epoch, dt, client_details[client]['model'],
                                                  client_details[client]['train_loader'],
                                                  client_details[client]['test_loader'],
                                                  fedargs, device)
print(time.time() - start_time)

  0%|          | 0/51 [00:00<?, ?it/s]2021-11-06 13:26:20,964 - <ipython-input-10-387a89c474b9>::<module>(l:6) : Federated Training Epoch 0 of 51 [MainProcess : MainThread (INFO)]
2021-11-06 13:26:20,993 - <ipython-input-9-a04f38dd6afa>::process(l:2) : Epoch: 0, Processing Client bladecluster.iitp.org: client-x [MainProcess : MainThread (INFO)]
2021-11-06 13:26:31,176 - <ipython-input-9-a04f38dd6afa>::process(l:11) : Epoch: 0, Client bladecluster.iitp.org: client-x received 50 model update(s) from ['bladecluster.iitp.org(4)', 'bladecluster.iitp.org(5)', 'bladecluster.iitp.org(6)', 'bladecluster.iitp.org(1)', 'bladecluster.iitp.org(18)', 'bladecluster.iitp.org(3)', 'bladecluster.iitp.org(31)', 'bladecluster.iitp.org(17)', 'bladecluster.iitp.org(9)', 'bladecluster.iitp.org(11)', 'bladecluster.iitp.org(32)', 'bladecluster.iitp.org(19)', 'bladecluster.iitp.org(21)', 'bladecluster.iitp.org(2)', 'bladecluster.iitp.org(15)', 'bladecluster.iitp.org(14)', 'bladecluster.iitp.org(13)', 'bladeclus

2021-11-06 13:29:21,625 - <ipython-input-9-a04f38dd6afa>::process(l:2) : Epoch: 3, Processing Client bladecluster.iitp.org: client-x [MainProcess : MainThread (INFO)]
2021-11-06 13:29:31,688 - <ipython-input-9-a04f38dd6afa>::process(l:11) : Epoch: 3, Client bladecluster.iitp.org: client-x received 5 model update(s) from ['bladecluster.iitp.org(2)', 'bladecluster.iitp.org(3)', 'bladecluster.iitp.org(4)', 'bladecluster.iitp.org(1)', 'bladecluster.iitp.org(5)'] [MainProcess : MainThread (INFO)]
2021-11-06 13:30:14,562 - /home/harsh_1921cs01/hub/AgroFed/fl/libs/protobuf_producer.py::produce(l:56) : Producing user records to topic pyflx. ^C to exit. [MainProcess : MainThread (INFO)]
2021-11-06 13:30:14,726 - /home/harsh_1921cs01/hub/AgroFed/fl/libs/protobuf_producer.py::produce(l:66) : Flushing records... [MainProcess : MainThread (INFO)]
2021-11-06 13:30:15,744 - /home/harsh_1921cs01/hub/AgroFed/fl/libs/protobuf_producer.py::delivery_report(l:50) : User record b'bladecluster.iitp.org: clie

<h1> End </h1>