In [1]:
%load_ext autoreload
%autoreload 2

import asyncio, copy, os, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim

import shap

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from libs import agg, data, fl, helper, log, nn, poison, resnet, sim, wandb
from libs.distributed import *
from cfgs.fedargs import *

In [2]:
project = 'fl-shap'
name = 'fedavg-cnn-mnist' + fedargs.name

#Define Custom CFGs
fedargs.num_clients = 1

# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
wb = wandb.init(name, project)

[34m[1mwandb[0m: Currently logged in as: [33mkasyah[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.10 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [3]:
# Device settings
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [4]:
host = socket.gethostname()
clients = [host + ": " + fedargs.name]

In [5]:
# Distributed topology
dt = Distributed(clients, fedargs.broker_ip, fedargs.schema_ip, fedargs.wait_to_consume)

In [6]:
# Initialize Global and Client models
global_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_data, test_data = data.load_dataset(fedargs.dataset)

In [7]:
clients_data = data.split_data(train_data, clients)

In [8]:
client_train_loaders, client_test_loaders = data.load_client_data(clients_data, fedargs.client_batch_size, 0.2, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

client_details = {
        client: {"train_loader": client_train_loaders[client],
                 "test_loader": client_test_loaders[client],
                 "model": copy.deepcopy(global_model),
                 "model_update": None}
        for client in clients
    }

In [9]:
def process(client, epoch, dt, model, train_loader, test_loader, fedargs, device):
    log.info("Epoch: {}, Processing Client {}".format(epoch, client))
    
    # Consume and Average, epoch passed is actually prev epoch, for which we want to consume updates
    client_model_updates = dt.consume_model(client, fedargs.topic, model, epoch)
    
    # Pop one's own update
    if client in client_model_updates:
        client_model_updates.pop(client)

    log.info("Epoch: {}, Client {} received {} model update(s) from {}".format(epoch, client, 
                                                                               len(client_model_updates), 
                                                                               list(client_model_updates.keys())))
    
    if len(client_model_updates) != 0:
        model = fl.federated_avg(client_model_updates, model)

    # Train    
    model_update, model, loss = fedargs.train_func(model, train_loader, 
                                                   fedargs.learning_rate,
                                                   fedargs.weight_decay,
                                                   fedargs.local_rounds, device)

    # Publish
    epoch = epoch + 1
    dt.produce_model(client, fedargs.topic, model_update, epoch)

    log.jsondebug(loss, "Epoch {} : Federated Training loss, Client {}".format(epoch, client))
    log.modeldebug(model, "Epoch {}: Client {} Update".format(epoch, client))

    # Test, Plot and Log
    test_output = fedargs.eval_func(model, test_loader, device)
    wb.log({client: {"epoch": epoch, "time": time.time(), "acc": test_output["accuracy"], "loss": test_output["test_loss"]}})
    log.jsoninfo(test_output, "Test Outut after Epoch {} of {} for Client {}".format(epoch, fedargs.epochs, client))

    return model

In [10]:
import time
start_time = time.time()
fedargs.epochs = 1
# Federated Training
for epoch in tqdm(range(fedargs.epochs)):
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    for client in clients:
        client_details[client]['model'] = process(client, epoch, dt, client_details[client]['model'],
                                                  client_details[client]['train_loader'],
                                                  client_details[client]['test_loader'],
                                                  fedargs, device)
print(time.time() - start_time)

  0%|          | 0/1 [00:00<?, ?it/s]2022-02-07 11:29:51,017 - <ipython-input-10-32c122a5cc8b>::<module>(l:6) : Federated Training Epoch 0 of 1 [MainProcess : MainThread (INFO)]
2022-02-07 11:29:51,043 - <ipython-input-9-e20d4c96e1f4>::process(l:2) : Epoch: 0, Processing Client bladecluster.iitp.org: client-x [MainProcess : MainThread (INFO)]
2022-02-07 11:30:01,098 - <ipython-input-9-e20d4c96e1f4>::process(l:11) : Epoch: 0, Client bladecluster.iitp.org: client-x received 0 model update(s) from [] [MainProcess : MainThread (INFO)]
2022-02-07 11:30:34,114 - /home/harsh_1921cs01/hub/AgroFed/fl/libs/protobuf_producer.py::produce(l:56) : Producing user records to topic pyflx. ^C to exit. [MainProcess : MainThread (INFO)]
2022-02-07 11:30:34,165 - /home/harsh_1921cs01/hub/AgroFed/fl/libs/protobuf_producer.py::produce(l:64) : Exception raised [MainProcess : MainThread (ERROR)]
2022-02-07 11:30:34,197 - /home/harsh_1921cs01/hub/AgroFed/fl/libs/protobuf_producer.py::produce(l:66) : Flushing re

47.10743498802185



