In [1]:
%load_ext autoreload
%autoreload 2

import asyncio, nest_asyncio
nest_asyncio.apply()

import copy, os, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
from libs import agg, data, fl, log, nn, plot, poison, resnet, sim, wandb
from cfgs.fedargs import *

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'wandb'

In [8]:
project = 'fl'
name = 'fedavg-cnn-mnist-na'

#Define Custom CFGs

# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
wb = wandb.init(name, project)





In [9]:
# Device settings
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [10]:
# Prepare clients
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [11]:
# Initialize Global and Client models
global_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_data, test_data = data.load_dataset(fedargs.dataset)

In [12]:
clients_data = data.split_data(train_data, clients)

In [13]:
client_train_loaders, _ = data.load_client_data(clients_data, fedargs.client_batch_size, None, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

client_details = {
        client: {"train_loader": client_train_loaders[client],
                 "model": copy.deepcopy(global_model),
                 "model_update": None}
        for client in clients
    }

In [14]:
def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def process(client, epoch, model, train_loader, fedargs, device):
    # Train
    model_update, model, loss = fedargs.train_func(model, train_loader, 
                                                   fedargs.learning_rate,
                                                   fedargs.weight_decay,
                                                   fedargs.local_rounds, device)

    log.jsondebug(loss, "Epoch {} of {} : Federated Training loss, Client {}".format(epoch, fedargs.epochs, client))
    log.modeldebug(model_update, "Epoch {} of {} : Client {} Update".format(epoch, fedargs.epochs, client))
    
    return model_update

In [18]:
import time
start_time = time.time()
    
# Federated Training
for epoch in tqdm(range(fedargs.epochs)):
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    # Global Model Update
    if epoch > 0:     
        # Average
        global_model = fl.federated_avg(client_model_updates, global_model)
        log.modeldebug(global_model, "Epoch {} of {} : Server Update".format(epoch, fedargs.epochs))
        
        # Test, Plot and Log
        global_test_output = fedargs.eval_func(global_model, test_loader, device)
        wb.log({"epoch": epoch, "time": time.time(), "acc": global_test_output["accuracy"], "loss": global_test_output["test_loss"]})
        log.jsoninfo(global_test_output, "Global Test Outut after Epoch {} of {}".format(epoch, fedargs.epochs))
        
        # Update client models
        for client in clients:
            client_details[client]['model'] = copy.deepcopy(global_model)

    # Clients
    tasks = [process(client, epoch, client_details[client]['model'],
                     client_details[client]['train_loader'],
                     fedargs, device) for client in clients]
    try:
        updates = fedargs.loop.run_until_complete(asyncio.gather(*tasks))
    except KeyboardInterrupt as e:
        log.error("Caught keyboard interrupt. Canceling tasks...")
        tasks.cancel()
        fedargs.loop.run_forever()
        tasks.exception()

    for client, update in zip(clients, updates):
        client_details[client]['model_update'] = update
    client_model_updates = {client: details["model_update"] for client, details in client_details.items()}

print(time.time() - start_time)

  0%|                                                    | 0/51 [00:00<?, ?it/s]2023-10-15 08:55:54,024 - /var/folders/bt/rxysq_fs7ggf93mf5dc5250r0000gr/T/ipykernel_10083/2384706191.py::<module>(l:6) : Federated Training Epoch 0 of 51 [MainProcess : MainThread (INFO)]
  2%|▊                                           | 1/51 [00:41<34:18, 41.18s/it]2023-10-15 08:56:35,197 - /var/folders/bt/rxysq_fs7ggf93mf5dc5250r0000gr/T/ipykernel_10083/2384706191.py::<module>(l:6) : Federated Training Epoch 1 of 51 [MainProcess : MainThread (INFO)]
2023-10-15 08:56:37,506 - /var/folders/bt/rxysq_fs7ggf93mf5dc5250r0000gr/T/ipykernel_10083/2384706191.py::<module>(l:17) : Global Test Outut after Epoch 1 of 51 {
    "accuracy": 44.43,
    "correct": 4443,
    "test_loss": -0.12358662796020507
} [MainProcess : MainThread (INFO)]
  4%|█▋                                          | 2/51 [01:25<35:10, 43.07s/it]2023-10-15 08:57:19,594 - /var/folders/bt/rxysq_fs7ggf93mf5dc5250r0000gr/T/ipykernel_10083/2384706191

AttributeError: 'list' object has no attribute 'cancel'

<h1> End </h1>