### argparse.py

In [None]:
import argparse

def args_parser():
    parser = argparse.ArgumentParser(description='FMNIST baseline')
    parser.add_argument('--name', '-n',
                        default="default",
                        type=str,
                        help='experiment name, used for saving results')
    parser.add_argument('--backend',
                        default="gloo",
                        type=str,
                        help='backend name')
    parser.add_argument('--out_fname',
                        default=".",
                        type=str,
                        help='where to store log files')
    parser.add_argument('--model',
                        default="MLP",
                        type=str,
                        help='neural network model')
    parser.add_argument('--alpha',
                        default=0.2,
                        type=float,
                        help='control the non-iidness of dataset')
    parser.add_argument('--num_classes',
                        type=int,
                        default=10,
                        help='number of classes')
    parser.add_argument('--gmf',
                        default=0,
                        type=float,
                        help='global (server) momentum factor')
    parser.add_argument('--lr',
                        default=0.1,
                        type=float,
                        help='client learning rate')
    parser.add_argument('--momentum',
                        default=0.0,
                        type=float,
                        help='local (client) momentum factor')
    parser.add_argument('--bs',
                        default=64,
                        type=int,
                        help='batch size on each worker/client')
    parser.add_argument('--rounds',
                        default=500,
                        type=int,
                        help='total communication rounds')
    parser.add_argument('--localE',
                        default=30,
                        type=int,
                        help='number of local epochs')
    parser.add_argument('--decay',
                        default=True,
                        type=bool,
                        help='1: decay LR, 0: no decay')
    parser.add_argument('--print_freq',
                        default=100,
                        type=int,
                        help='print info frequency')
    parser.add_argument('--size',
                        default=3,
                        type=int,
                        help='number of local workers')
    parser.add_argument('--powd',
                        default=6,
                        type=int,
                        help='number of selected subset workers per round ($d$)')
    parser.add_argument('--fracC',
                        default=0.03,
                        type=float,
                        help='fraction of selected workers per round')
    parser.add_argument('--seltype',
                        default='rand',
                        type=str,
                        help='type of client selection ($\pi$)')
    parser.add_argument('--ensize',
                        default=100,
                        type=int,
                        help='number of all workers')
    parser.add_argument('--rank',
                        default=0,
                        type=int,
                        help='the rank of worker')
    parser.add_argument('--rnd_ratio',
                        default=0.1,
                        type=float,
                        help='hyperparameter for afl')
    parser.add_argument('--delete_ratio',
                        default=0.75,
                        type=float,
                        help='hyperparameter for afl')
    parser.add_argument('--seed',
                        default=1,
                        type=int,
                        help='random seed')
    parser.add_argument('--save', '-s',
                        action='store_true',
                        help='whether save the training results')
    parser.add_argument('--p', '-p',
                        action='store_true',
                        help='whether the dataset is partitioned or not')
    parser.add_argument('--NIID',
                        action='store_true',
                        help='whether the dataset is non-iid or not')
    parser.add_argument('--commE',
                        action='store_true',
                        help='activation of $cpow-d$')
    parser.add_argument('--constantE',
                        action='store_true',
                        help='whether all the local workers have an identical \
                        number of local epochs or not')
    parser.add_argument('--optimizer',
                        default='local',
                        type=str,
                        help='optimizer name')
    parser.add_argument('--initmethod',
                        default='env://',
                        type=str,
                        help='init method')
    parser.add_argument('--mu',
                        default=0,
                        type=float,
                        help='mu parameter in fedprox')
    parser.add_argument('--dataset',
                        default='fmnist',
                        type=str,
                        help='type of dataset')
    parser.add_argument('--img_size',
                        default=32,
                        type=int,
                        help='image size')

    args = parser.parse_args()

    return args

### Models.py

In [None]:
import math
from torch import nn
import torch.nn.functional as F

In [None]:
class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden1, dim_hidden2, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden1 = nn.Linear(dim_hidden1, dim_hidden2)
        self.layer_hidden2 = nn.Linear(dim_hidden2, dim_out)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])
        x = self.layer_input(x)
        x = self.relu(x)
        x = self.layer_hidden1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden2(x)

        return self.logsoftmax(x)

In [None]:
class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
class CNNFashion_Mnist(nn.Module):
    def __init__(self, args):
        super(CNNFashion_Mnist, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(7 * 7 * 32, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

In [None]:
class modelC(nn.Module):
    def __init__(self, input_size, n_classes=10):
        super(modelC, self).__init__()
        self.conv1 = nn.Conv2d(input_size, 96, 3, padding=1)
        self.conv2 = nn.Conv2d(96, 96, 3, padding=1)
        self.conv3 = nn.Conv2d(96, 96, 3, padding=1, stride=2)
        self.conv4 = nn.Conv2d(96, 192, 3, padding=1)
        self.conv5 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv6 = nn.Conv2d(192, 192, 3, padding=1, stride=2)
        self.conv7 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv8 = nn.Conv2d(192, 192, 1)

        self.class_conv = nn.Conv2d(192, n_classes, 1)

    def forward(self, x):
        x_drop = F.dropout(x, .2)
        conv1_out = F.relu(self.conv1(x_drop))
        conv2_out = F.relu(self.conv2(conv1_out))
        conv3_out = F.relu(self.conv3(conv2_out))
        conv3_out_drop = F.dropout(conv3_out, .5)
        conv4_out = F.relu(self.conv4(conv3_out_drop))
        conv5_out = F.relu(self.conv5(conv4_out))
        conv6_out = F.relu(self.conv6(conv5_out))
        conv6_out_drop = F.dropout(conv6_out, .5)
        conv7_out = F.relu(self.conv7(conv6_out_drop))
        conv8_out = F.relu(self.conv8(conv7_out))

        class_out = F.relu(self.class_conv(conv8_out))
        pool_out = F.adaptive_avg_pool2d(class_out, 1)
        pool_out.squeeze_(-1)
        pool_out.squeeze_(-1)
        return pool_out

In [None]:
__all__ = ['VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
           'vgg19_bn', 'vgg19']

In [None]:
class VGG(nn.Module):
    """
    VGG model
    """

    def __init__(self, features):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 10),
        )
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 1
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

In [None]:
cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
          512, 512, 512, 512, 'M'],
}

In [None]:
def vgg11():
    """VGG 11-layer model (configuration "A")"""
    return VGG(make_layers(cfg['A']))

def vgg11_bn():
    """VGG 11-layer model (configuration "A") with batch normalization"""
    return VGG(make_layers(cfg['A'], batch_norm=True))

def vgg13():
    """VGG 13-layer model (configuration "B")"""
    return VGG(make_layers(cfg['B']))

def vgg13_bn():
    """VGG 13-layer model (configuration "B") with batch normalization"""
    return VGG(make_layers(cfg['B'], batch_norm=True))

def vgg16():
    """VGG 16-layer model (configuration "D")"""
    return VGG(make_layers(cfg['D']))

def vgg16_bn():
    """VGG 16-layer model (configuration "D") with batch normalization"""
    return VGG(make_layers(cfg['D'], batch_norm=True))

def vgg19():
    """VGG 19-layer model (configuration "E")"""
    return VGG(make_layers(cfg['E']))

def vgg19_bn():
    """VGG 19-layer model (configuration 'E') with batch normalization"""
    return VGG(make_layers(cfg['E'], batch_norm=True))

### Utils.py

In [None]:
import random
from random import Random

In [None]:
import numpy as np
import torch
import torch.utils.data.distributed
import torchvision
from numpy.random import RandomState
from torchvision import transforms
import torch.distributed as dist

In [None]:
class Partition(object):
    """ Dataset-like object, but only access a subset of it. """

    def __init__(self, data, index):
        self.data = data
        self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        data_idx = self.index[index]
        return self.data[data_idx]

In [None]:
class DataPartitioner(object):
    """ Partitions a dataset into different chunks. """

    def __init__(self, data, sizes=[0.7, 0.2, 0.1], rnd=0, seed=1234, isNonIID=False, alpha=0,
                 dataset=None, print_f=50):
        self.data = data
        self.dataset = dataset

        if isNonIID:
            self.partitions, self.ratio, self.dat_stat, self.endat_size = self.__getDirichletData__(data, sizes,
                                                                                                    alpha, rnd, print_f)

        else:
            self.partitions = []
            self.ratio = sizes
            rng = Random()
            rng.seed(seed)  # seed is fixed so same random number is generated
            data_len = len(data)
            indexes = [x for x in range(0, data_len)]
            rng.shuffle(indexes)  # Same shuffling (with each seed)

            for frac in sizes:
                part_len = int(frac * data_len)
                self.partitions.append(indexes[0:part_len])
                indexes = indexes[part_len:]

    def use(self, partition):
        return Partition(self.data, self.partitions[partition])

    def __getNonIIDdata__(self, data, sizes, seed, alpha):
        labelList = data.train_labels
        rng = Random()
        rng.seed(seed)
        a = [(label, idx) for idx, label in enumerate(labelList)]

        # Same Part
        labelIdxDict = dict()
        for label, idx in a:
            labelIdxDict.setdefault(label, [])
            labelIdxDict[label].append(idx)
        labelNum = len(labelIdxDict)
        labelNameList = [key for key in labelIdxDict]
        labelIdxPointer = [0] * labelNum

        # sizes = number of nodes
        partitions = [list() for i in range(len(sizes))]
        eachPartitionLen = int(len(labelList) / len(sizes))

        # majorLabelNumPerPartition = ceil(labelNum/len(partitions))
        majorLabelNumPerPartition = 2
        basicLabelRatio = alpha
        interval = 1
        labelPointer = 0

        # basic part
        for partPointer in range(len(partitions)):
            requiredLabelList = list()
            for _ in range(majorLabelNumPerPartition):
                requiredLabelList.append(labelPointer)
                labelPointer += interval
                if labelPointer > labelNum - 1:
                    labelPointer = interval
                    interval += 1
            for labelIdx in requiredLabelList:
                start = labelIdxPointer[labelIdx]
                idxIncrement = int(basicLabelRatio * len(labelIdxDict[labelNameList[labelIdx]]))
                partitions[partPointer].extend(labelIdxDict[labelNameList[labelIdx]][start:start + idxIncrement])
                labelIdxPointer[labelIdx] += idxIncrement

        # random part
        remainLabels = list()
        for labelIdx in range(labelNum):
            remainLabels.extend(labelIdxDict[labelNameList[labelIdx]][labelIdxPointer[labelIdx]:])
        rng.shuffle(remainLabels)
        for partPointer in range(len(partitions)):
            idxIncrement = eachPartitionLen - len(partitions[partPointer])
            partitions[partPointer].extend(remainLabels[:idxIncrement])
            rng.shuffle(partitions[partPointer])
            remainLabels = remainLabels[idxIncrement:]

        return partitions

    def __getDirichletData__(self, data, psizes, alpha, rnd, print_f):
        n_nets = len(psizes)
        K = 10
        labelList = np.array(data.train_labels)
        min_size = 0
        N = len(labelList)
        rann = RandomState(2020)

        net_dataidx_map = {}
        while min_size < K:
            idx_batch = [[] for _ in range(n_nets)]
            # for each class in the dataset
            for k in range(K):
                idx_k = np.where(labelList == k)[0]
                rann.shuffle(idx_k)
                proportions = rann.dirichlet(np.repeat(alpha, n_nets))
                ## Balance
                proportions = np.array([p * (len(idx_j) < N / n_nets) for p, idx_j in zip(proportions, idx_batch)])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])

        for j in range(n_nets):
            rann.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]

        net_cls_counts = {}

        for net_i, dataidx in net_dataidx_map.items():
            unq, unq_cnt = np.unique(labelList[dataidx], return_counts=True)
            tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
            net_cls_counts[net_i] = tmp

        local_sizes = []
        for i in range(n_nets):
            local_sizes.append(len(net_dataidx_map[i]))
        local_sizes = np.array(local_sizes)
        weights = local_sizes / np.sum(local_sizes)

        if rnd % print_f == 0:
            print('Data statistics: %s' % str(net_cls_counts))
            print('Data ratio: %s' % str(weights))

        return idx_batch, weights, net_cls_counts, np.sum(local_sizes)

In [None]:
def partition_dataset(size, args, rnd):
    if args.dataset == 'cifar':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        trainset = torchvision.datasets.CIFAR10(root='./data',
                                                train=True,
                                                download=True,
                                                transform=transform_train)

        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=64,
                                                   shuffle=False,
                                                   num_workers=size)

        partition_sizes = [1.0 / args.ensize for _ in range(args.ensize)]
        partition = DataPartitioner(trainset, partition_sizes, rnd, isNonIID=args.NIID, alpha=args.alpha,
                                    dataset=args.dataset, print_f=args.print_freq)
        ratio = partition.ratio

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        testset = torchvision.datasets.CIFAR10(root='./data',
                                               train=False,
                                               download=True,
                                               transform=transform_test)

        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=64,
                                                  shuffle=False,
                                                  num_workers=size)

    elif args.dataset == 'fmnist':
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        trainset = torchvision.datasets.FashionMNIST(root='./data',
                                                     train=True,
                                                     download=True,
                                                     transform=apply_transform)

        #train_sampler = \
        #    torch.utils.data.distributed.DistributedSampler(trainset,
        #                                                    num_replicas=dist.get_world_size(),
        #                                                    rank=dist.get_rank())

        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=64,
                                                   shuffle=True, #train_sampler is None,
                                                   #sampler=train_sampler,
                                                   num_workers=size)

        partition_sizes = [1.0 / args.ensize for _ in range(args.ensize)]
        partition = DataPartitioner(trainset, partition_sizes, rnd, isNonIID=args.NIID, alpha=args.alpha,
                                    dataset=args.dataset, print_f=args.print_freq)
        ratio = partition.ratio  # Ratio of data sizes

        testset = torchvision.datasets.FashionMNIST(root='./data',
                                                    train=False,
                                                    download=True,
                                                    transform=apply_transform)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=64,
                                                  shuffle=False,
                                                  num_workers=size)

    elif args.dataset == 'emnist':
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        trainset = torchvision.datasets.EMNIST(root='./data',
                                               split='digits',
                                               train=True,
                                               download=True,
                                               transform=apply_transform)

        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=64,
                                                   shuffle=False,
                                                   num_workers=size)

        partition_sizes = [1.0 / args.ensize for _ in range(args.ensize)]
        partition = DataPartitioner(trainset, partition_sizes, rnd, isNonIID=args.NIID, alpha=args.alpha,
                                    dataset=args.dataset, print_f=args.print_freq)
        ratio = partition.ratio  # Ratio of data sizes

        testset = torchvision.datasets.EMNIST(root='./data',
                                              split='digits',
                                              train=False,
                                              download=True,
                                              transform=apply_transform)
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=64,
                                                  shuffle=False,
                                                  num_workers=size)

    # add more datasets here

    args.img_size = trainset[0][0].shape

    return partition, train_loader, test_loader, ratio, partition.dat_stat, partition.endat_size

In [None]:
def partitiondata_loader(partition, rank, batch_size):
    """
    single mini-batch loader
    """
    partition = partition.use(rank)

    data_idx = random.sample(range(len(partition)), k=int(min(batch_size, len(partition))))
    partitioned = torch.utils.data.Subset(partition, indices=data_idx)
    trainbatch_loader = torch.utils.data.DataLoader(partitioned,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    pin_memory=True)
    return trainbatch_loader

In [None]:
def sel_client(DataRatios, cli_loss, cli_val, args, rnd):
    """
    Client selection part returning the indices the set $\mathcal{S}$ and $\mathcal{A}$
    :param DataRatios: $p_k$
    :param cli_loss: actual local loss F_k(w)
    :param cli_val: proxy of the local loss
    :param args: variable arguments
    :param rnd: communication round index
    :return: idxs_users (indices of $\mathcal{S}$), rnd_idx (indices of $\mathcal{A}$)
    """
    # If reproducibility is needed
    # rng1 = Random()
    # rng1.seed(seed)

    rnd_idx = []
    if args.seltype == 'rand':
        # random selection in proportion to $p_k$ with replacement
        idxs_users = np.random.choice(args.ensize, p=DataRatios, size=args.size, replace=True)

    elif args.seltype == 'randint':
        # 'rand' for intermittent client availability
        delete = 0.2
        if (rnd % 2) == 0:
            del_idx = np.random.choice(int(args.ensize / 2), size=int(delete * args.ensize / 2), replace=False)
            search_idx = np.delete(np.arange(0, args.ensize / 2), del_idx)
        else:
            del_idx = np.random.choice(np.arange(args.ensize / 2, args.ensize), size=int(delete * args.ensize / 2),
                                       replace=False)
            search_idx = np.delete(np.arange(args.ensize / 2, args.ensize), del_idx)

        idxs_users = np.random.choice(search_idx, p=[DataRatios[int(i)] for i in search_idx] / sum([DataRatios[int(i)]
                                                                                                    for i in
                                                                                                    search_idx]),
                                      size=args.size, replace=True)

    elif args.seltype == 'pow-d':
        # standard power-of-choice strategy
        rnd_idx = np.random.choice(args.ensize, p=DataRatios, size=args.powd, replace=False)
        repval = list(zip([cli_loss[i] for i in rnd_idx], rnd_idx))
        repval.sort(key=lambda x: x[0], reverse=True)
        rep = list(zip(*repval))
        idxs_users = rep[1][:int(args.size)]

    elif args.seltype == 'rpow-d':
        # computation/communication efficient variant of 'pow-d'
        rnd_idx1 = np.random.choice(args.ensize, p=DataRatios, size=args.powd, replace=False)
        repval = list(zip([cli_val[i] for i in rnd_idx1], rnd_idx1))
        repval.sort(key=lambda x: x[0], reverse=True)
        rep = list(zip(*repval))
        idxs_users = rep[1][:int(args.size)]

    elif args.seltype == 'pow-dint':
        # 'pow-d' for intermittent client availability
        delete = 0.2
        if (rnd % 2) == 0:
            del_idx = np.random.choice(int(args.ensize / 2), size=int(delete * args.ensize / 2), replace=False)
            search_idx = list(np.delete(np.arange(0, args.ensize / 2), del_idx))
        else:
            del_idx = np.random.choice(np.arange(args.ensize / 2, args.ensize), size=int(delete * args.ensize / 2),
                                       replace=False)
            search_idx = list(np.delete(np.arange(args.ensize / 2, args.ensize), del_idx))

        rnd_idx = np.random.choice(search_idx, p=[DataRatios[int(i)] for i in search_idx] / sum([DataRatios[int(i)]
                                                                                                 for i in search_idx]),
                                   size=args.powd, replace=False)

        repval = list(zip([cli_loss[int(i)] for i in rnd_idx], rnd_idx))
        repval.sort(key=lambda x: x[0], reverse=True)
        rep = list(zip(*repval))
        idxs_users = rep[1][:int(args.size)]

    elif args.seltype == 'rpow-dint':
        # 'rpow-d' for intermittent client availability
        delete = 0.2
        if (rnd % 2) == 0:
            del_idx = np.random.choice(int(args.ensize / 2), size=int(delete * args.ensize / 2), replace=False)
            search_idx = list(np.delete(np.arange(0, args.ensize / 2), del_idx))
        else:
            del_idx = np.random.choice(np.arange(args.ensize / 2, args.ensize), size=int(delete * args.ensize / 2),
                                       replace=False)
            search_idx = list(np.delete(np.arange(args.ensize / 2, args.ensize), del_idx))

        rnd_idx = np.random.choice(search_idx, p=[DataRatios[int(i)] for i in search_idx] / sum([DataRatios[int(i)]
                                                                                                 for i in search_idx]),
                                   size=args.powd, replace=False)

        repval = list(zip([cli_val[int(i)] for i in rnd_idx], rnd_idx))
        repval.sort(key=lambda x: x[0], reverse=True)
        rep = list(zip(*repval))
        idxs_users = rep[1][:int(args.size)]

    elif args.seltype == 'afl':
        # benchmark strategy
        soft_temp = 0.01
        sorted_loss_idx = np.argsort(cli_val)

        for j in sorted_loss_idx[:int(args.delete_ratio * args.ensize)]:
            cli_val[j] = -np.inf

        loss_prob = np.exp(soft_temp * cli_val) / sum(np.exp(soft_temp * cli_val))
        idx1 = np.random.choice(int(args.ensize), p=loss_prob, size=int(np.floor((1 - args.rnd_ratio) * args.size)),
                                replace=False)

        new_idx = np.delete(np.arange(0, args.ensize), idx1)
        idx2 = np.random.choice(new_idx, size=int(args.size - np.floor((1 - args.rnd_ratio) * args.size)),
                                replace=False)

        idxs_users = list(idx1) + list(idx2)

    return idxs_users, rnd_idx

In [None]:
def choices(population, weights=None, cum_weights=None, k=1):
    """Return a k sized list of population elements chosen with replacement.
    If the relative weights or cumulative weights are not specified,
    the selections are made with equal probability.
    """

    if cum_weights is None:
        if weights is None:
            total = len(population)
            result = []
            for i in range(k):
                random.seed(i)
                result.extend(population[int(random.random() * total)])
            return result
        cum_weights = []
        c = 0
        for x in weights:
            c += x
            cum_weights.append(c)
    elif weights is not None:
        raise TypeError('Cannot specify both weights and cumulative weights')
    if len(cum_weights) != len(population):
        raise ValueError('The number of weights does not match the population')
    total = cum_weights[-1]
    hi = len(cum_weights) - 1
    from bisect import bisect
    result = []
    for i in range(k):
        random.seed(i)
        result.extend(population[bisect(cum_weights, random.random() * total, 0, hi)])
    return result

In [None]:
class Meter(object):
    """ Computes and stores the average, variance, and current value """

    def __init__(self, init_dict=None, ptag='Time', stateful=False,
                 csv_format=True):
        """
        :param init_dict: Dictionary to initialize meter values
        :param ptag: Print tag used in __str__() to identify meter
        :param stateful: Whether to store value history and compute MAD
        """
        self.reset()
        self.ptag = ptag
        self.value_history = None
        self.stateful = stateful
        if self.stateful:
            self.value_history = []
        self.csv_format = csv_format
        if init_dict is not None:
            for key in init_dict:
                try:
                    # TODO: add type checking to init_dict values
                    self.__dict__[key] = init_dict[key]
                except Exception:
                    print('(Warning) Invalid key {} in init_dict'.format(key))

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.std = 0
        self.sqsum = 0
        self.mad = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.sqsum += (val ** 2) * n
        if self.count > 1:
            self.std = ((self.sqsum - (self.sum ** 2) / self.count)
                        / (self.count - 1)
                        ) ** 0.5
        if self.stateful:
            self.value_history.append(val)
            mad = 0
            for v in self.value_history:
                mad += abs(v - self.avg)
            self.mad = mad / len(self.value_history)

    def __str__(self):
        if self.csv_format:
            if self.stateful:
                return str('{dm.val:.3f},{dm.avg:.3f},{dm.mad:.3f}'
                           .format(dm=self))
            else:
                return str('{dm.val:.3f},{dm.avg:.3f},{dm.std:.3f}'
                           .format(dm=self))
        else:
            if self.stateful:
                return str(self.ptag) + \
                    str(': {dm.val:.3f} ({dm.avg:.3f} +- {dm.mad:.3f})'
                        .format(dm=self))
            else:
                return str(self.ptag) + \
                    str(': {dm.val:.3f} ({dm.avg:.3f} +- {dm.std:.3f})'
                        .format(dm=self))

### comm_helpers.py

In [None]:
def flatten_tensors(tensors):
    if len(tensors) == 1:
        return tensors[0].view(-1).clone()
    flat = torch.cat([t.view(-1) for t in tensors], dim=0)
    return flat

def unflatten_tensors(flat, tensors):
    outputs = []
    offset = 0
    for tensor in tensors:
        numel = tensor.numel()
        outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
        offset += numel
    return tuple(outputs)

def communicate(tensors, communication_op, attention=False):
    flat_tensor = flatten_tensors(tensors)
    communication_op(tensor=flat_tensor)
    if attention:
        return tensors / flat_tensor
    for f, t in zip(unflatten_tensors(flat_tensor, tensors), tensors):
        t.set_(f)

### FedAvg.py

In [None]:
import torch
import torch.distributed as dist
from torch.optim.optimizer import Optimizer
from .comm_helpers import communicate, flatten_tensors, unflatten_tensors

In [None]:
class fedavg(Optimizer):
    r"""Implements stochastic gradient descent for FedAvg."""

    def __init__(self, params, ratio, gmf, mu=0, lr=0.01, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, variance=0):

        self.gmf = gmf
        self.ratio = ratio
        self.etamu = mu * lr
        self.mu = mu

        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov, variance=variance)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(fedavg, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(fedavg, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefargs.fracCault('nesterov', False)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        device = "cuda" if torch.cuda.is_available() else "cpu"

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data

                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)

                param_state = self.state[p]
                if 'old_init' not in param_state:
                    param_state['old_init'] = torch.clone(p.data).detach()

                local_lr = group['lr']

                # apply momentum updates
                if momentum != 0:
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                # apply proximal updates
                if self.etamu != 0:
                    d_p.add_(self.mu, p.data - param_state['old_init'])

                if 'cum_grad' not in param_state:
                    param_state['cum_grad'] = torch.clone(d_p).detach()
                    param_state['cum_grad'].mul_(local_lr)

                else:
                    param_state['cum_grad'].add_(local_lr, d_p)

                p.data.add_(-local_lr, d_p)

        return loss

    def average(self, weight):
        param_list = []

        for group in self.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['cum_grad'].mul_(weight)
                param_list.append(param_state['cum_grad'])

        communicate(param_list, dist.all_reduce)

        for group in self.param_groups:
            lr = group['lr']
            for p in group['params']:
                param_state = self.state[p]

                if self.gmf != 0:
                    if 'global_momentum_buffer' not in param_state:
                        buf = param_state['global_momentum_buffer'] = torch.clone(param_state['cum_grad']).detach()
                        buf.div_(lr)
                    else:
                        buf = param_state['global_momentum_buffer']
                        buf.mul_(self.gmf).add_(1 / lr, param_state['cum_grad'])
                    param_state['old_init'].sub_(lr, buf)
                else:
                    param_state['old_init'].sub_(param_state['cum_grad'])

                p.data.copy_(param_state['old_init'])
                param_state['cum_grad'].zero_()

                # Reinitialize momentum buffer
                if 'momentum_buffer' in param_state:
                    param_state['momentum_buffer'].zero_()

### train_dnn.py

In [None]:
import datetime

In [None]:
import numpy as np
import random

In [None]:
import time
import pathlib
import logging

In [None]:
import torch
import torch.distributed as dist
import torch.utils.data.distributed
import torch.nn as nn
import torch.backends.cudnn as cudnn
from distoptim import fedavg
import util_v4 as util
import models
from params import args_parser

In [None]:
logging.basicConfig(format='%(levelname)s - %(message)s', level=logging.INFO)
logging.debug('This message should appear on the console')

In [None]:
args = args_parser()

In [None]:
def run(rank, size):
    # initiate experiments folder
    save_path = '/mnt/batch/tasks/shared/LS_root/mounts/clusters/tsiameh1/code/Users/tsiameh/dnn'
    fold = 'lr{:.4f}_bs{}_cp{}_a{:.2f}_e{}_r0_n{}_f{:.2f}/'.format(args.lr, args.bs, args.localE, args.alpha, args.seed,
                                                                   args.ensize, args.fracC)
    if args.commE:
        fold = 'com_' + fold
    folder_name = save_path + args.name + '/' + fold
    file_name = '{}_rr{:.2f}_dr{:.2f}_lr{:.3f}_bs{:d}_cp{:d}_a{:.2f}_e{}_r{}_n{}_f{:.2f}_p{}.csv'.format(args.seltype,
                                                                                                         args.rnd_ratio,
                                                                                                         args.delete_ratio,
                                                                                                         args.lr,
                                                                                                         args.bs,
                                                                                                         args.localE,
                                                                                                         args.alpha,
                                                                                                         args.seed,
                                                                                                         rank,
                                                                                                         args.ensize,
                                                                                                         args.fracC,
                                                                                                         args.powd)
    pathlib.Path(folder_name).mkdir(parents=True, exist_ok=True)

    # initiate log files
    saveFileName = folder_name + file_name
    args.out_fname = saveFileName
    with open(args.out_fname, 'w+') as f:
        print('BEGIN-TRAINING\n' 'World-Size,{ws}\n' 'Batch-Size,{bs}\n' 'Epoch,itr,'
              'loss,trainloss,avg:Loss,Prec@1,avg:Prec@1,val,trainval,updtime,comptime,seltime,entime'.format(
            ws=args.size, bs=args.bs), file=f)

    # seed for reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True

    # load data
    partition, train_loader, test_loader, dataratios, datstat, endat = util.partition_dataset(size, args, 0)

    # initialization for client selection
    cli_loss, cli_freq, cli_val = np.zeros(args.ensize) + 1, np.zeros(args.ensize), np.zeros(args.ensize)

    tmp_cli = [torch.tensor(0, dtype=torch.float32).cuda() for _ in range(dist.get_world_size())]
    tmp_clifreq = [torch.tensor(0).cuda() for _ in range(dist.get_world_size())]

    dist.barrier()
    # select client for each round, in total m ranks
    send = torch.zeros(args.size, dtype=torch.int32).cuda()
    if rank == 0:
        replace_param = False
        if args.seltype == 'rand':
            replace_param = True

        idxs_users = np.random.choice(args.ensize, size=args.size, replace=replace_param)
        send = [torch.tensor(int(ii)).cuda() for ii in idxs_users]
    dist.barrier()

    for i in range(args.size):
        dist.broadcast(tensor=send[i], src=0)
    dist.barrier()
    sel_idx = int(send[rank])

    # define neural nets model, criterion, and optimizer
    if args.model == 'MLP':
        len_in = 1
        for x in args.img_size:
            len_in *= x
        model = models.MLP(dim_in=len_in, dim_hidden1=64, dim_hidden2=30, dim_out=args.num_classes).cuda()

    else:
        model = models.vgg11().cuda()  # vgg

    criterion = nn.NLLLoss().cuda()

    # select optimizer according to algorithm
    algorithms = {'fedavg': fedavg}

    selected_opt = algorithms[args.optimizer]
    optimizer = selected_opt(model.parameters(),
                             lr=args.lr,
                             gmf=args.gmf,  # set to 0
                             mu=args.mu,  # set to 0
                             ratio=dataratios[rank],
                             momentum=args.momentum,  # set to 0
                             nesterov=False,
                             weight_decay=1e-4)

    for rnd in range(args.rounds):

        # Initialize hyperparameters
        local_epochs = args.localE
        weight = 1 / args.size

        # Decay learning rate according to round index (optional)
        if args.decay:
            update_learning_rate(optimizer, rnd, args.lr)

        # Clients locally train for several local epochs
        loss_final = 0
        dist.barrier()
        comm_update_start = time.time()
        for t in range(local_epochs):
            singlebatch_loader = util.partitiondata_loader(partition, sel_idx, args.bs)
            loss = train(model, criterion, optimizer, singlebatch_loader, t)
            loss_final += loss / local_epochs
        dist.barrier()
        comm_update_end = time.time()
        update_time = comm_update_end - comm_update_start

        # Getting value function for client selection (required only for 'rpow-d', 'afl')
        dist.barrier()  # TODO: implement with multi-arm bandit
        dist.all_gather(tmp_cli, torch.tensor(loss_final).cuda())
        dist.all_gather(tmp_clifreq, torch.tensor(int(sel_idx)).cuda())
        dist.barrier()
        for i, i_val in enumerate(tmp_clifreq):
            cli_freq[i_val.item()] += 1  # Cli freq is the entire clients that are selected for all rounds
            cli_val[i_val.item()] = tmp_cli[i].item()
        not_visited = np.where(cli_freq == 0)[0]

        for ii in not_visited:
            if args.seltype == 'afl':
                cli_val[ii] = -np.inf
            else:
                cli_val[ii] = np.inf

        # synchronize parameters
        dist.barrier()
        optimizer.average(weight=weight)
        dist.barrier()

        # evaluate test accuracy
        test_acc, test_loss = evaluate(model, test_loader, criterion)

        # evaluate loss values and sync selected frequency
        cli_loss, cli_comptime = evaluate_client(model, criterion, partition)
        train_loss = sum([cli_loss[i] * dataratios[i] for i in range(args.ensize)])
        train_loss1 = sum(cli_loss) / args.ensize

        dist.barrier()
        # Select client for each round, in total m ranks
        send = torch.zeros(args.size, dtype=torch.int32).cuda()
        comp_time, sel_time = 0, 0

        if rank == 0:
            sel_time_start = time.time()
            idxs_users, rnd_idx = util.sel_client(dataratios, cli_loss, cli_val, args, rnd)
            sel_time_end = time.time()
            sel_time = sel_time_end - sel_time_start

            if args.seltype == 'pow-d' or args.seltype == 'pow-dint':
                comp_time = max([cli_comptime[int(i)] for i in rnd_idx])

            send = [torch.tensor(int(ii)).cuda() for ii in idxs_users]
        dist.barrier()
        for i in range(args.size):
            dist.broadcast(tensor=send[i], src=0)
        dist.barrier()
        sel_idx = int(send[rank])

        # record metrics
        logging.info("Round {} rank {} test accuracy {:.3f} test loss {:.3f}".format(rnd, rank, test_acc, test_loss))
        with open(args.out_fname, '+a') as f:
            print('{ep},{itr},{loss:.4f},{trainloss:.4f},{filler},'
                  '{filler},{filler},'
                  '{val:.4f},{other:.4f},{updtime:.4f},{comptime:.4f},{seltime:.4f},{entime:.4f}'
                  .format(ep=rnd, itr=-1, loss=test_loss, trainloss=train_loss,
                          filler=-1, val=test_acc, other=train_loss1, updtime=update_time, comptime=comp_time,
                          seltime=sel_time, entime=update_time + comp_time + sel_time), file=f)

In [None]:
def evaluate_client(model, criterion, partition):
    """
    Evaluating each client's local loss values for the current global model for client selection
    :param model: current global model
    :param criterion: loss function
    :param partition: dataset dict for clients
    :return: cli_loss = list of local loss values, cli_comptime = list of computation time
    """

    cli_comptime, cli_loss = [], []
    model.eval()

    # Get data from client to evaluate local loss on
    for i in range(args.ensize):
        partitioned = partition.use(i)

        # cpow-d
        if args.commE:
            seldata_idx = random.sample(range(len(partitioned)), k=int(min(args.bs, len(partitioned))))
            partitioned = torch.utils.data.Subset(partitioned, indices=seldata_idx)

        train_loader = torch.utils.data.DataLoader(partitioned,
                                                   batch_size=len(partitioned),
                                                   shuffle=False,
                                                   pin_memory=True)

        # Compute local loss values or proxies for the clients
        tmp, total = 0, 0
        with torch.no_grad():
            comptime_start = time.time()
            for batch_idx, (data, target) in enumerate(train_loader):
                data = data.cuda(non_blocking=True)
                target = target.cuda(non_blocking=True)
                outputs = model(data)
                loss = criterion(outputs, target)
                tmp += loss.item()
                total += 1
            final_loss = tmp / total
            comptime_end = time.time()
            cli_comptime.append(comptime_end - comptime_start)
            cli_loss.append(final_loss)

    return cli_loss, cli_comptime

In [None]:
def evaluate(model, test_loader, criterion):
    """
    Evaluate test accuracy
    """

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0

    # Get test accuracy for the current model
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data = data.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # Inference
            outputs = model(data)
            batch_loss = criterion(outputs, target)
            loss += batch_loss.item()

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels.view(-1), target)).item() / len(pred_labels)
            total += 1

        acc = (correct / total) * 100
        los = loss / total

    return acc, los

In [None]:
def train(model, criterion, optimizer, loader, epoch):
    """
    train model on the sampled mini-batch for $\tau$ epochs
    """

    model.train()
    loss, total, correct = 0.0, 0.0, 0.0

    for batch_idx, (data, target) in enumerate(loader):
        # data loading
        data = data.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # forward pass
        output = model(data)
        batch_loss = criterion(output, target)

        # backward pass
        batch_loss.backward()

        # gradient clipping
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10, norm_type=2)

        # gradient step
        optimizer.step()
        optimizer.zero_grad()

        # write log files
        loss += batch_loss.item()

        # Prediction
        _, pred_labels = torch.max(output, 1)
        correct += torch.sum(torch.eq(pred_labels.view(-1), target)).item() / len(pred_labels)
        total += 1

        acc = (correct / total) * 100
        los = loss / total

        if batch_idx % args.print_freq == 0 and args.save:
            logging.debug('epoch {} itr {}, '
                          'rank {}, loss value {:.4f}, train accuracy {:.3f}'
                          .format(epoch, batch_idx, rank, los, acc))

            with open(args.out_fname, '+a') as f:
                print('{ep},{itr},'
                      '{loss:.4f},-1,-1,'
                      '{top1:.3f},-1,-1,-1,-1,-1,-1'
                      .format(ep=epoch, itr=batch_idx,loss=los, top1=acc), file=f)

    with open(args.out_fname, '+a') as f:
        print('{ep},{itr},'
              '{loss:.4f},-1,-1,'
              '{top1:.3f},-1,-1,-1,-1,-1,-1'
              .format(ep=epoch, itr=batch_idx,loss=los, top1=acc), file=f)

    return los

In [None]:
def update_learning_rate(optimizer, epoch, target_lr):
    """
    Decay learning rate
    ** note: target_lr is the reference learning rate from which to scale down
    """
    if epoch == 149:
        lr = target_lr / 2
        logging.info('Updating learning rate to {}'.format(lr))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    if epoch == 299:
        lr = target_lr / 4
        logging.info('Updating learning rate to {}'.format(lr))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

In [None]:
def init_processes(rank, size, fn):
    """ Initialize the distributed environment. """

    dist.init_process_group(backend=args.backend,
                            timeout=datetime.timedelta(hours=5),
                            init_method=args.initmethod,
                            rank=rank,
                            world_size=size)
    fn(rank, size)

In [None]:
if __name__ == "__main__":
    rank = args.rank
    size = args.size

    init_processes(rank, size, run)