In [13]:
# implementation of FedGKT
# Group Knowledge Transfer: Federated Learning of Large CNNs at the Edge by Chaoyang He, Murali Annavaram, Salman Avestimehr
# https://papers.nips.cc/paper/2020/hash/a1d4c20b182ad7137ab3606f0e3fc8a4-Abstract.html
# https://github.com/FedML-AI/FedML/tree/master/fedml_experiments/distributed/fedgkt

import torch
from torchsummary import summary

import numpy as np

# from dataloader import get_dataset
from models_server import ResNet50
# from models_server_bigResNet import ResNet50
from models_client import ResNet8

from server import GKTServerTrainer
from client import GKTClientTrainer

In [14]:
# parameters
iid = 0 # if the data is i.i.d or not
unbalanced = 0 # in non i.i.d. setting split the data between clients equally or not
num_users = 100 # number of client
frac = 0.1 # fraction of the clients to be used for federated updates
server_epochs = 10
gpu = 0
optimizer = "sgd" #sgd or adam
local_batch_size = 128 # batch size of local updates in each user
lr = 0.001 # learning rate
client_epochs = 1
loss_function = "CrossEntropyLoss"

partition_alpha = 0.5
client_number = num_users  # number of workers in a distributed cluster
# the data will be partitioned in client_number groups

temperature = 3.0

communication_rounds = 10  # number of communication rounds

num_groups = 0  # 0 for BatchNorm, > 0 for GroupNorm
if num_groups == 0:
    normalization_type = "BatchNorm"
else:
    normalization_type = "GroupNorm"

In [15]:
if iid:
    from utils_v2 import get_dataset, average_weights, exp_details
else:
    from utils import get_dataset, average_weights, exp_details

In [16]:
train_dataset, test_dataset, user_groups = get_dataset(iid=iid, unbalanced=unbalanced, num_users=num_users)

Files already downloaded and verified
Files already downloaded and verified


In [17]:
def create_client_model():
    # client_model = RenNet8(normalization_type)
    client_model = ResNet8()
    return client_model

def create_server_model():
    server_model = ResNet50(n_type=normalization_type) # actually is a ResNet49 the first layer is done by the ResNet8
    return server_model

In [18]:
server_model = create_server_model()
client_model = create_client_model()

In [19]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
    device = torch.device("cpu")
    gpu = 0
else:
    print('CUDA is available!  Training on GPU ...')
    device = torch.device("cuda")
    gpu = 1

server_model.to(device)
client_model.to(device)

# set the models to train
server_model.train()
client_model.train()

CUDA is available!  Training on GPU ...


ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1),

In [20]:
# 23,503,626 parameters (ResNet49)
# 23,520,842 parameters (ResNet50)
summary(server_model, (16, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,024
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]          36,864
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5          [-1, 256, 32, 32]          16,384
       BatchNorm2d-6          [-1, 256, 32, 32]             512
            Conv2d-7          [-1, 256, 32, 32]           4,096
       BatchNorm2d-8          [-1, 256, 32, 32]             512
        Bottleneck-9          [-1, 256, 32, 32]               0
           Conv2d-10           [-1, 64, 32, 32]          16,384
      BatchNorm2d-11           [-1, 64, 32, 32]             128
           Conv2d-12           [-1, 64, 32, 32]          36,864
      BatchNorm2d-13           [-1, 64, 32, 32]             128
           Conv2d-14          [-1, 256,

In [21]:
# 10,586 parameters (ResNet8)
summary(client_model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 16, 32, 32]             256
       BatchNorm2d-5           [-1, 16, 32, 32]              32
              ReLU-6           [-1, 16, 32, 32]               0
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
              ReLU-9           [-1, 16, 32, 32]               0
           Conv2d-10           [-1, 64, 32, 32]           1,024
      BatchNorm2d-11           [-1, 64, 32, 32]             128
           Conv2d-12           [-1, 64, 32, 32]           1,024
      BatchNorm2d-13           [-1, 64, 32, 32]             128
             ReLU-14           [-1, 64,

In [22]:
# init server
server_trainer = GKTServerTrainer(server_model, num_users, lr, server_epochs, device,
                                  optimizer, temperature)
clients_trainer = []  # list of client_trainer

# different clients at each epoch

idxs_users = range(num_users)

# init all clients
for idx in idxs_users:
    client_trainer = GKTClientTrainer(client_model, train_dataset, test_dataset,
                                      user_groups[idx], idx, gpu, optimizer, local_batch_size,
                                      lr, client_epochs, temperature, partition_alpha)
    clients_trainer.append(client_trainer)
    # print(f"client \t{idx}/{num_users} initialized")

for communication_round in range(1, communication_rounds+1):
    print(f'\nCommunication Round: {communication_round} \n')

    m = max(int(frac * num_users), 1) # number of users to be used for federated updates, at least 1
    idxs_chosen_users = np.random.choice(range(num_users), m, replace=False) # choose randomly m users

    print(idxs_chosen_users)
    for idx in idxs_chosen_users:
        # the server broadcast k-th Z_c to the client
        extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test,\
        labels_dict_test = clients_trainer[idx].train()

        # send client result to server
        server_trainer.add_local_trained_result(idx, extracted_feature_dict, logits_dict, labels_dict,
                                                extracted_feature_dict_test, labels_dict_test)

    # # check if all updates are received
    # b_all_received = server_trainer.check_whether_all_receive()
    # print("b_all received" + str(b_all_received))
    #
    # if b_all_received:
    #     server_trainer.train(communication_round)

    server_trainer.train(communication_round, idxs_chosen_users)

    for idx in idxs_chosen_users:
        # get global logits
        global_logits = server_trainer.get_global_logits(idx)

        # print(type(global_logits))
        # print(len(global_logits))
        # print(global_logits)

        # send global logits to client
        clients_trainer[idx].update_large_model_logits(global_logits)

# get lists of train loss and accuracy
train_loss, train_accuracy = server_trainer.get_loss_acc_list()


Communication Round: 1 

[17 67 57 29 63 64 53 37 28 33]


  return torch.tensor(image), torch.tensor(label)


client 17 - Update Epoch: 1 - Loss: 1.711598
add model - client_id = 17
client 67 - Update Epoch: 1 - Loss: 3.087377
add model - client_id = 67
client 57 - Update Epoch: 1 - Loss: 1.474509
add model - client_id = 57
client 29 - Update Epoch: 1 - Loss: 2.277935
add model - client_id = 29
client 63 - Update Epoch: 1 - Loss: 2.056274
add model - client_id = 63
client 64 - Update Epoch: 1 - Loss: 2.642022
add model - client_id = 64
client 53 - Update Epoch: 1 - Loss: 2.647139
add model - client_id = 53
client 37 - Update Epoch: 1 - Loss: 2.668604
add model - client_id = 37
client 28 - Update Epoch: 1 - Loss: 2.154992
add model - client_id = 28
client 33 - Update Epoch: 1 - Loss: 2.687364
add model - client_id = 33
train_and_eval - round_idx = 1, epoch = 1

- TRAIN METRICS: train_loss: 4.870 ; train_accuracy: 19.686

train_and_eval - round_idx = 1, epoch = 2

- TRAIN METRICS: train_loss: 3.225 ; train_accuracy: 15.247

train_and_eval - round_idx = 1, epoch = 3

- TRAIN METRICS: train_loss: 




Communication Round: 2 

[46 82 57  4 95 28 11 70 36 79]
client 46 - Update Epoch: 1 - Loss: 2.385154
add model - client_id = 46
client 82 - Update Epoch: 1 - Loss: 2.421107
add model - client_id = 82
client 57 - Update Epoch: 1 - Loss: 1.924684
add model - client_id = 57
client 4 - Update Epoch: 1 - Loss: 2.244481
add model - client_id = 4
client 95 - Update Epoch: 1 - Loss: 2.152624
add model - client_id = 95
client 28 - Update Epoch: 1 - Loss: 2.045080
add model - client_id = 28
client 11 - Update Epoch: 1 - Loss: 2.126417
add model - client_id = 11
client 70 - Update Epoch: 1 - Loss: 2.546243
add model - client_id = 70
client 36 - Update Epoch: 1 - Loss: 1.985712
add model - client_id = 36
client 79 - Update Epoch: 1 - Loss: 2.376925
add model - client_id = 79
train_and_eval - round_idx = 2, epoch = 1

- TRAIN METRICS: train_loss: 2.396 ; train_accuracy: 19.218

train_and_eval - round_idx = 2, epoch = 2

- TRAIN METRICS: train_loss: 2.256 ; train_accuracy: 20.424

train_and_eval -

In [23]:
print("Exectution completed!")

# TODO print summary of parameters used

# TODO print last accuracy

# TODO some graphs

Exectution completed!


In [24]:
# save train loss and accuracy
import pandas as pd

file_name = 'fedGKT_results/new_random_seed/{}_{}_{}_lr_[{}]_C[{}]_iid[{}]_Es[{}]_Ec[{}]_B[{}]_{}_unbalanced[{}].csv'.\
    format("ResNet50", normalization_type, communication_rounds, lr, frac, iid,
           server_epochs, client_epochs, local_batch_size, optimizer, unbalanced)

data = list(zip(train_loss, train_accuracy))
pd.DataFrame(data, columns=['train_loss','train_accuracy']).to_csv(file_name)