In [1]:
from fedavg.node import Node
from fedavg import utils
from fedavg.model import Net

from torch import nn

import logging

In [2]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [3]:
config = {
    "participant_n": 10,
    "client_n": 50,
    "criterion": nn.CrossEntropyLoss()
}

In [4]:
train_loaders = utils.split_mnist_to_loaders(config["client_n"])

In [5]:
test_loader = utils.get_mnist_test_loader()

In [6]:
clients = [
    Node(i, model=Net(), train_loader=loader, criterion=config["criterion"])
    for i, loader in enumerate(train_loaders)
]

In [7]:
server = Node(
    node_id=len(clients)+1, 
    model=Net(), 
    clients=clients, 
    test_loader=test_loader,
    criterion=config["criterion"],
    config=config
)

In [8]:
print("Initial test accuracy: {} %".format(server.test()))

Initial test accuracy: 9.83 %


In [9]:
rounds = 30
for r in range(1, rounds + 1):
    server.round(r)

INFO:root:Round 1 Starts.
INFO:root:Selected Clients: [12, 44, 1, 17, 19, 26, 7, 29, 25, 34]
INFO:root:Client 12 Avg Loss: 0.036
INFO:root:Client 44 Avg Loss: 0.036
INFO:root:Client 1 Avg Loss: 0.036
INFO:root:Client 17 Avg Loss: 0.036
INFO:root:Client 19 Avg Loss: 0.036
INFO:root:Client 26 Avg Loss: 0.036
INFO:root:Client 7 Avg Loss: 0.036
INFO:root:Client 29 Avg Loss: 0.036
INFO:root:Client 25 Avg Loss: 0.036
INFO:root:Client 34 Avg Loss: 0.036
INFO:root:Round 1 Test Accuracy: 13.42 %
INFO:root:Round 2 Starts.
INFO:root:Selected Clients: [11, 41, 30, 28, 37, 46, 7, 38, 42, 36]
INFO:root:Client 11 Avg Loss: 0.036
INFO:root:Client 41 Avg Loss: 0.036
INFO:root:Client 30 Avg Loss: 0.036
INFO:root:Client 28 Avg Loss: 0.036
INFO:root:Client 37 Avg Loss: 0.036
INFO:root:Client 46 Avg Loss: 0.036
INFO:root:Client 7 Avg Loss: 0.036
INFO:root:Client 38 Avg Loss: 0.036
INFO:root:Client 42 Avg Loss: 0.036
INFO:root:Client 36 Avg Loss: 0.036
INFO:root:Round 2 Test Accuracy: 20.88 %
INFO:root:Roun

INFO:root:Round 16 Starts.
INFO:root:Selected Clients: [11, 31, 45, 44, 33, 30, 4, 23, 19, 34]
INFO:root:Client 11 Avg Loss: 0.009
INFO:root:Client 31 Avg Loss: 0.010
INFO:root:Client 45 Avg Loss: 0.010
INFO:root:Client 44 Avg Loss: 0.010
INFO:root:Client 33 Avg Loss: 0.011
INFO:root:Client 30 Avg Loss: 0.010
INFO:root:Client 4 Avg Loss: 0.010
INFO:root:Client 23 Avg Loss: 0.010
INFO:root:Client 19 Avg Loss: 0.010
INFO:root:Client 34 Avg Loss: 0.010
INFO:root:Round 16 Test Accuracy: 83.41 %
INFO:root:Round 17 Starts.
INFO:root:Selected Clients: [33, 39, 22, 37, 41, 16, 23, 46, 29, 49]
INFO:root:Client 33 Avg Loss: 0.010
INFO:root:Client 39 Avg Loss: 0.009
INFO:root:Client 22 Avg Loss: 0.010
INFO:root:Client 37 Avg Loss: 0.009
INFO:root:Client 41 Avg Loss: 0.010
INFO:root:Client 16 Avg Loss: 0.009
INFO:root:Client 23 Avg Loss: 0.010
INFO:root:Client 46 Avg Loss: 0.010
INFO:root:Client 29 Avg Loss: 0.009
INFO:root:Client 49 Avg Loss: 0.009
INFO:root:Round 17 Test Accuracy: 84.68 %
INFO:r

In [10]:
print("Test accuracy after training: {} %".format(server.test()))

Test accuracy after training: 91.14 %
