In [4]:
%load_ext autoreload
%autoreload 2

import asyncio, nest_asyncio
nest_asyncio.apply()

import copy, os, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from libs import agg, data, fl, log, nn, plot, poison, resnet, sim, wandb
from cfgs.fedargs import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
project = 'fl'
name = 'fl-he'

#Define Custom CFGs

# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
wb = wandb.init(name, project)

2024-09-05 11:14:14,723 - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving. [MainProcess : MainThread (ERROR)]
[34m[1mwandb[0m: Currently logged in as: [33mkasyah[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
# Device settings
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [7]:
# Prepare clients
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [8]:
# Initialize Global and Client models
global_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_data, test_data = data.load_dataset(fedargs.dataset)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|████████████████████████████| 9912422/9912422 [00:01<00:00, 7352513.01it/s]


Extracting /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw/train-images-idx3-ubyte.gz to /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|█████████████████████████████████| 28881/28881 [00:00<00:00, 336682.57it/s]


Extracting /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw/train-labels-idx1-ubyte.gz to /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|████████████████████████████| 1648877/1648877 [00:00<00:00, 3089433.44it/s]


Extracting /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw/t10k-images-idx3-ubyte.gz to /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|███████████████████████████████████| 4542/4542 [00:00<00:00, 670887.76it/s]

Extracting /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /Users/hkasyap/Desktop/ATI/fl/libs/../data/MNIST/raw






In [9]:
clients_data = data.split_data(train_data, clients)

In [10]:
client_train_loaders, _ = data.load_client_data(clients_data, fedargs.client_batch_size, None, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

client_details = {
        client: {"train_loader": client_train_loaders[client],
                 "model": copy.deepcopy(global_model),
                 "model_update": None}
        for client in clients
    }

In [11]:
def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def process(client, epoch, model, train_loader, fedargs, device):
    # Train
    model_update, model, loss = fedargs.train_func(model, train_loader, 
                                                   fedargs.learning_rate,
                                                   fedargs.weight_decay,
                                                   fedargs.local_rounds, device)

    log.jsondebug(loss, "Epoch {} of {} : Federated Training loss, Client {}".format(epoch, fedargs.epochs, client))
    log.modeldebug(model_update, "Epoch {} of {} : Client {} Update".format(epoch, fedargs.epochs, client))
    
    return model_update

In [None]:
import time
start_time = time.time()
    
# Federated Training
for epoch in tqdm(range(fedargs.epochs)):
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    # Global Model Update
    if epoch > 0:     
        # Average
        global_model = fl.federated_avg(client_model_updates, global_model)
        log.modeldebug(global_model, "Epoch {} of {} : Server Update".format(epoch, fedargs.epochs))
        
        # Test, Plot and Log
        global_test_output = fedargs.eval_func(global_model, test_loader, device)
        wb.log({"epoch": epoch, "time": time.time(), "acc": global_test_output["accuracy"], "loss": global_test_output["test_loss"]})
        log.jsoninfo(global_test_output, "Global Test Outut after Epoch {} of {}".format(epoch, fedargs.epochs))
        
        # Update client models
        for client in clients:
            client_details[client]['model'] = copy.deepcopy(global_model)

    # Clients
    tasks = [process(client, epoch, client_details[client]['model'],
                     client_details[client]['train_loader'],
                     fedargs, device) for client in clients]
    try:
        updates = fedargs.loop.run_until_complete(asyncio.gather(*tasks))
    except KeyboardInterrupt as e:
        log.error("Caught keyboard interrupt. Canceling tasks...")
        tasks.cancel()
        fedargs.loop.run_forever()
        tasks.exception()

    for client, update in zip(clients, updates):
        client_details[client]['model_update'] = update
    client_model_updates = {client: details["model_update"] for client, details in client_details.items()}

print(time.time() - start_time)

  0%|                                                    | 0/51 [00:00<?, ?it/s]2024-09-05 11:14:46,157 - /var/folders/bt/rxysq_fs7ggf93mf5dc5250r0000gr/T/ipykernel_74004/2384706191.py::<module>(l:6) : Federated Training Epoch 0 of 51 [MainProcess : MainThread (INFO)]
  2%|▊                                           | 1/51 [00:39<32:49, 39.39s/it]2024-09-05 11:15:25,516 - /var/folders/bt/rxysq_fs7ggf93mf5dc5250r0000gr/T/ipykernel_74004/2384706191.py::<module>(l:6) : Federated Training Epoch 1 of 51 [MainProcess : MainThread (INFO)]
2024-09-05 11:15:27,554 - /var/folders/bt/rxysq_fs7ggf93mf5dc5250r0000gr/T/ipykernel_74004/2384706191.py::<module>(l:17) : Global Test Outut after Epoch 1 of 51 {
    "accuracy": 48.55,
    "correct": 4855,
    "test_loss": -0.14279726257324218
} [MainProcess : MainThread (INFO)]


<h1> End </h1>