In [1]:
# implementation of FedProx
# Federated Optimization in Heterogeneous Networks by Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, Virginia Smith
# https://arxiv.org/abs/1812.06127
# https://github.com/ayushm-agrawal/Federated-Learning-Implementations

import copy

import torch
from torchsummary import summary

import numpy as np
import pandas as pd
import random

from models import ResNet50
# from utils import get_dataset, average_weights, exp_details
# from utils_v2 import get_dataset, average_weights, exp_details
from update_FedProx import LocalUpdate, test_inference

In [2]:
# parameters
iid = 1 # if the data is i.i.d or not
unbalanced = 0 # in non i.i.d. setting split the data between clients equally or not
num_users = 100 # number of client
frac = 0.1 # fraction of the clients to be used for federated updates
n_epochs = 30
gpu = 0
optimizer = "sgd" #sgd or adam
local_batch_size = 10 # batch size of local updates in each user
lr = 0.01 # learning rate
local_epochs = 10
loss_function = "CrossEntropyLoss"

percentage = 50  # percentage of strugglers

mu = 0.01  # proximal term constant

num_groups = 0  # 0 for BatchNorm, > 0 for GroupNorm
if num_groups == 0:
    normalization_type = "BatchNorm"
else:
    normalization_type = "GroupNorm"

In [3]:
if iid:
    from utils_v2 import get_dataset, average_weights, weighted_average_weights, exp_details
else:
    from utils import get_dataset, average_weights, weighted_average_weights, exp_details

In [4]:
exp_details("ResNet50", optimizer, lr, normalization_type, n_epochs, iid, frac,
            local_batch_size, local_epochs, unbalanced, num_users)


Experimental details:
    Model     : ResNet50
    Optimizer : sgd
    Learning  : 0.01
    Normalization  : BatchNorm
    Global Rounds   : 30

    Federated parameters:
    IID
    NUmber of users  : 100
    Fraction of users  : 0.1
    Local Batch size   : 10
    Local Epochs       : 10



In [5]:
# for REPRODUCIBILITY https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(0)

g = torch.Generator()
g.manual_seed(0)

np.random.seed(0)

In [6]:
train_dataset, test_dataset, user_groups = get_dataset(iid=iid, unbalanced=unbalanced,
                                                       num_users=num_users)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
model = ResNet50(n_type=normalization_type)
# model = CNNCifar()

train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
    device = torch.device("cpu")
    gpu = 0
else:
    print('CUDA is available!  Training on GPU ...')
    device = torch.device("cuda")
    gpu = 1

model.to(device)

# set the model to train
model.train()

CUDA is available!  Training on GPU ...


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (

In [8]:
def GenerateLocalEpochs(percentage, size, max_epochs):
    ''' Method generates list of epochs for selected clients
    to replicate system heteroggeneity

    Params:
      percentage: percentage of clients to have fewer than E epochs
      size:       total size of the list
      max_epochs: maximum value for local epochs

    Returns:
      List of size epochs for each Client Update

    '''

    # if percentage is 0 then each client runs for E epochs
    if percentage == 0:
        return np.array([max_epochs]*size)
    else:
        # get the number of clients to have fewer than E epochs
        heterogenous_size = int((percentage/100) * size)

        # generate random uniform epochs of heterogenous size between 1 and E
        epoch_list = np.random.randint(1, max_epochs, heterogenous_size)

        # the rest of the clients will have E epochs
        remaining_size = size - heterogenous_size
        rem_list = [max_epochs]*remaining_size

        epoch_list = np.append(epoch_list, rem_list, axis=0)

        # shuffle the list and return
        np.random.shuffle(epoch_list)

        return epoch_list

In [9]:
# copy weights
global_weights = model.state_dict()

In [None]:
test_acc_list = []  # final accuracy is the mean of the last 10
test_loss_list = []

# training
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []

for epoch in range(1, n_epochs+1):
    local_weights = []
    local_losses = []

    print(f'Epoch: {epoch} \n')



    ###################
    # train the model #
    ###################
    model.train()

    # different clients at each epoch
    m = max(int(frac * num_users), 1) # number of users to be used for federated updates, at least 1
    idxs_users = np.random.choice(range(num_users), m, replace=False) # choose randomly m users

    # define how many local epochs for each client
    heterogenous_epoch_list = GenerateLocalEpochs(percentage, size=m, max_epochs=local_epochs)

    for idx, i in zip(idxs_users, range(len(heterogenous_epoch_list))):  # for each user
        # get local model
        local_model = LocalUpdate(dataset=train_dataset, idxs=user_groups[idx], mu=mu,
                                  gpu=gpu, optimizer=optimizer,
                                  local_batch_size=local_batch_size, lr=lr,
                                  local_epochs=heterogenous_epoch_list[i], loss_function=loss_function)

        # get updated weight and loss from local model
        w, loss = local_model.update_weights(model=copy.deepcopy(model), # pass the global model to the clients
                                             global_round=epoch)

        print('| Client : {} | Average Loss: {:.4f} '.format(
            idx, loss))

        local_weights.append(copy.deepcopy(w))
        local_losses.append(copy.deepcopy(loss))

    # compute global weights (average of local weights)
    if unbalanced:
        global_weights = weighted_average_weights(local_weights, user_groups, idxs_users)
    else:
        global_weights = average_weights(local_weights)

    # update weights of the global model
    model.load_state_dict(global_weights)

    # compute average loss
    loss_avg = sum(local_losses) / len(local_losses)
    train_loss.append(loss_avg)

    ######################
    # validate the model #
    ######################
    model.eval()
    # calculate avg training accuracy over all users at every epoch
    list_acc, list_loss = [], []
    for client in range(num_users): # for each client
        # get local model
        local_model = LocalUpdate(dataset=train_dataset, idxs=user_groups[client], mu=mu,
                                  gpu=gpu, optimizer=optimizer,
                                  local_batch_size=local_batch_size, lr=lr,
                                  local_epochs=local_epochs, loss_function=loss_function)

        # get accuracy and loss of local model
        acc, loss = local_model.inference(model=model)
        list_acc.append(acc)
        list_loss.append(loss)

    # compute average accuracy
    train_accuracy.append(sum(list_acc)/len(list_acc))

    # print stats
    print(f'\nAverage training statistics (global epoch) : {epoch}')
    print(f'|---- Trainig Loss : {np.mean(np.array(train_loss))}')
    print('|---- Training Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

    if epoch in range(n_epochs - 9, n_epochs+1):
        test_acc, test_loss = test_inference(model=model, test_dataset=test_dataset, gpu=gpu,
                                             loss_function=loss_function)
        test_acc_list.append(test_acc)
        test_loss_list.append(test_loss)

Epoch: 1 



  return torch.tensor(image), torch.tensor(label)


| Global Round : 1 | Local Epoch : 1 | Train Loss: 6.2518 | Train Accuracy: 0.12
| Global Round : 1 | Local Epoch : 2 | Train Loss: 4.1298 | Train Accuracy: 0.12
| Global Round : 1 | Local Epoch : 3 | Train Loss: 3.7448 | Train Accuracy: 0.12
| Global Round : 1 | Local Epoch : 4 | Train Loss: 3.6008 | Train Accuracy: 0.16
| Global Round : 1 | Local Epoch : 5 | Train Loss: 3.5955 | Train Accuracy: 0.18
| Global Round : 1 | Local Epoch : 6 | Train Loss: 3.4919 | Train Accuracy: 0.20
| Global Round : 1 | Local Epoch : 7 | Train Loss: 3.4167 | Train Accuracy: 0.16
| Global Round : 1 | Local Epoch : 8 | Train Loss: 3.5089 | Train Accuracy: 0.17
| Global Round : 1 | Local Epoch : 9 | Train Loss: 3.4362 | Train Accuracy: 0.19


In [None]:
# # training loss
# train_loss = []
#
# # test accuracy
# test_acc = []
#
# # store last loss for convergence
# last_loss = 0.0
#
# for epoch in range(1, n_epochs+1):
#     w, local_loss, lst_local_train_time = [], [], []
#
#     # different clients at each epoch
#     m = max(int(frac * num_users), 1) # number of users to be used for federated updates, at least 1
#     idxs_users = np.random.choice(range(num_users), m, replace=False) # choose randomly m users
#
#     heterogenous_epoch_list = GenerateLocalEpochs(percentage, size=m, max_epochs=local_epochs)
#
#     for idx, i in zip(idxs_users, range(len(heterogenous_epoch_list))):  # for each user
#         # get local model
#         local_model = LocalUpdate(dataset=train_dataset, idxs=user_groups[idx], mu=mu,
#                                   gpu=gpu, optimizer=optimizer,
#                                   local_batch_size=local_batch_size, lr=lr,
#                                   local_epochs=heterogenous_epoch_list[i], loss_function=loss_function)
#         # get updated weight and loss from local model
#         weights, loss = local_model.update_weights(model=copy.deepcopy(model), # pass the global model to the clients
#                                              global_round=epoch)
#
#         w.append(copy.deepcopy(weights))
#         local_loss.append(copy.deepcopy(loss))
#
#         # compute global weights (average of local weights)
#         if unbalanced:
#             global_weights = weighted_average_weights(w, user_groups, idxs_users)
#         else:
#             global_weights = average_weights(w)
#
#
#         # update weights of the global model
#         model.load_state_dict(global_weights)


In [None]:
# save the trained model

filename_pt = 'fedProx_results/{}_{}_{}_lr_[{}]_C[{}]_iid[{}]_unbalanced[{}]_E[{}]_B[{}]_{}_numGroups[{}]_percentage[{}].pt'\
    .format("ResNet50", n_epochs, optimizer, lr, frac, iid, unbalanced,
            local_epochs, local_batch_size, normalization_type, num_groups, percentage)
torch.save(model.state_dict(), filename_pt)

In [None]:
# save the trained model

filename_csv = 'fedProx_results/{}_{}_{}_lr_[{}]_C[{}]_iid[{}]_unbalanced[{}]_E[{}]_B[{}]_{}_numGroups[{}]_percentage[{}].csv'\
    .format("ResNet50", n_epochs, optimizer, lr, frac, iid, unbalanced,
            local_epochs, local_batch_size, normalization_type, num_groups, percentage)

data = list(zip(train_loss, train_accuracy))
pd.DataFrame(data, columns=['train_loss','train_accuracy']).to_csv(filename_csv)

In [None]:
# test the trained model

test_acc, test_loss = test_inference(model=model, test_dataset=test_dataset, gpu=gpu,
                                     loss_function=loss_function)

print(f'\nResults after {n_epochs} global rounds of training:')
print("|---- Avgerage Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

In [None]:
print("\n\n")

print(f'\nResults after {n_epochs} global rounds of training:')
print("|---- Test Loss: {:.2f}".format((sum(test_loss_list) / len(test_loss_list))))
print("|---- Test Accuracy: {:.2f}%".format(100 * (sum(test_acc_list) / len(test_acc_list))))