In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import copy

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
%reload_ext autoreload
%autoreload 2

from src.constants import csv_file, directory, INPUT_SHAPE, YEAR, ext

from src.utils import (
    get_dataloader,
    sample_iid,
    split_val,
)

from src.model import Model_Retinopathy
from src.server_scaffold import Server_Scaffold

from src.constants import EPOCHS, BATCH_SIZE, LEARNING_RATE, UPDATES
from src.constants import K_CLIENTS, C, rounds, clients
from src.constants import optimizer_fn

In [4]:
# DATA STUFF
df = pd.read_csv(csv_file)
df, val_df = split_val(df, 0.05)
df = sample_iid(df, 0.5).reset_index(drop=True)
_, val_loader = get_dataloader(val_df, ext, directory, BATCH_SIZE)

In [5]:
gpu_memory = torch.cuda.memory_allocated()
print(gpu_memory / (1024**2), "MB")


server_scaffold = Server_Scaffold(K_CLIENTS, optimizer_fn, df, val_loader).to(device)
print((torch.cuda.memory_allocated() - gpu_memory) / (1024**2), "MB")

0.0 MB
1486.9326171875 MB


In [6]:
server_scaffold.train_loop(rounds, EPOCHS)

GLOBAL Loss = 1.5713680287202199, Acc = 0.1092896174863388
1495.0576171875 MB
Round 0
--------------------------------------------------
1495.0576171875 MB
Client 0:
1582.638671875 MB
Epoch 0: Loss: 1.5053349634011586, Accuracy: 0.14207650273224043
1724.25048828125 MB
Epoch 1: Loss: 1.320368488629659, Accuracy: 0.5792349726775956
1724.25048828125 MB
Epoch 2: Loss: 1.2039523621400197, Accuracy: 0.6502732240437158
1724.25048828125 MB
Epoch 3: Loss: 1.1201504866282146, Accuracy: 0.6830601092896175
1638.16845703125 MB
Client 1:
1728.24951171875 MB
Epoch 0: Loss: 1.497679164012273, Accuracy: 0.15300546448087432
1863.923828125 MB
Epoch 1: Loss: 1.3308786749839783, Accuracy: 0.5792349726775956
1863.923828125 MB
Epoch 2: Loss: 1.1962503393491108, Accuracy: 0.6338797814207651
1863.923828125 MB
Epoch 3: Loss: 1.1442056745290756, Accuracy: 0.6830601092896175
1775.529296875 MB
Client 2:
1864.7353515625 MB
Epoch 0: Loss: 1.5232085684935253, Accuracy: 0.13114754098360656
2000.47216796875 MB
Epoch 1:

In [7]:
def fed_avg():
    m_clients = int(max(1, K_CLIENTS * C))
    for t in range(rounds):
        print("ROUND: ", t)
        print("-------------------------------" * 3)
        weight_global = global_model.get_weights()
        selected_clients = np.random.choice(clients, m_clients, replace=False)
        selected_clients.sort()

        weights = [0] * K_CLIENTS
        m_t = sum([len(datasets[i]) for i in selected_clients])
        for client in selected_clients:
            local_model = Model_Retinopathy().to(device)
            local_model.set_weights(global_model.get_weights())

            print("TRAINING CLIENT", client, "Federated")
            train_loop(
                local_model,
                loaders[client],
                val_loader,
                epochs=EPOCHS,
                lr=LEARNING_RATE,
                verbose=False,
            )
            print("TRAINING CLIENT", client, "Independent")
            train_loop(
                models_ind[client],
                loaders[client],
                val_loader,
                epochs=EPOCHS,
                lr=LEARNING_RATE,
                verbose=False,
            )
            weights[client] = local_model.get_weights()
            for layer in weights[client]:
                weights[client][layer] = (
                    weights[client][layer] * len(datasets[client]) / m_t
                )
            losses_clients[client].append(models_ind[client].get_loss(val_loader))
            accuracies_clients[client].append(
                models_ind[client].get_accuracy(val_loader)
            )
        for layer in weight_global:
            for i, client in enumerate(selected_clients):
                if i == 0:
                    weight_global[layer] = weights[client][layer]
                else:
                    weight_global[layer] = weights[client][layer] + weight_global[layer]
        global_model.set_weights(weight_global)
        loss = global_model.get_loss(val_loader)
        losses.append(loss)
        acc = global_model.get_accuracy(val_loader)
        accuracies.append(acc)
        print(f"Loss = {loss}, Acc = {acc}")

        save_plots()