In [9]:
%load_ext autoreload
%autoreload 2

import asyncio, nest_asyncio
nest_asyncio.apply()

import copy, os, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from libs import agg, data, fl, log, nn, plot, poison, resnet, sim, wandb, he
from cfgs.fedargs import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
project = 'fl'
name = 'fl-he'

#Define seed
torch.manual_seed(1)

#Define Custom CFGs
fedargs.enc = True

# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
wb = wandb.init(name, project)

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.17.8 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [11]:
# Device settings
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [12]:
# Prepare clients
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [13]:
# Initialize Global and Client models
global_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_data, test_data = data.load_dataset(fedargs.dataset)

In [14]:
clients_data = data.split_data(train_data, clients)

In [15]:
client_train_loaders, _ = data.load_client_data(clients_data, fedargs.client_batch_size, None, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

client_details = {
        client: {"train_loader": client_train_loaders[client],
                 "model": copy.deepcopy(global_model),
                 "model_update": None}
        for client in clients
    }

In [16]:
def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def process(client, epoch, model, train_loader, fedargs, device):
    # Train
    model_update, model, loss = fedargs.train_func(model, train_loader, 
                                                   fedargs.learning_rate,
                                                   fedargs.weight_decay,
                                                   fedargs.local_rounds, device)

    log.jsondebug(loss, "Epoch {} of {} : Federated Training loss, Client {}".format(epoch, fedargs.epochs, client))
    log.modeldebug(model_update, "Epoch {} of {} : Client {} Update".format(epoch, fedargs.epochs, client))
    
    return model_update

In [None]:
import time
start_time = time.time()
    
# Federated Training
for epoch in tqdm(range(fedargs.epochs)):
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    # Global Model Update
    if epoch > 0:     
        # Average
        if fedargs.enc:
            global_model = he.federated_avg(client_model_updates, global_model)
        else:
            global_model = fl.federated_avg(client_model_updates, global_model)
        log.modeldebug(global_model, "Epoch {} of {} : Server Update".format(epoch, fedargs.epochs))
        
        # Test, Plot and Log
        global_test_output = fedargs.eval_func(global_model, test_loader, device)
        wb.log({"epoch": epoch, "time": time.time(), "acc": global_test_output["accuracy"], "loss": global_test_output["test_loss"]})
        log.jsoninfo(global_test_output, "Global Test Outut after Epoch {} of {}".format(epoch, fedargs.epochs))
        
        # Update client models
        for client in clients:
            client_details[client]['model'] = copy.deepcopy(global_model)

    # Clients
    tasks = [process(client, epoch, client_details[client]['model'],
                     client_details[client]['train_loader'],
                     fedargs, device) for client in clients]
    try:
        updates = fedargs.loop.run_until_complete(asyncio.gather(*tasks))
    except KeyboardInterrupt as e:
        log.error("Caught keyboard interrupt. Canceling tasks...")
        tasks.cancel()
        fedargs.loop.run_forever()
        tasks.exception()

    for client, update in zip(clients, updates):            
        client_details[client]['model_update'] = update
        if fedargs.enc:
            enc_update = he.enc_model_update(update)
            client_details[client]['model_update'] = enc_update

    client_model_updates = {client: details["model_update"] for client, details in client_details.items()}

print(time.time() - start_time)

  0%|          | 0/51 [00:00<?, ?it/s]2024-09-05 19:11:26,460 - <ipython-input-17-95d18e29c38d>::<module>(l:6) : Federated Training Epoch 0 of 51 [MainProcess : MainThread (INFO)]
  2%|▏         | 1/51 [00:39<33:10, 39.80s/it]2024-09-05 19:12:05,985 - <ipython-input-17-95d18e29c38d>::<module>(l:6) : Federated Training Epoch 1 of 51 [MainProcess : MainThread (INFO)]
2024-09-05 19:12:09,524 - <ipython-input-17-95d18e29c38d>::<module>(l:17) : Global Test Outut after Epoch 1 of 51 {
    "accuracy": 59.809999999999995,
    "correct": 5981,
    "test_loss": 0.016576014828681946
} [MainProcess : MainThread (INFO)]
  4%|▍         | 2/51 [01:22<33:53, 41.49s/it]2024-09-05 19:12:48,667 - <ipython-input-17-95d18e29c38d>::<module>(l:6) : Federated Training Epoch 2 of 51 [MainProcess : MainThread (INFO)]
2024-09-05 19:12:52,275 - <ipython-input-17-95d18e29c38d>::<module>(l:17) : Global Test Outut after Epoch 2 of 51 {
    "accuracy": 77.29,
    "correct": 7729,
    "test_loss": 0.014071006071567535

 39%|███▉      | 20/51 [14:23<22:23, 43.33s/it]2024-09-05 19:25:49,669 - <ipython-input-17-95d18e29c38d>::<module>(l:6) : Federated Training Epoch 20 of 51 [MainProcess : MainThread (INFO)]
2024-09-05 19:25:52,934 - <ipython-input-17-95d18e29c38d>::<module>(l:17) : Global Test Outut after Epoch 20 of 51 {
    "accuracy": 93.08999999999999,
    "correct": 9309,
    "test_loss": 0.012125563597679138
} [MainProcess : MainThread (INFO)]
 41%|████      | 21/51 [15:06<21:36, 43.21s/it]2024-09-05 19:26:32,616 - <ipython-input-17-95d18e29c38d>::<module>(l:6) : Federated Training Epoch 21 of 51 [MainProcess : MainThread (INFO)]
2024-09-05 19:26:36,104 - <ipython-input-17-95d18e29c38d>::<module>(l:17) : Global Test Outut after Epoch 21 of 51 {
    "accuracy": 93.15,
    "correct": 9315,
    "test_loss": 0.01211085467338562
} [MainProcess : MainThread (INFO)]
 43%|████▎     | 22/51 [15:49<20:53, 43.22s/it]2024-09-05 19:27:15,855 - <ipython-input-17-95d18e29c38d>::<module>(l:6) : Federated Trainin

<h1> End </h1>