In [None]:
import torch
from torchsummary import summary

import numpy as np

from dataloader import get_dataset
from models_server import ResNet50
# from models_server_bigResNet import ResNet50
from models_client import ResNet8

from server import GKTServerTrainer
from client import GKTClientTrainer

In [None]:
# parameters
normalization_type = "GroupNorm" #BatchNorm or GroupNorm
iid = 1 # if the data is i.i.d or not
unbalanced = 0 # in non i.i.d. setting split the data between clients equally or not
num_users = 100 # number of client
frac = 0.1 # fraction of the clients to be used for federated updates
server_epochs = 3
gpu = 0
optimizer = "sgd" #sgd or adam
local_batch_size = 10 # batch size of local updates in each user
lr = 0.001 # learning rate
client_epochs = 1
loss_function = "CrossEntropyLoss"

partition_alpha = 0.5
client_number = num_users  # number of workers in a distributed cluster
# the data will be partitioned in client_number groups

temperature = 3.0

communication_rounds = 2  # number of communication rounds

In [None]:
train_dataset, test_dataset, user_groups = get_dataset(iid=iid, unbalanced=unbalanced, num_users=num_users)

In [None]:
def create_client_model():
    # client_model = RenNet8(normalization_type)
    client_model = ResNet8()
    return client_model

def create_server_model():
    server_model = ResNet50() # actually is a ResNet49 the first layer is done by the ResNet8
    return server_model

In [None]:
server_model = create_server_model()
client_model = create_client_model()

In [None]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
    device = torch.device("cpu")
    gpu = 0
else:
    print('CUDA is available!  Training on GPU ...')
    device = torch.device("cuda")
    gpu = 1

server_model.to(device)
client_model.to(device)

# set the models to train
server_model.train()
client_model.train()

In [None]:
# 23,503,626 parameters (ResNet49)
# 23,520,842 parameters (ResNet50)
# summary(server_model, (16, 32, 32))

In [None]:
# summary(client_model, (3, 32, 32))

In [None]:
# init server
server_trainer = GKTServerTrainer(server_model, num_users, lr, server_epochs, device,
                                  optimizer, temperature)
clients_trainer = []  # list of client_trainer

# different clients at each epoch

idxs_users = range(num_users)

# init all clients
for idx in idxs_users:
    client_trainer = GKTClientTrainer(client_model, train_dataset, test_dataset,
                                      user_groups[idx], idx, gpu, optimizer, local_batch_size,
                                      lr, client_epochs, temperature, partition_alpha)
    clients_trainer.append(client_trainer)
    # print(f"client \t{idx}/{num_users} initialized")

for communication_round in range(1, communication_rounds+1):
    print(f'\nCommunication Round: {communication_round} \n')

    m = max(int(frac * num_users), 1) # number of users to be used for federated updates, at least 1
    idxs_chosen_users = np.random.choice(range(num_users), m, replace=False) # choose randomly m users

    print(idxs_chosen_users)
    for idx in idxs_chosen_users:
        # the server broadcast k-th Z_c to the client
        extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test,
        labels_dict_test = clients_trainer[idx].train()

        # send client result to server
        server_trainer.add_local_trained_result(idx, extracted_feature_dict, logits_dict, labels_dict,
                                                extracted_feature_dict_test, labels_dict_test)

    # # check if all updates are received
    # b_all_received = server_trainer.check_whether_all_receive()
    # print("b_all received" + str(b_all_received))
    #
    # if b_all_received:
    #     server_trainer.train(communication_round)

    server_trainer.train(communication_round, idxs_chosen_users)

    for idx in idxs_chosen_users:
        # get global logits
        global_logits = server_trainer.get_global_logits(idx)

        # print(type(global_logits))
        # print(len(global_logits))
        # print(global_logits)

        # send global logits to client
        clients_trainer[idx].update_large_model_logits(global_logits)

# get lists of train loss and accuracy
train_loss, train_accuracy = server_trainer.get_loss_acc_list()

In [None]:
print("Exectution completed!")

# TODO print summary of parameters used

# TODO print last accuracy

# TODO some graphs

In [None]:
# save train loss and accuracy
import pandas as pd

file_name = 'fedGKT_results/{}_{}_{}_lr_[{}]_C[{}]_iid[{}]_Es[{}]_Ec[{}]_B[{}]_{}_unbalanced[{}].csv'.\
    format("ResNet50", normalization_type, communication_rounds, lr, frac, iid,
           server_epochs, client_epochs, local_batch_size, optimizer, unbalanced)

data = list(zip(train_loss, train_accuracy))
pd.DataFrame(data, columns=['train_loss','train_accuracy']).to_csv(file_name)