In [2]:
'''
Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved.
This program is free software; you can redistribute it and/or modify
it under the terms of the Apache 2.0 License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Apache 2.0 License for more details.
'''
import sys
sys.path.append("..")

import torch
from utils.aggregate import aggregate, aggregate_lr, zero_model, aggregate_momentum
from algs.individual_train import individual_train

In [4]:
class FedAvg(torch.nn.Module):  # abstract class for FL algorithms
    """subclass should implement the following
    --aggregate(): the server aggregation of models
    --local_updates(): the update of each client"""
    def __init__(self, models, optimizers, num_clients, num_local_epochs, loss_func):
        super(FedAvg, self).__init__()
        self.num_clients = num_clients
        self.num_local_epochs = num_local_epochs
        self.models = models
        self.optimizers = optimizers
        self.loss_func = loss_func
        self.losses = None
    
    def aggregate(self, weights=None):
        aggregate(models, weights=weights)
    
    def local_updates(self, train_loaders, test_loaders, device, output_dir):
        loss_func = self.loss_func
        self.losses = losses(self.models, train_loaders, self.loss_func, device)
        for i in range(self.num_clients):
            individual_train(train_loaders[i], loss_func, self.optimizers[i], self.models[i], \
                             test_loaders[i], \
                         device=device, client_id=i, epochs=self.num_local_epochs, \
                         output_dir=output_dir, show=False, save=False)

In [3]:
from utils.project import project
class AFL(FedAvg):
    def __init__(self, models, optimizers, num_clients, num_local_epochs, \
                 loss_func, lambda_, step_size_lambda=0.1):
        super(AFL, self).__init__(models, optimizers, num_clients, num_local_epochs, loss_func)
        self.lambda_ = lambda_
        self.step_size_lambda = step_size_lambda
    
    def aggregate(self, weights):
        # update lamdba
        self.lambda_ = project(np.array(self.lambda_) + self.step_size_lambda * \
                               np.array(self.losses))
        super(AFL, self).aggregate(weights=self.lambda_)

In [None]:
PropFair(models, optimizers, args.num_clients, args.num_local_epochs, \
             loss_func, base=5.0, epsilon = 0.2)

In [13]:
class PropFair(FedAvg):
    def __init__(self, models, optimizers, num_clients, num_local_epochs, \
             loss_func, base, epsilon = 0.2):
        super(PropFair, self).__init__(models, optimizers, \
                                       num_clients, num_local_epochs, loss_func)
        self.base = base
        self.epsilon = epsilon
        
    def local_updates(self, train_loaders, test_loaders, device, output_dir):
        
        def log_loss(output, target, base=self.base):
            ce_loss = self.loss_func(output, target)
            base = torch.tensor(base).to(device)
            if base - ce_loss < self.epsilon:           
                # for the bad performing batches, we enforce a constant to avoid divergence
                return ce_loss/base
            else:
                return -torch.log(1 - ce_loss/base)
        
        loss_func = log_loss
        for i in range(self.num_clients):
            individual_train(train_loaders[i], loss_func, self.optimizers[i], self.models[i], \
                             test_loaders[i], \
                         device=device, client_id=i, epochs=self.num_local_epochs, \
                         output_dir=output_dir, show=False, save=False)

In [3]:
class qFFL(FedAvg):
    def __init__(self, models, optimizers, num_clients, num_local_epochs, \
             loss_func, q=1.0):
        super(qFFL, self).__init__(models, optimizers, num_clients, num_local_epochs, loss_func)
        self.q = q
        
    def local_updates(self, train_loaders, test_loaders, device, output_dir):
        
        def q_loss(output, target, q=self.q):
            ce_loss = loss_func(output, target)
            return ce_loss ** (q + 1.0) / (q + 1.0)
        
        loss_func = q_loss
        for i in range(self.num_clients):
            individual_train(train_loaders[i], loss_func, self.optimizers[i], self.models[i], \
                             test_loaders[i], \
                         device=device, client_id=i, epochs=self.num_local_epochs, \
                         output_dir=output_dir, show=False, save=False)

In [5]:
from copy import deepcopy
from utils.aggregate import aggregate, sum_models, zero_model, \
                        assign_models, scale_model, sub_models, norm2_model

class qFedAvg(FedAvg):
    def __init__(self, models, optimizers, num_clients, num_local_epochs, \
             loss_func, Lipschitz, q=1.0):
        super(qFedAvg, self).__init__(models, optimizers, num_clients, \
                                      num_local_epochs, loss_func)
        self.q = q
        self.Lipschitz = Lipschitz
        self.old_models = deepcopy(models[0])
        self.losses = None
        
    def local_updates(self, train_loaders, test_loaders, device, output_dir):
        self.old_model = deepcopy(models[0])
        super(qFedAvg, self).local_updates(train_loaders, test_loaders, device, output_dir)
    
    def aggregate(self, weights):
        delta_w = [scale_model(sub_models(self.old_model, model), self.Lipschitz)\
                   for model in self.models]
        Delta = [scale_model(delta_w[i], (self.losses[i] ** self.q)) for i in range(len(delta_w))]
        h = [self.q * (self.losses[i] ** (self.q - 1)) * norm2_model(delta_w[i]) + \
               self.Lipschitz * (self.losses[i] ** self.q) for i in range(len(delta_w))]
        new_model = sub_models(model, scale_model(sum_models(Delta), 1.0 / sum(h)))
        assign_models(self.models, new_model)
        

In [13]:
class TERM(FedAvg):
    def __init__(self, models, optimizers, num_clients, num_local_epochs, \
             loss_func, Lipschitz, alpha=0.1):
        super(qFedAvg, self).__init__(models, optimizers, num_clients, \
                                      num_local_epochs, loss_func)
        self.alpha = alpha
         
    def aggregate(self, weights):
        weights_term = [np.exp(self.alpha * self.losses[i])* weights \
                        for i in range(self.num_clients)]
        weights_term = list(weights_term / np.sum(weights_term))
        super(TERM, self).aggregate(weights=weights_term)

In [6]:
# training
import sys
sys.path.append("..")

import argparse
import os
import torch
from copy import deepcopy
import numpy as np
from torch import optim, nn
from torch.utils.data import DataLoader
import pickle
import threading
from tqdm import tqdm
import json
from utils.io import Tee, to_csv
from utils.eval import accuracy, accuracies, losses
from utils.aggregate import aggregate, aggregate_lr, zero_model, aggregate_momentum
from algs.individual_train import individual_train
from utils.concurrency import multithreads
from models.models import resnet18, CNN, CNN_FEMNIST, RNN_Shakespeare, RNN_StackOverflow
from utils.print import print_acc, round_list
from utils.save import save_acc_loss
from utils.stat import mean_std


root = '..' 

parser = argparse.ArgumentParser(description='training')
parser.add_argument('--device', type=str, default='7')
parser.add_argument('--data_dir', type=str, default='iid-4')
parser.add_argument('--dataset', type=str, default='MNIST')
parser.add_argument('--algorithm', type=str, default='qFedAvg')
parser.add_argument('--num_clients', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=0, help='for data loader')
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--num_local_epochs', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.1)
parser.add_argument('--save_epoch', type=int, default=5)
parser.add_argument('--seed', type=int, default=0)


args = parser.parse_args('')
output_dir = os.path.join(root, 'results', args.dataset, args.data_dir, 'FedAvg', \
                          f'batch_size_{args.batch_size}', f'seed_{args.seed}')
args.data_dir = os.path.join(root, 'data', args.dataset, args.data_dir)
os.makedirs(output_dir, exist_ok=True)
print(args)
print('output_dir: ', output_dir)

with open(os.path.join(output_dir, 'args.json'), 'w') as fp:
    json.dump(vars(args), fp)

os.environ['CUDA_VISIBLE_DEVICES'] = args.device
if torch.cuda.is_available():
    device = torch.device('cuda:0')  # use the first GPU
else:
    device = torch.device('cpu')

in_file = os.path.join(args.data_dir, 'in.pickle')
out_file = os.path.join(args.data_dir, 'out.pickle')

with open(in_file, 'rb') as f_in:
    in_data = pickle.load(f_in)
with open(out_file, 'rb') as f_out:
    out_data = pickle.load(f_out)  

weights = np.array([len(in_data[i]) for i in range(args.num_clients)])
weights_test = np.array([len(out_data[i]) for i in range(args.num_clients)])

print('total train samples: {}'.format(np.sum(weights)))
print('total test samples: {}'.format(np.sum(weights_test)))
print('total samples: {}'.format(np.sum(weights)+np.sum(weights_test)))

print('samples: ', weights)
weights = list(weights / np.sum(weights))

# data loaders
train_loaders = [DataLoader(
    dataset=in_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]

test_loaders = [DataLoader(
    dataset=out_data[i],
    batch_size=args.batch_size,
    num_workers=args.num_workers, drop_last=False, pin_memory=True, shuffle=True)
    for i in range(args.num_clients)]

if args.dataset == 'MNIST':
    models = [CNN().to(device) for _ in range(args.num_clients)]
elif args.dataset == 'CIFAR10':
    models = [resnet18(num_classes=10).to(device)  for _ in range(args.num_clients)]
elif args.dataset == 'CIFAR100':
    models = [resnet18(num_classes=100).to(device)  for _ in range(args.num_clients)]
elif args.dataset == 'CINIC10':
    models = [resnet18(num_classes=10).to(device)  for _ in range(args.num_clients)]
elif args.dataset == 'FEMNIST':
    models = [CNN_FEMNIST().to(device) for _ in range(args.num_clients)]      
elif args.dataset == 'Shakespeare':
    models = [RNN_Shakespeare().to(device)  for _ in range(args.num_clients)]
elif args.dataset == 'StackOverflow':
    models = [RNN_StackOverflow().to(device)  for _ in range(args.num_clients)]     

# loss functions, optimizer
loss_func = nn.CrossEntropyLoss()
#loss_func = nn.MSELoss()
optimizers = [optim.SGD(model.parameters(), lr = args.learning_rate, \
                        momentum=0.0) for model in models]

# checkpoint
model_path = output_dir  + f'/model_last.pth'
if os.path.exists(model_path):
    start_epoch = torch.load(model_path)['epoch']
    for model in models:
        model.load_state_dict(torch.load(model_path)['state_dict'])
else:
    start_epoch = 0

json_file = os.path.join(output_dir, 'log.json')
with open(json_file, 'w') as f:
    f.write('')

if args.algorithm == 'FedAvg':
    alg = FedAvg(models, optimizers, args.num_clients, args.num_local_epochs, loss_func)
elif args.algorithm == 'AFL':
    alg = AFL(models, optimizers, args.num_clients, args.num_local_epochs, loss_func, \
              lambda_=weights, step_size_lambda=0.1)
elif args.algorithm == 'PropFair':
    alg = PropFair(models, optimizers, args.num_clients, args.num_local_epochs, \
             loss_func, base=5.0, epsilon = 0.2)
elif args.algorithm == 'qFFL':
    alg = qFFL(models, optimizers, args.num_clients, args.num_local_epochs, \
             loss_func)
elif args.algorithm == 'qFedAvg':
    alg = qFedAvg(models, optimizers, args.num_clients, args.num_local_epochs, \
             loss_func, Lipschitz = 1 / args.learning_rate)
else:
    raise NotImplemented

mean_accs = []
for t in range(start_epoch + 1, args.num_epochs):
    
    alg.local_updates(train_loaders, test_loaders, device, output_dir)
    alg.aggregate(weights=weights)
    
    accs = accuracies(alg.models, test_loaders, device)
    losses_ = losses(alg.models, train_loaders, loss_func, device)
    print(f'global epoch: {t}')
    mean, std = mean_std(accs)
    mean_accs.append(mean)
    print(f'losses: {round_list(losses_)}')
    save_acc_loss(json_file, t, accs, losses_)
    if t % args.save_epoch == 0:
        torch.save({'epoch': t, 'state_dict': models[0].state_dict()}, \
            output_dir  + f'/model_last.pth')
    
mean, std = mean_std(accs)
print('mean: ', mean, 'std: ', std)
print(f'accs: {[round(i, 3) for i in accs]}')

acc_file = "mean_acc.pkl".format(args.dataset, args.seed)
acc_file = os.path.join(output_dir, acc_file)

with open(acc_file, 'wb') as f_out:
    pickle.dump(mean_accs, f_out) 

Namespace(algorithm='qFedAvg', batch_size=64, data_dir='../data/MNIST/iid-4', dataset='MNIST', device='7', learning_rate=0.1, num_clients=4, num_epochs=10, num_local_epochs=1, num_workers=0, save_epoch=5, seed=0)
output_dir:  ../results/MNIST/iid-4/FedAvg/batch_size_64/seed_0
total train samples: 30000
total test samples: 30000
total samples: 60000
samples:  [7500 7500 7500 7500]
global epoch: 6
losses: [0.103, 0.1055, 0.0898, 0.0811]
global epoch: 7
losses: [0.0849, 0.0854, 0.0809, 0.0587]
global epoch: 8
losses: [0.1117, 0.1063, 0.1052, 0.08]
global epoch: 9
losses: [0.0804, 0.0797, 0.08, 0.0463]
mean:  97.50666666666666 std:  0.18135294011647532
accs: [97.547, 97.76, 97.253, 97.467]
