In [1]:
import torch
import torchvision
import pandas as pd
import os
import copy

In [2]:
os.chdir('../src/aggregator')

In [3]:
# args
args = {
        "architecture": 'TwoLayerNet',
        "test_batch_size": 1000,
        "agg_epochs": 10,
        "device": "cuda:0",
        "momentum": 0.5,
        "log_interval": 50,
        "aggregator": "fedavg",
        "agg_iterations": 50,
        "agg_optim_lr": 'Adam',
        "agg_optim_momentum": 0.01,
        "wandb": False,
        "num_of_nodes": 1,
        "test_csv": '../../csv/test.csv',
        "data_location": '../../x-ray',
        "aggregated_model_location": '../../aggregated_model/',
        "labels": ['NORMAL', 'PNEUMONIA'],
        "image_dim": (28, 28),
        "lr": 0.001,
        "epochs": 1,
        "train_batch_size": 16,
        "node_name": 'node_0',
        "app_ip": 'localhost',
        "train_csv": '../../csv/train.csv',
        "model_location": '../../node_model/'
    }

train_csv_list = ['../../csv/original_train.csv']

In [4]:
from modelloader import createInitialModel, getModelArchitecture, loadModel

In [5]:
# init model - agg model
agg_model = getModelArchitecture(args)

In [6]:
# create dataloader for n nodes
from dataloader import getNumSamples, getTestLoader, getTrainLoader
trainloaders = []
node_samples = []
for i in range(args['num_of_nodes']):
    args['train_csv'] = train_csv_list[i]
    trainloaders.append(getTrainLoader(args))
    node_samples.append(getNumSamples(args))
test_loader = getTestLoader(args)
print(trainloaders)

-----------------
['NORMAL', 'PNEUMONIA']
[<torch.utils.data.dataloader.DataLoader object at 0x7f2417741590>]


In [7]:
# def compare_models(model_1, model_2):
#     models_differ = 0
#     for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
#         if torch.equal(key_item_1[1], key_item_2[1]):
#             pass
#         else:
#             models_differ += 1
#             if (key_item_1[0] == key_item_2[0]):
#                 print('Mismtach found at', key_item_1[0])
#             else:
#                 raise Exception
#     if models_differ == 0:
#         print('Models match perfectly! :)')

In [12]:
# train n nodes 
import train
import testing
from checks import compare_models
from torch import optim
from aggregatorloader import selectAggregator

aggregator = selectAggregator(args)
for agg_epoch in range(args['agg_epochs']):
    node_models = []
    for i in range(args['num_of_nodes']):
        local_model = copy.deepcopy(agg_model)
        optimizer = optim.Adam(local_model.parameters(), lr=args['lr'])
        
        for epoch in range(1, args['epochs'] + 1):
            train.train(logger=None, args=args, model=local_model, train_loader=trainloaders[i], optimizer=optimizer, epoch=epoch)
        testing.test(args, local_model, test_loader, logger=None)

    node_models.append((local_model, node_samples[i]))
    agg_model = aggregator(node_models, args)
    compare_models(agg_model, node_models[0][0])
    # test model performance
    testing.test(args, agg_model, test_loader, logger=None)

    # save model
    torch.save(agg_model, args['aggregated_model_location']+'agg_model.pt')

100%|██████████| 326/326 [01:10<00:00,  4.64it/s]


total_loss: 0.013133432377866138

Test set: Average loss: 0.0012, Accuracy: 475/624 (76%)

Precision:0.7295238095238096	Recall:0.982051282051282	F1:0.8371584699453551
Models match perfectly! :)


  0%|          | 0/326 [00:00<?, ?it/s]


Test set: Average loss: 0.0012, Accuracy: 475/624 (76%)

Precision:0.7295238095238096	Recall:0.982051282051282	F1:0.8371584699453551


100%|██████████| 326/326 [01:10<00:00,  4.61it/s]


total_loss: 0.006560861089539484

Test set: Average loss: 0.0026, Accuracy: 430/624 (69%)

Precision:0.6683848797250859	Recall:0.9974358974358974	F1:0.8004115226337449
Models match perfectly! :)


  0%|          | 0/326 [00:00<?, ?it/s]


Test set: Average loss: 0.0026, Accuracy: 430/624 (69%)

Precision:0.6683848797250859	Recall:0.9974358974358974	F1:0.8004115226337449


 61%|██████    | 199/326 [00:43<00:28,  4.53it/s]


KeyboardInterrupt: 