# FLDetector for CIFAR10with Dirichlet distribution


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [6]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
import copy
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [7]:
from resnet import *

In [8]:
from cifar10_models import *
from cifar10_util import *

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable

__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']

def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])

In [10]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))

        return per_participant_list

In [11]:
def get_fang_train_data(trainset, num_workers=100, bias=0.5):
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10

    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    per_participant_list = defaultdict(list)
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        
        per_participant_list[selected_worker].extend([i])
    
    return per_participant_list

In [12]:
def get_federated_data(trainset, num_workers, distribution='fang', param=1, force=False):
    if distribution == 'fang':
        per_participant_list = get_fang_train_data(trainset, num_workers, bias=param)
    elif distribution == 'dirichlet':
        per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=param, force=force)

    each_worker_idx = [[] for _ in range(num_workers)]
    
    each_worker_te_idx = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(6*w_len/7)
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        
        each_worker_idx[worker_idx] = w_indices[tr_idx]
        each_worker_te_idx[worker_idx] = w_indices[te_idx]
    
    global_test_idx = np.concatenate(each_worker_te_idx)
    
    return each_worker_idx, each_worker_te_idx, global_test_idx

In [13]:
def get_train(dataset, indices, batch_size=32, shuffle=False):
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
    
    return train_loader

In [14]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='/home/vshejwalkar_umass_edu/data/'
# load the train dataset

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_train)
cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_train)

te_cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_test)
te_cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [15]:
def lbfgs(S_k_list, Y_k_list, v):
    curr_S_k = torch.stack(S_k_list).T
    curr_Y_k = torch.stack(Y_k_list).T
    S_k_time_Y_k = np.dot(curr_S_k.T.cpu().numpy(), curr_Y_k.cpu().numpy())
    S_k_time_S_k = np.dot(curr_S_k.T.cpu().numpy(), curr_S_k.cpu().numpy())
    R_k = np.triu(S_k_time_Y_k)
    L_k = S_k_time_Y_k - R_k
    sigma_k = np.dot(Y_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()) / (np.dot(S_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()))
    D_k_diag = np.diag(S_k_time_Y_k)
    upper_mat = np.concatenate((sigma_k * S_k_time_S_k, L_k), axis=1)
    lower_mat = np.concatenate((L_k.T, -np.diag(D_k_diag)), axis=1)
    mat = np.concatenate((upper_mat, lower_mat), axis=0)
    mat_inv = np.linalg.inv(mat)

    approx_prod = sigma_k * v.cpu().numpy()
    approx_prod = approx_prod.T
    p_mat = np.concatenate((np.dot(curr_S_k.T.cpu().numpy(), sigma_k * v.unsqueeze(0).T.cpu().numpy()), np.dot(curr_Y_k.T.cpu().numpy(), v.unsqueeze(0).T.cpu().numpy())), axis=0)
    approx_prod -= np.dot(np.dot(np.concatenate((sigma_k * curr_S_k.cpu().numpy(), curr_Y_k.cpu().numpy()), axis=1), mat_inv), p_mat)

    return approx_prod

In [16]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(20):
        # apply attack to compromised worker devices with randomness
        ##random_12 = 1. + np.random.uniform(size=vi_shape)
        ##random_12 = torch.Tensor(random_12).float().cuda()
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [17]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [18]:
def simple_mean(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = torch.mean(user_grads,dim=0)
    
    return agg_grads, distance

In [19]:
def median(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = torch.median(user_grads, 0)[0]
    
    return agg_grads, distance

In [20]:
def trimmed_mean(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = tr_mean(user_grads, 20)
    
    return agg_grads, distance

In [21]:
def detection(score, nobyz):
    estimator = KMeans(n_clusters=2)
    estimator.fit(score.reshape(-1, 1))
    label_pred = estimator.labels_
    if np.mean(score[label_pred==0])<np.mean(score[label_pred==1]):
        #0 is the label of malicious clients
        label_pred = 1 - label_pred
    real_label=np.ones(100)
    real_label[:nobyz]=0
    acc=len(label_pred[label_pred==real_label])/100
    recall=1-np.sum(label_pred[:nobyz])/nobyz
    fpr=1-np.sum(label_pred[nobyz:])/(100-nobyz)
    fnr=np.sum(label_pred[:nobyz])/nobyz
    print("acc %0.4f; recall %0.4f; fpr %0.4f; fnr %0.4f;" % (acc, recall, fpr, fnr))
    print(silhouette_score(score.reshape(-1, 1), label_pred))

def detection1(score, nobyz):
    nrefs = 10
    ks = range(1, 8)
    gaps = np.zeros(len(ks))
    gapDiff = np.zeros(len(ks) - 1)
    sdk = np.zeros(len(ks))
    min = np.min(score)
    max = np.max(score)
    score = (score - min)/(max-min)
    for i, k in enumerate(ks):
        estimator = KMeans(n_clusters=k)
        estimator.fit(score.reshape(-1, 1))
        label_pred = estimator.labels_
        center = estimator.cluster_centers_
        Wk = np.sum([np.square(score[m]-center[label_pred[m]]) for m in range(len(score))])
        WkRef = np.zeros(nrefs)
        for j in range(nrefs):
            rand = np.random.uniform(0, 1, len(score))
            estimator = KMeans(n_clusters=k)
            estimator.fit(rand.reshape(-1, 1))
            label_pred = estimator.labels_
            center = estimator.cluster_centers_
            WkRef[j] = np.sum([np.square(rand[m]-center[label_pred[m]]) for m in range(len(rand))])
        gaps[i] = np.log(np.mean(WkRef)) - np.log(Wk)
        sdk[i] = np.sqrt((1.0 + nrefs) / nrefs) * np.std(np.log(WkRef))

        if i > 0:
            gapDiff[i - 1] = gaps[i - 1] - gaps[i] + sdk[i]
    #print(gapDiff)
    for i in range(len(gapDiff)):
        if gapDiff[i] >= 0:
            select_k = i+1
            break
    if select_k == 1:
        print('No attack detected!')
        return 0
    else:
        print('Attack Detected!')
        return 1

In [22]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(trainloader, model, model_received, criterion, optimizer, pgd=False, eps=2):
    # switch to train mode
    model.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(trainloader):

        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item()/100.0, inputs.size()[0])
        top5.update(prec5.item()/100.0, inputs.size()[0])

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if pgd:
            curr_model = list(model.parameters())
            curr_model_vec = parameters_to_vector(curr_model)

            if torch.norm(curr_model_vec - model_received) > eps:
                curr_model_vec = eps*(curr_model_vec - model_received)/torch.norm(curr_model_vec - model_received) + model_received
                vector_to_parameters(curr_model_vec, curr_model)
        
    return (losses.avg, top1.avg)

def test(testloader, model, criterion):
    model.eval()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)

        loss = criterion(outputs, targets)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data, inputs.size()[0])
        top1.update(prec1/100.0, inputs.size()[0])
        top5.update(prec5/100.0, inputs.size()[0])
    return (losses.avg, top1.avg)

# Dirichlet: Find a good CIFAR10 baseline (with 72% clients)

In [24]:
torch.cuda.empty_cache()
nepochs=2000
local_epochs = 2
batch_size = 16
num_workers = 100

local_lr = 0.1
global_lr = 1
nepochs = 2000

nbyz = 28
byz_type = 'full_trim'
aggregation = 'trim'


all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='dirichlet'
param = .1
force = True
each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    batch_size = batch_size
    train_loaders.append((pos, get_train(all_data, indices, batch_size)))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)
        # optimizer = optim.SGD(model.parameters(), lr = 1)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    agg_update = torch.mean(user_updates, 0)

    del user_updates

    model_received = model_received + global_lr * agg_update
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name

    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

generating participant indices for alpha 0.1
e 0 benign_norm 15978.102 val loss 2.318 val acc 0.099 best val_acc 0.099
e 10 benign_norm 287.626 val loss 1.930 val acc 0.322 best val_acc 0.322
e 20 benign_norm 284.817 val loss 1.639 val acc 0.437 best val_acc 0.437
e 30 benign_norm 284.760 val loss 1.476 val acc 0.502 best val_acc 0.502
e 40 benign_norm 284.818 val loss 1.342 val acc 0.553 best val_acc 0.557
e 50 benign_norm 285.020 val loss 1.239 val acc 0.588 best val_acc 0.594
e 60 benign_norm 284.978 val loss 1.156 val acc 0.626 best val_acc 0.626
e 70 benign_norm 285.150 val loss 1.081 val acc 0.649 best val_acc 0.649
e 80 benign_norm 285.266 val loss 1.028 val acc 0.663 best val_acc 0.666
e 90 benign_norm 285.471 val loss 0.999 val acc 0.678 best val_acc 0.679
e 100 benign_norm 285.492 val loss 0.944 val acc 0.696 best val_acc 0.696
e 110 benign_norm 285.611 val loss 0.908 val acc 0.713 best val_acc 0.713
e 120 benign_norm 285.742 val loss 0.888 val acc 0.715 best val_acc 0.715
e 

KeyboardInterrupt: 

In [25]:
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 8
num_workers = 100

local_lr = 0.1
global_lr = 1
nepochs = 200

nbyz = 28
byz_type = 'full_trim'
aggregation = 'trim'


all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='dirichlet'
param = .5
force = True
each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    batch_size = batch_size
    train_loaders.append((pos, get_train(all_data, indices, batch_size)))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)
        # optimizer = optim.SGD(model.parameters(), lr = 1)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    agg_update = torch.mean(user_updates, 0)

    del user_updates

    model_received = model_received + global_lr * agg_update
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name

    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

generating participant indices for alpha 0.5
e 0 benign_norm 2415.255 val loss 2.305 val acc 0.099 best val_acc 0.099
e 10 benign_norm 572.385 val loss 1.694 val acc 0.387 best val_acc 0.387
e 20 benign_norm 572.360 val loss 1.323 val acc 0.526 best val_acc 0.526
e 30 benign_norm 573.014 val loss 1.121 val acc 0.601 best val_acc 0.601
e 40 benign_norm 573.828 val loss 0.953 val acc 0.663 best val_acc 0.663
e 50 benign_norm 574.615 val loss 0.841 val acc 0.703 best val_acc 0.703
e 60 benign_norm 575.766 val loss 0.765 val acc 0.736 best val_acc 0.736
e 70 benign_norm 577.174 val loss 0.714 val acc 0.754 best val_acc 0.754
e 80 benign_norm 577.915 val loss 0.678 val acc 0.768 best val_acc 0.768
e 90 benign_norm 578.681 val loss 0.651 val acc 0.778 best val_acc 0.780
e 100 benign_norm 578.745 val loss 0.625 val acc 0.786 best val_acc 0.788
e 110 benign_norm 579.562 val loss 0.607 val acc 0.791 best val_acc 0.794
e 120 benign_norm 579.342 val loss 0.597 val acc 0.798 best val_acc 0.799
e 1

# Dirichlet: Slow baselines for CIFAR10 (with 72% clients) similar to original work

In [None]:
torch.cuda.empty_cache()
nepochs=2000
local_epochs = 1
batch_size = 32
num_workers = 100

local_lr = 1
global_lr = .1
nepochs = 2000

nbyz = 28
byz_type = 'full_trim'
aggregation = 'trim'


all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='dirichlet'
param = .1
force = True
each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    batch_size = batch_size
    train_loaders.append((pos, get_train(all_data, indices, len(indices))))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
#         optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.99**epoch_num), momentum=0.9, weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr = local_lr)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    agg_update = torch.mean(user_updates, 0)

    del user_updates

    model_received = model_received + global_lr * (0.999**epoch_num) * agg_update
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name

    fed_model.load_state_dict(state_dict)

    if epoch_num%10==0 or epoch_num==nepochs-1:
        val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
        is_best = best_global_acc < val_acc
        best_global_acc = max(best_global_acc, val_acc)
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

generating participant indices for alpha 0.1
e 0 benign_norm 11.427 val loss 12.951 val acc 0.101 best val_acc 0.101
e 10 benign_norm 8.659 val loss 6.440 val acc 0.099 best val_acc 0.101
e 20 benign_norm 8.368 val loss 4.641 val acc 0.099 best val_acc 0.101
e 30 benign_norm 8.211 val loss 3.748 val acc 0.100 best val_acc 0.101
e 40 benign_norm 8.080 val loss 3.059 val acc 0.101 best val_acc 0.101
e 50 benign_norm 7.978 val loss 2.661 val acc 0.103 best val_acc 0.103
e 60 benign_norm 7.886 val loss 2.418 val acc 0.121 best val_acc 0.121
e 70 benign_norm 7.820 val loss 2.252 val acc 0.161 best val_acc 0.161
e 80 benign_norm 7.776 val loss 2.161 val acc 0.201 best val_acc 0.201
e 90 benign_norm 7.748 val loss 2.077 val acc 0.240 best val_acc 0.240
e 100 benign_norm 7.717 val loss 2.014 val acc 0.274 best val_acc 0.274
e 110 benign_norm 7.681 val loss 1.979 val acc 0.300 best val_acc 0.300
e 120 benign_norm 7.657 val loss 1.938 val acc 0.324 best val_acc 0.324
e 130 benign_norm 7.659 val 

In [23]:
torch.cuda.empty_cache()
nepochs=2000
local_epochs = 1
batch_size = 32
num_workers = 100

local_lr = 1
global_lr = .5
nepochs = 2000

nbyz = 28
byz_type = 'full_trim'
aggregation = 'trim'


all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='dirichlet'
param = .1
force = True
each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    batch_size = batch_size
    train_loaders.append((pos, get_train(all_data, indices, len(indices))))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
#         optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.99**epoch_num), momentum=0.9, weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr = local_lr)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    agg_update = torch.mean(user_updates, 0)

    del user_updates

    model_received = model_received + global_lr * (0.999**epoch_num) * agg_update
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name

    fed_model.load_state_dict(state_dict)

    if epoch_num%10==0 or epoch_num==nepochs-1:
        val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
        is_best = best_global_acc < val_acc
        best_global_acc = max(best_global_acc, val_acc)
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

generating participant indices for alpha 0.1
e 0 benign_norm 12.552 val loss 211.935 val acc 0.097 best val_acc 0.097
e 10 benign_norm 10.165 val loss 120.449 val acc 0.101 best val_acc 0.101
e 20 benign_norm 5.152 val loss 4.042 val acc 0.139 best val_acc 0.139
e 30 benign_norm 3.714 val loss 2.341 val acc 0.166 best val_acc 0.166
e 40 benign_norm 4.137 val loss 2.141 val acc 0.192 best val_acc 0.192
e 50 benign_norm 4.132 val loss 2.149 val acc 0.189 best val_acc 0.192
e 60 benign_norm 3.555 val loss 2.062 val acc 0.270 best val_acc 0.270
e 70 benign_norm 4.839 val loss 2.080 val acc 0.241 best val_acc 0.270
e 80 benign_norm 4.213 val loss 1.963 val acc 0.283 best val_acc 0.283
e 90 benign_norm 3.598 val loss 1.988 val acc 0.284 best val_acc 0.284
e 100 benign_norm 3.641 val loss 1.955 val acc 0.313 best val_acc 0.313
e 110 benign_norm 4.543 val loss 2.220 val acc 0.305 best val_acc 0.313
e 120 benign_norm 3.844 val loss 1.887 val acc 0.322 best val_acc 0.322
e 130 benign_norm 3.888 

KeyboardInterrupt: 