In [1]:
import torch
from torchsummary import summary

import numpy as np

# from dataloader import get_dataset
from models_server import ResNet50
# from models_server_bigResNet import ResNet50
from models_client import ResNet8

from server import GKTServerTrainer
from client import GKTClientTrainer

In [2]:
# parameters
normalization_type = "BatchNorm" #BatchNorm or GroupNorm
iid = 0 # if the data is i.i.d or not
unbalanced = 0 # in non i.i.d. setting split the data between clients equally or not
num_users = 100 # number of client
frac = 0.1 # fraction of the clients to be used for federated updates
server_epochs = 10
gpu = 0
optimizer = "sgd" #sgd or adam
local_batch_size = 128 # batch size of local updates in each user
lr = 0.001 # learning rate
client_epochs = 1
loss_function = "CrossEntropyLoss"

partition_alpha = 0.5
client_number = num_users  # number of workers in a distributed cluster
# the data will be partitioned in client_number groups

temperature = 3.0

communication_rounds = 10  # number of communication rounds

In [3]:
if iid:
    from utils_v2 import get_dataset, average_weights, exp_details
else:
    from utils import get_dataset, average_weights, exp_details

In [4]:
train_dataset, test_dataset, user_groups = get_dataset(iid=iid, unbalanced=unbalanced, num_users=num_users)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
def create_client_model():
    # client_model = RenNet8(normalization_type)
    client_model = ResNet8()
    return client_model

def create_server_model():
    server_model = ResNet50() # actually is a ResNet49 the first layer is done by the ResNet8
    return server_model

In [6]:
server_model = create_server_model()
client_model = create_client_model()

In [7]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
    device = torch.device("cpu")
    gpu = 0
else:
    print('CUDA is available!  Training on GPU ...')
    device = torch.device("cuda")
    gpu = 1

server_model.to(device)
client_model.to(device)

# set the models to train
server_model.train()
client_model.train()

CUDA is available!  Training on GPU ...


ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1),

In [8]:
# 23,503,626 parameters (ResNet49)
# 23,520,842 parameters (ResNet50)
# summary(server_model, (16, 32, 32))

In [9]:
# summary(client_model, (3, 32, 32))

In [10]:
# init server
server_trainer = GKTServerTrainer(server_model, num_users, lr, server_epochs, device,
                                  optimizer, temperature)
clients_trainer = []  # list of client_trainer

# different clients at each epoch

idxs_users = range(num_users)

# init all clients
for idx in idxs_users:
    client_trainer = GKTClientTrainer(client_model, train_dataset, test_dataset,
                                      user_groups[idx], idx, gpu, optimizer, local_batch_size,
                                      lr, client_epochs, temperature, partition_alpha)
    clients_trainer.append(client_trainer)
    # print(f"client \t{idx}/{num_users} initialized")

for communication_round in range(1, communication_rounds+1):
    print(f'\nCommunication Round: {communication_round} \n')

    m = max(int(frac * num_users), 1) # number of users to be used for federated updates, at least 1
    idxs_chosen_users = np.random.choice(range(num_users), m, replace=False) # choose randomly m users

    print(idxs_chosen_users)
    for idx in idxs_chosen_users:
        # the server broadcast k-th Z_c to the client
        extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test,\
        labels_dict_test = clients_trainer[idx].train()

        # send client result to server
        server_trainer.add_local_trained_result(idx, extracted_feature_dict, logits_dict, labels_dict,
                                                extracted_feature_dict_test, labels_dict_test)

    # # check if all updates are received
    # b_all_received = server_trainer.check_whether_all_receive()
    # print("b_all received" + str(b_all_received))
    #
    # if b_all_received:
    #     server_trainer.train(communication_round)

    server_trainer.train(communication_round, idxs_chosen_users)

    for idx in idxs_chosen_users:
        # get global logits
        global_logits = server_trainer.get_global_logits(idx)

        # print(type(global_logits))
        # print(len(global_logits))
        # print(global_logits)

        # send global logits to client
        clients_trainer[idx].update_large_model_logits(global_logits)

# get lists of train loss and accuracy
train_loss, train_accuracy = server_trainer.get_loss_acc_list()


Communication Round: 1 

[99 44 90 32 42 83 29 26 58 10]


  return torch.tensor(image), torch.tensor(label)


client 99 - Update Epoch: 1 - Loss: 1.808949
add model - client_id = 99
client 44 - Update Epoch: 1 - Loss: 2.372289
add model - client_id = 44
client 90 - Update Epoch: 1 - Loss: 2.290651
add model - client_id = 90
client 32 - Update Epoch: 1 - Loss: 1.983563
add model - client_id = 32
client 42 - Update Epoch: 1 - Loss: 2.409932
add model - client_id = 42
client 83 - Update Epoch: 1 - Loss: 2.324446
add model - client_id = 83
client 29 - Update Epoch: 1 - Loss: 2.193654
add model - client_id = 29
client 26 - Update Epoch: 1 - Loss: 2.182953
add model - client_id = 26
client 58 - Update Epoch: 1 - Loss: 2.241385
add model - client_id = 58
client 10 - Update Epoch: 1 - Loss: 2.819190
add model - client_id = 10
train_and_eval - round_idx = 1, epoch = 1

- TRAIN METRICS: train_loss: 3.441 ; train_accuracy: 28.912

train_and_eval - round_idx = 1, epoch = 2

- TRAIN METRICS: train_loss: 3.129 ; train_accuracy: 8.917

train_and_eval - round_idx = 1, epoch = 3

- TRAIN METRICS: train_loss: 2




Communication Round: 2 

[22 98 40 50 14 97 12 55 88  6]
client 22 - Update Epoch: 1 - Loss: 2.488237
add model - client_id = 22
client 98 - Update Epoch: 1 - Loss: 2.141943
add model - client_id = 98
client 40 - Update Epoch: 1 - Loss: 2.286222
add model - client_id = 40
client 50 - Update Epoch: 1 - Loss: 2.450955
add model - client_id = 50
client 14 - Update Epoch: 1 - Loss: 2.445546
add model - client_id = 14
client 97 - Update Epoch: 1 - Loss: 1.978880
add model - client_id = 97
client 12 - Update Epoch: 1 - Loss: 2.151867
add model - client_id = 12
client 55 - Update Epoch: 1 - Loss: 1.872464
add model - client_id = 55
client 88 - Update Epoch: 1 - Loss: 2.209647
add model - client_id = 88
client 6 - Update Epoch: 1 - Loss: 2.571182
add model - client_id = 6
train_and_eval - round_idx = 2, epoch = 1

- TRAIN METRICS: train_loss: 2.527 ; train_accuracy: 26.533

train_and_eval - round_idx = 2, epoch = 2

- TRAIN METRICS: train_loss: 2.077 ; train_accuracy: 30.925

train_and_eval -

In [11]:
print("Exectution completed!")

# TODO print summary of parameters used

# TODO print last accuracy

# TODO some graphs

Exectution completed!


In [12]:
# save train loss and accuracy
import pandas as pd

file_name = 'fegGKT_results/{}_{}_{}_lr_[{}]_C[{}]_iid[{}]_Es[{}]_Ec[{}]_B[{}]_{}_unbalanced[{}].csv'.\
    format("ResNet50", normalization_type, communication_rounds, lr, frac, iid,
           server_epochs, client_epochs, local_batch_size, optimizer, unbalanced)

data = list(zip(train_loss, train_accuracy))
pd.DataFrame(data, columns=['train_loss','train_accuracy']).to_csv(file_name)