In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import DataLoader
import numpy as np 

from models import MnistNet
from utils import train, test
from create_dataset import create_iid_dataset
from utils import load_partition, load_data
import copy

In [3]:
DATA_DIR = './datasets/fl_mnist_noniid/'
num_clients = 2
B = 32
# Train loaders for clients
train_sets = [load_partition(cid=x + 1, data_dir=DATA_DIR) for x in range(num_clients)]
train_loaders = [DataLoader(dataset, batch_size=B, shuffle=True) for dataset in train_sets]
_, test_loader = load_data()

  X_train = torch.tensor(X_train).type(torch.FloatTensor)


In [4]:
[len(train_sets[i]) for i in range(num_clients)]

[30000, 30000]

In [5]:
device = 'mps'
net_glob = MnistNet()
net_glob.train()
w_glob = net_glob.state_dict()

In [6]:
loss_train = []
cv_loss, cv_acc = [], []
val_loss_pre, counter = 0, 0
net_best = None
best_loss = None 
val_acc_list, net_list = [], []

In [7]:
round = 10
local_epoch = 5
lr = 0.01

In [8]:
def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

In [9]:
w_locals = [w_glob for _ in range(num_clients)]
val_acc_list, val_loss_list = [], []
for iter in range(round):
    loss_locals = []
    idxs_clients = list(range(num_clients))
    round_lr = lr/(1 + iter) # weight decay for lr at each round
    # local_epoch += 10
    for idx in idxs_clients:
        print(f'Client {idx} is training')
        train_local_loader = train_loaders[idx]
        net_local = copy.deepcopy(net_glob).to(device)
        train(net_local, train_local_loader, local_epoch, device, lr=round_lr)
        w_locals[idx] = net_local.state_dict()
        print('===============')

    w_glob = FedAvg(w_locals)
    net_glob.load_state_dict(w_glob)
    print(f'Round {iter + 1} - Centralize Evaluation')
    loss, acc, total = test(net_glob, test_loader, device)
    val_loss_list.append(loss)
    val_acc_list.append(acc)
    print(f'Num samples {total} | Loss {loss} | Acc {acc}')

Client 0 is training


 20%|██        | 1/5 [00:11<00:47, 11.84s/it]

Epoch 1: train loss inf, accuracy 0.3257333333333333


 40%|████      | 2/5 [00:23<00:35, 11.70s/it]

Epoch 2: train loss 0.03894336133797963, accuracy 0.398


 60%|██████    | 3/5 [00:35<00:23, 11.64s/it]

Epoch 3: train loss 0.03755306511521339, accuracy 0.4384


 80%|████████  | 4/5 [00:47<00:11, 11.80s/it]

Epoch 4: train loss 0.035680377423763274, accuracy 0.5334333333333333


100%|██████████| 5/5 [00:59<00:00, 11.82s/it]


Epoch 5: train loss 0.03605516842206319, accuracy 0.46413333333333334
Client 1 is training


 20%|██        | 1/5 [00:11<00:47, 11.87s/it]

Epoch 1: train loss 0.054709232779343926, accuracy 0.19786666666666666


 40%|████      | 2/5 [00:23<00:35, 11.88s/it]

Epoch 2: train loss 0.05273698762257894, accuracy 0.20313333333333333


 60%|██████    | 3/5 [00:35<00:23, 11.94s/it]

Epoch 3: train loss 0.052653967384497326, accuracy 0.1997


 80%|████████  | 4/5 [00:46<00:11, 11.46s/it]

Epoch 4: train loss 0.05261373739242554, accuracy 0.19873333333333335


100%|██████████| 5/5 [00:55<00:00, 11.09s/it]

Epoch 5: train loss 0.05257239559491476, accuracy 0.20056666666666667
Round 1 - Centralize Evaluation





Num samples 10000 | Loss 0.08366061129570007 | Acc 0.0982
Client 0 is training


 20%|██        | 1/5 [00:09<00:36,  9.20s/it]

Epoch 1: train loss 0.041857781573136646, accuracy 0.3862


 40%|████      | 2/5 [00:18<00:27,  9.24s/it]

Epoch 2: train loss 0.035772474626700086, accuracy 0.4844


 60%|██████    | 3/5 [00:28<00:18,  9.44s/it]

Epoch 3: train loss 0.027214984729886055, accuracy 0.6718666666666666


 80%|████████  | 4/5 [00:40<00:10, 10.49s/it]

Epoch 4: train loss 0.016231320372968913, accuracy 0.8307666666666667


100%|██████████| 5/5 [00:52<00:00, 10.46s/it]


Epoch 5: train loss 0.01423501363073786, accuracy 0.8559
Client 1 is training


 20%|██        | 1/5 [00:12<00:49, 12.36s/it]

Epoch 1: train loss 0.04634051198164622, accuracy 0.40846666666666664


 40%|████      | 2/5 [00:24<00:36, 12.20s/it]

Epoch 2: train loss 0.031189100904266038, accuracy 0.6542


 60%|██████    | 3/5 [00:36<00:24, 12.09s/it]

Epoch 3: train loss 0.023168710786600908, accuracy 0.7646


 80%|████████  | 4/5 [00:48<00:12, 12.06s/it]

Epoch 4: train loss 0.019745724546412626, accuracy 0.8047


100%|██████████| 5/5 [01:00<00:00, 12.07s/it]

Epoch 5: train loss 0.018526387937863667, accuracy 0.8199333333333333
Round 2 - Centralize Evaluation





Num samples 10000 | Loss 0.08056090676784515 | Acc 0.0983
Client 0 is training


 20%|██        | 1/5 [00:12<00:48, 12.15s/it]

Epoch 1: train loss 0.015247550216068825, accuracy 0.8550666666666666


 40%|████      | 2/5 [00:24<00:36, 12.17s/it]

Epoch 2: train loss 0.012136178851375978, accuracy 0.8854333333333333


 60%|██████    | 3/5 [00:36<00:24, 12.16s/it]

Epoch 3: train loss 0.011785851792742809, accuracy 0.8883


 80%|████████  | 4/5 [00:48<00:12, 12.11s/it]

Epoch 4: train loss 0.010806424609944224, accuracy 0.9001666666666667


100%|██████████| 5/5 [01:00<00:00, 12.11s/it]


Epoch 5: train loss 0.010525804562742512, accuracy 0.9026666666666666
Client 1 is training


 20%|██        | 1/5 [00:11<00:46, 11.63s/it]

Epoch 1: train loss 0.023615058527886867, accuracy 0.7632


 40%|████      | 2/5 [00:21<00:31, 10.35s/it]

Epoch 2: train loss 0.018106598128378393, accuracy 0.8197333333333333


 60%|██████    | 3/5 [00:30<00:19,  9.89s/it]

Epoch 3: train loss 0.017071339802940688, accuracy 0.8324333333333334


 80%|████████  | 4/5 [00:39<00:09,  9.61s/it]

Epoch 4: train loss 0.016011383214592934, accuracy 0.8406


100%|██████████| 5/5 [00:48<00:00,  9.76s/it]

Epoch 5: train loss 0.015791261475284896, accuracy 0.8431333333333333
Round 3 - Centralize Evaluation





Num samples 10000 | Loss 0.07722982017993928 | Acc 0.1203
Client 0 is training


 20%|██        | 1/5 [00:09<00:37,  9.36s/it]

Epoch 1: train loss 0.012005681972329815, accuracy 0.8896333333333334


 40%|████      | 2/5 [00:18<00:27,  9.27s/it]

Epoch 2: train loss 0.009943885834390919, accuracy 0.9091333333333333


 60%|██████    | 3/5 [00:27<00:18,  9.24s/it]

Epoch 3: train loss 0.00959137728822728, accuracy 0.9111


 80%|████████  | 4/5 [00:38<00:09,  9.96s/it]

Epoch 4: train loss 0.009189891709697744, accuracy 0.9152333333333333


100%|██████████| 5/5 [00:48<00:00,  9.71s/it]


Epoch 5: train loss 0.009007405311055481, accuracy 0.9173333333333333
Client 1 is training


 20%|██        | 1/5 [00:11<00:45, 11.42s/it]

Epoch 1: train loss 0.017582275432844956, accuracy 0.8295


 40%|████      | 2/5 [00:23<00:35, 11.76s/it]

Epoch 2: train loss 0.01451245742018024, accuracy 0.8564333333333334


 60%|██████    | 3/5 [00:35<00:23, 11.74s/it]

Epoch 3: train loss 0.013993458000694712, accuracy 0.8632666666666666


 80%|████████  | 4/5 [00:46<00:11, 11.42s/it]

Epoch 4: train loss 0.013950397908439239, accuracy 0.8634


100%|██████████| 5/5 [00:55<00:00, 11.08s/it]

Epoch 5: train loss 0.014189660334090392, accuracy 0.8588666666666667
Round 4 - Centralize Evaluation





Num samples 10000 | Loss 0.07734139859676362 | Acc 0.1325
Client 0 is training


 20%|██        | 1/5 [00:09<00:37,  9.43s/it]

Epoch 1: train loss 0.010435557153138021, accuracy 0.9021333333333333


 40%|████      | 2/5 [00:18<00:28,  9.51s/it]

Epoch 2: train loss 0.008816959737551708, accuracy 0.9184666666666667


 60%|██████    | 3/5 [00:29<00:20, 10.11s/it]

Epoch 3: train loss 0.008417046220352253, accuracy 0.9219666666666667


 80%|████████  | 4/5 [00:41<00:10, 10.91s/it]

Epoch 4: train loss 0.007983106633182614, accuracy 0.9265666666666666


100%|██████████| 5/5 [00:54<00:00, 10.86s/it]


Epoch 5: train loss 0.007714001622935757, accuracy 0.9279666666666667
Client 1 is training


 20%|██        | 1/5 [00:12<00:48, 12.01s/it]

Epoch 1: train loss 0.015121973108748596, accuracy 0.8534


 40%|████      | 2/5 [00:23<00:35, 11.91s/it]

Epoch 2: train loss 0.013384585217386485, accuracy 0.8681666666666666


 60%|██████    | 3/5 [00:35<00:23, 11.96s/it]

Epoch 3: train loss 0.013036164509753386, accuracy 0.8704666666666667


 80%|████████  | 4/5 [00:47<00:11, 11.96s/it]

Epoch 4: train loss 0.01279500408793489, accuracy 0.8732333333333333


100%|██████████| 5/5 [00:59<00:00, 11.98s/it]

Epoch 5: train loss 0.012456294353554645, accuracy 0.878
Round 5 - Centralize Evaluation





Num samples 10000 | Loss 0.07706052141189575 | Acc 0.1305
Client 0 is training


 20%|██        | 1/5 [00:11<00:47, 11.77s/it]

Epoch 1: train loss 0.008936371564989288, accuracy 0.9156333333333333


 40%|████      | 2/5 [00:23<00:35, 11.84s/it]

Epoch 2: train loss 0.0076070491214593255, accuracy 0.9278666666666666


 60%|██████    | 3/5 [00:35<00:23, 11.89s/it]

Epoch 3: train loss 0.00728126170138518, accuracy 0.931


 80%|████████  | 4/5 [00:47<00:11, 11.93s/it]

Epoch 4: train loss 0.006772160805979123, accuracy 0.9360666666666667


100%|██████████| 5/5 [00:59<00:00, 11.90s/it]


Epoch 5: train loss 0.006487900537035117, accuracy 0.9402
Client 1 is training


 20%|██        | 1/5 [00:11<00:47, 11.83s/it]

Epoch 1: train loss 0.01400797833080093, accuracy 0.8663666666666666


 40%|████      | 2/5 [00:23<00:35, 11.91s/it]

Epoch 2: train loss 0.01212767269536853, accuracy 0.8826666666666667


 60%|██████    | 3/5 [00:35<00:23, 11.93s/it]

Epoch 3: train loss 0.011805721516534686, accuracy 0.8858


 80%|████████  | 4/5 [00:47<00:11, 11.92s/it]

Epoch 4: train loss 0.011638955770929655, accuracy 0.8856333333333334


100%|██████████| 5/5 [00:59<00:00, 11.90s/it]

Epoch 5: train loss 0.011004196638117233, accuracy 0.8909333333333334
Round 6 - Centralize Evaluation





Num samples 10000 | Loss 0.07672494449615479 | Acc 0.1352
Client 0 is training


 20%|██        | 1/5 [00:12<00:48, 12.01s/it]

Epoch 1: train loss 0.007731869240229328, accuracy 0.9289


 40%|████      | 2/5 [00:23<00:35, 11.90s/it]

Epoch 2: train loss 0.006531660691757376, accuracy 0.9387666666666666


 60%|██████    | 3/5 [00:36<00:24, 12.25s/it]

Epoch 3: train loss 0.006163860742375254, accuracy 0.9416666666666667


 80%|████████  | 4/5 [00:49<00:12, 12.51s/it]

Epoch 4: train loss 0.0061386867006309334, accuracy 0.9432333333333334


100%|██████████| 5/5 [01:02<00:00, 12.51s/it]


Epoch 5: train loss 0.005856793928534413, accuracy 0.9451
Client 1 is training


 20%|██        | 1/5 [00:12<00:50, 12.52s/it]

Epoch 1: train loss 0.012437747675428788, accuracy 0.8811666666666667


 40%|████      | 2/5 [00:24<00:36, 12.18s/it]

Epoch 2: train loss 0.011090187410016855, accuracy 0.8926333333333333


 60%|██████    | 3/5 [00:36<00:24, 12.15s/it]

Epoch 3: train loss 0.010960629586999615, accuracy 0.8931333333333333


 80%|████████  | 4/5 [00:48<00:12, 12.06s/it]

Epoch 4: train loss 0.010408804128443201, accuracy 0.8971666666666667


100%|██████████| 5/5 [01:00<00:00, 12.12s/it]

Epoch 5: train loss 0.010325750071927904, accuracy 0.9022
Round 7 - Centralize Evaluation





Num samples 10000 | Loss 0.07521801011562347 | Acc 0.1441
Client 0 is training


 20%|██        | 1/5 [00:12<00:48, 12.21s/it]

Epoch 1: train loss 0.007246300582618763, accuracy 0.9335


 40%|████      | 2/5 [00:24<00:36, 12.15s/it]

Epoch 2: train loss 0.006138758607069031, accuracy 0.9431


 60%|██████    | 3/5 [00:36<00:24, 12.02s/it]

Epoch 3: train loss 0.005655267450067913, accuracy 0.9454


 80%|████████  | 4/5 [00:48<00:12, 12.00s/it]

Epoch 4: train loss 0.0057076254412299025, accuracy 0.9463333333333334


100%|██████████| 5/5 [01:00<00:00, 12.04s/it]


Epoch 5: train loss 0.005519295211640808, accuracy 0.9469
Client 1 is training


 20%|██        | 1/5 [00:11<00:47, 11.95s/it]

Epoch 1: train loss 0.01168156103802224, accuracy 0.8872333333333333


 40%|████      | 2/5 [00:23<00:35, 11.90s/it]

Epoch 2: train loss 0.010365772880117098, accuracy 0.9011666666666667


 60%|██████    | 3/5 [00:35<00:23, 11.92s/it]

Epoch 3: train loss 0.010285303761810064, accuracy 0.8991333333333333


 80%|████████  | 4/5 [00:47<00:11, 11.97s/it]

Epoch 4: train loss 0.010103911984153092, accuracy 0.9022333333333333


100%|██████████| 5/5 [00:59<00:00, 11.95s/it]

Epoch 5: train loss 0.009618424387462438, accuracy 0.9045
Round 8 - Centralize Evaluation





Num samples 10000 | Loss 0.07526814022064209 | Acc 0.1463
Client 0 is training


 20%|██        | 1/5 [00:11<00:47, 11.88s/it]

Epoch 1: train loss 0.006734496789022038, accuracy 0.9367333333333333


 40%|████      | 2/5 [00:23<00:35, 11.91s/it]

Epoch 2: train loss 0.0056526163374035, accuracy 0.9472333333333334


 60%|██████    | 3/5 [00:35<00:23, 11.94s/it]

Epoch 3: train loss 0.005314883974116917, accuracy 0.9497666666666666


 80%|████████  | 4/5 [00:47<00:11, 11.95s/it]

Epoch 4: train loss 0.005064097034228811, accuracy 0.9516


100%|██████████| 5/5 [00:59<00:00, 11.92s/it]


Epoch 5: train loss 0.005041513962011474, accuracy 0.9532333333333334
Client 1 is training


 20%|██        | 1/5 [00:09<00:36,  9.19s/it]

Epoch 1: train loss 0.010801703441888093, accuracy 0.8979333333333334


 40%|████      | 2/5 [00:18<00:27,  9.08s/it]

Epoch 2: train loss 0.009832557124768694, accuracy 0.9045666666666666


 60%|██████    | 3/5 [00:27<00:18,  9.04s/it]

Epoch 3: train loss 0.009525445023303231, accuracy 0.9094


 80%|████████  | 4/5 [00:36<00:08,  8.99s/it]

Epoch 4: train loss 0.009337886693701149, accuracy 0.9086333333333333


100%|██████████| 5/5 [00:45<00:00,  9.01s/it]

Epoch 5: train loss 0.009159977478409806, accuracy 0.9106
Round 9 - Centralize Evaluation





Num samples 10000 | Loss 0.07455817804336548 | Acc 0.1517
Client 0 is training


 20%|██        | 1/5 [00:09<00:36,  9.01s/it]

Epoch 1: train loss 0.006376589043609177, accuracy 0.9395333333333333


 40%|████      | 2/5 [00:17<00:26,  8.94s/it]

Epoch 2: train loss 0.005222002388280816, accuracy 0.9500666666666666


 60%|██████    | 3/5 [00:26<00:17,  8.93s/it]

Epoch 3: train loss 0.004941313981823623, accuracy 0.9535333333333333


 80%|████████  | 4/5 [00:35<00:08,  8.97s/it]

Epoch 4: train loss 0.004835334926005453, accuracy 0.9534


100%|██████████| 5/5 [00:44<00:00,  8.95s/it]


Epoch 5: train loss 0.004653562675012896, accuracy 0.9549666666666666
Client 1 is training


 20%|██        | 1/5 [00:08<00:35,  8.96s/it]

Epoch 1: train loss 0.010495063734302918, accuracy 0.9015666666666666


 40%|████      | 2/5 [00:17<00:26,  8.96s/it]

Epoch 2: train loss 0.009591942767240107, accuracy 0.9077333333333333


 60%|██████    | 3/5 [00:26<00:17,  8.95s/it]

Epoch 3: train loss 0.009273194531351328, accuracy 0.9116333333333333


 80%|████████  | 4/5 [00:35<00:08,  8.93s/it]

Epoch 4: train loss 0.008979927549511194, accuracy 0.9131


100%|██████████| 5/5 [00:44<00:00,  8.94s/it]

Epoch 5: train loss 0.008964624838158488, accuracy 0.9134
Round 10 - Centralize Evaluation





Num samples 10000 | Loss 0.07438940155506134 | Acc 0.1538
