# FLDetector for CIFAR10 with Fang distribution


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [2]:
import collections
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
import copy
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [3]:
from cifar10_models import *
from cifar10_util import *

In [4]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))

        return per_participant_list

In [5]:
def get_fang_train_data(trainset, num_workers=100, bias=0.5, force=False):
    dist_file = 'fang_nworkers%d_bias%.1f.pkl' % (num_workers, bias)
    if not force and os.path.exists(dist_file):
        print('Loading fang distribution for num_workers %d and bias %.1f from memory' % (num_workers, bias))
        return pickle.load(open(dist_file, 'rb'))
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    per_participant_list = defaultdict(list)
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()
        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y
        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        per_participant_list[selected_worker].extend([i])
    
    print('Saving fang distribution for num_workers %d and bias %.1f to memory' % (num_workers, bias))
    pickle.dump(per_participant_list, open(dist_file, 'wb'))
    return per_participant_list

In [6]:
def get_federated_data(trainset, num_workers, distribution='fang', param=1, force=False):
    if distribution == 'fang':
        per_participant_list = get_fang_train_data(trainset, num_workers, bias=param, force=force)
    elif distribution == 'dirichlet':
        per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=param, force=force)

    each_worker_idx = [[] for _ in range(num_workers)] 
    each_worker_te_idx = [[] for _ in range(num_workers)]
    np.random.seed(0)
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(5*w_len/6)
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        each_worker_idx[worker_idx] = w_indices[tr_idx]
        each_worker_te_idx[worker_idx] = w_indices[te_idx]
    global_test_idx = np.concatenate(each_worker_te_idx)
    return each_worker_idx, each_worker_te_idx, global_test_idx

In [7]:
def get_train(dataset, indices, batch_size=32, shuffle=False):
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
    
    return train_loader

In [8]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='/home/vshejwalkar_umass_edu/data/'
# load the train dataset

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_train)
cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_train)

te_cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_test)
te_cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [9]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(f):
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [10]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.sum(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [11]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(trainloader, model, model_received, criterion, optimizer, pgd=False, eps=2):
    model.train()
    losses = AverageMeter()
    top1 = AverageMeter()
    for batch_ind, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item()/100.0, inputs.size()[0])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if pgd:
            curr_model = list(model.parameters())
            curr_model_vec = parameters_to_vector(curr_model)
            if torch.norm(curr_model_vec - model_received) > eps:
                curr_model_vec = eps*(curr_model_vec - model_received)/torch.norm(curr_model_vec - model_received) + model_received
                vector_to_parameters(curr_model_vec, curr_model)
        
    return (losses.avg, top1.avg)

def test(testloader, model, criterion):
    model.eval()
    losses = AverageMeter()
    top1 = AverageMeter()
    for batch_ind, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data, inputs.size()[0])
        top1.update(prec1/100.0, inputs.size()[0])
    return (losses.avg, top1.avg)

# resnet-BN

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable

__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']

def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])

In [13]:
all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))
batch_size = 32
num_workers = 100
distribution='fang'
param = .5
force = False

each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)

train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    train_loaders.append((pos, get_train(all_data, indices, batch_size)))
test_loaders = []
for pos, indices in enumerate(each_worker_te_idx):
    test_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

len(train_loaders), len(test_loaders), len(global_test_idx)

Loading fang distribution for num_workers 100 and bias 0.5 from memory


(100, 100, 10038)

In [17]:
len(train_loaders[0][1])

17

# Mean with FedNOVA

# Attack Resnet20BN

In [16]:
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 32
num_workers = 100
local_lr = 0.1
global_lr = 1
nbyz = 20
total_data_len = 0
client_data_len = []
tau_per_client = []

for client_idx, worker_indices in enumerate(each_worker_idx[20-nbyz:]):
    if client_idx < nbyz:
        total_data_len += len(worker_indices)
        client_data_len.append(len(worker_indices))
        tau_per_client.append(len(train_loaders[client_idx][1]))
    else:
        total_data_len += len(worker_indices)
        client_data_len.append(len(worker_indices))
        tau_per_client.append(len(train_loaders[client_idx][1]))
client_data_len = np.array(client_data_len)
tau_per_client = np.array(tau_per_client)
client_ws = client_data_len / total_data_len

print('total data len: ', total_data_len)
print('client_data_len shape: ', client_data_len.shape)
print('tau_per_client shape: ', tau_per_client.shape)
print('client_ws shape: ', client_ws.shape)

criterion = nn.CrossEntropyLoss()
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(20-nbyz, num_workers)
    round_benign = round_clients[nbyz:]
    round_malicious = round_clients[:nbyz]
    user_updates=[]
    benign_norm = 0
    
    for client_idx in round_malicious:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[client_idx][1], model, model_received, criterion, optimizer)
        params = []
        for _, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))
        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    for client_idx in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[client_idx][1], model, model_received, criterion, optimizer)

        params = []
        for _, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    if epoch_num == 0:
        renormalized_client_weights = client_ws * np.sum(client_ws * tau_per_client) / tau_per_client
        print('len tau_per_client %d client_data_len %d' % (len(tau_per_client), len(client_data_len)))
        print('len client_ws %d renormalized_client_ws %d' % (len(client_ws), len(renormalized_client_weights)))

    user_updates *= torch.from_numpy(renormalized_client_weights)[:, None].to(device)
    user_updates[:nbyz] = full_trim(user_updates[:nbyz], nbyz)
    agg_update = tr_mean(user_updates, nbyz)
    del user_updates

    model_received = model_received + global_lr * (.999**epoch_num) * agg_update
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))
    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break
    epoch_num+=1

total data len:  49962
client_data_len shape:  (100,)
tau_per_client shape:  (100,)
client_ws shape:  (100,)
len tau_per_client 100 client_data_len 100
len client_ws 100 renormalized_client_ws 100
e 0 benign_norm 0.000 val loss 2.391 val acc 0.101 best val_acc 0.101
e 10 benign_norm 0.000 val loss 1.812 val acc 0.301 best val_acc 0.301
e 20 benign_norm 0.000 val loss 1.608 val acc 0.398 best val_acc 0.398
e 30 benign_norm 0.000 val loss 1.527 val acc 0.439 best val_acc 0.441
e 40 benign_norm 0.000 val loss 1.525 val acc 0.451 best val_acc 0.457
e 50 benign_norm 0.000 val loss 1.582 val acc 0.436 best val_acc 0.467
e 60 benign_norm 0.000 val loss 1.464 val acc 0.462 best val_acc 0.469
e 70 benign_norm 0.000 val loss 1.390 val acc 0.493 best val_acc 0.493
e 80 benign_norm 0.000 val loss 1.406 val acc 0.501 best val_acc 0.507
e 90 benign_norm 0.000 val loss 1.397 val acc 0.501 best val_acc 0.512
e 100 benign_norm 0.000 val loss 1.368 val acc 0.507 best val_acc 0.527
e 110 benign_norm 0.00

In [18]:
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 32
num_workers = 100
local_lr = 0.1
global_lr = 1
nbyz = 20
total_data_len = 0
client_data_len = []
tau_per_client = []

for client_idx, worker_indices in enumerate(each_worker_idx[20-nbyz:]):
    if client_idx < nbyz:
        total_data_len += len(worker_indices)
        client_data_len.append(len(worker_indices))
        tau_per_client.append(len(train_loaders[client_idx][1]))
    else:
        total_data_len += len(worker_indices)
        client_data_len.append(len(worker_indices))
        tau_per_client.append(len(train_loaders[client_idx][1]))
client_data_len = np.array(client_data_len)
tau_per_client = np.array(tau_per_client)
client_ws = client_data_len / total_data_len

print('total data len: ', total_data_len)
print('client_data_len shape: ', client_data_len.shape)
print('tau_per_client shape: ', tau_per_client.shape)
print('client_ws shape: ', client_ws.shape)

criterion = nn.CrossEntropyLoss()
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(20-nbyz, num_workers)
    round_benign = round_clients[nbyz:]
    round_malicious = round_clients[:nbyz]
    user_updates=[]
    benign_norm = 0
    
    for client_idx in round_malicious:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[client_idx][1], model, model_received, criterion, optimizer)
        params = []
        for _, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))
        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    for client_idx in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[client_idx][1], model, model_received, criterion, optimizer)

        params = []
        for _, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    if epoch_num == 0:
        renormalized_client_weights = client_ws * np.sum(client_ws * tau_per_client) / tau_per_client
        print('len tau_per_client %d client_data_len %d' % (len(tau_per_client), len(client_data_len)))
        print('len client_ws %d renormalized_client_ws %d' % (len(client_ws), len(renormalized_client_weights)))

    user_updates *= torch.from_numpy(renormalized_client_weights)[:, None].to(device)
    user_updates[:nbyz] = full_trim(user_updates[:nbyz], nbyz)
    agg_update = tr_mean(user_updates, nbyz)
    del user_updates

    model_received = model_received + global_lr * (.999**epoch_num) * agg_update
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))
    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break
    epoch_num+=1

total data len:  49962
client_data_len shape:  (100,)
tau_per_client shape:  (100,)
client_ws shape:  (100,)
e 10 benign_norm 0.000 val loss 1.850 val acc 0.283 best val_acc 0.283
e 20 benign_norm 0.000 val loss 1.653 val acc 0.388 best val_acc 0.392
e 30 benign_norm 0.000 val loss 1.750 val acc 0.371 best val_acc 0.413
e 40 benign_norm 0.000 val loss 1.797 val acc 0.350 best val_acc 0.413
e 50 benign_norm 0.000 val loss 1.663 val acc 0.397 best val_acc 0.414
e 60 benign_norm 0.000 val loss 1.643 val acc 0.443 best val_acc 0.455
e 70 benign_norm 0.000 val loss 1.460 val acc 0.486 best val_acc 0.486
e 80 benign_norm 0.000 val loss 1.343 val acc 0.529 best val_acc 0.529
e 90 benign_norm 0.000 val loss 1.315 val acc 0.533 best val_acc 0.554
e 100 benign_norm 0.000 val loss 1.267 val acc 0.555 best val_acc 0.560
e 110 benign_norm 0.000 val loss 1.267 val acc 0.550 best val_acc 0.560
e 120 benign_norm 0.000 val loss 1.210 val acc 0.567 best val_acc 0.567
e 130 benign_norm 0.000 val loss 1.5

In [None]:
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 32
num_workers = 100
local_lr = 0.1
global_lr = 1
nbyz = 20
total_data_len = 0
client_data_len = []
tau_per_client = []

for client_idx, worker_indices in enumerate(each_worker_idx[20-nbyz:]):
    if client_idx < nbyz:
        total_data_len += len(worker_indices)
        client_data_len.append(len(worker_indices))
        tau_per_client.append(len(train_loaders[client_idx][1]))
    else:
        total_data_len += len(worker_indices)
        client_data_len.append(len(worker_indices))
        tau_per_client.append(len(train_loaders[client_idx][1]))
client_data_len = np.array(client_data_len)
tau_per_client = np.array(tau_per_client)
client_ws = client_data_len / total_data_len

print('total data len: ', total_data_len)
print('client_data_len shape: ', client_data_len.shape)
print('tau_per_client shape: ', tau_per_client.shape)
print('client_ws shape: ', client_ws.shape)

criterion = nn.CrossEntropyLoss()
best_global_acc = best_ema_acc = 0
epoch_num = 0
ema_decay = 0.95

fed_model = resnet20().cuda()

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))
ema_model_received = copy.deepcopy(model_received)

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(20-nbyz, num_workers)
    round_benign = round_clients[nbyz:]
    round_malicious = round_clients[:nbyz]
    user_updates=[]
    benign_norm = 0
    
    for client_idx in round_malicious:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[client_idx][1], model, model_received, criterion, optimizer)
        params = []
        for _, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))
        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    for client_idx in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[client_idx][1], model, model_received, criterion, optimizer)

        params = []
        for _, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    if epoch_num == 0:
        renormalized_client_weights = client_ws * np.sum(client_ws * tau_per_client) / tau_per_client
        print('len tau_per_client %d client_data_len %d' % (len(tau_per_client), len(client_data_len)))
        print('len client_ws %d renormalized_client_ws %d' % (len(client_ws), len(renormalized_client_weights)))

    user_updates *= torch.from_numpy(renormalized_client_weights)[:, None].to(device)
    user_updates[:nbyz] = full_trim(user_updates[:nbyz], nbyz)
    agg_update = tr_mean(user_updates, nbyz)
    del user_updates

    model_received = model_received + global_lr * (.999**epoch_num) * agg_update
    round_ema_decay = min(ema_decay, (1+epoch_num)/(10+epoch_num))
    ema_model_received = ema_model_received * round_ema_decay + model_received * (1 - round_ema_decay)
    
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for _, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)

    ema_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for _, (name, _) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = ema_model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    ema_model.load_state_dict(state_dict)
    
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    ema_loss, ema_acc = test(cifar10_test_loader, ema_model, criterion)
    is_best = best_global_acc < val_acc
    
    best_global_acc = max(best_global_acc, val_acc)
    best_ema_acc = max(best_ema_acc, ema_acc)
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))
    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break
    epoch_num+=1

In [29]:
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.1
global_lr = 1
nbyz = 20

criterion = nn.CrossEntropyLoss()
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    user_updates[:nbyz] = full_trim(user_updates[:nbyz], nbyz)
    agg_update = tr_mean(user_updates, nbyz)
    del user_updates

    model_received = model_received + global_lr * agg_update
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    
    best_accs_per_round.append(best_global_acc.cpu().item())
    accs_per_round.append(val_acc.cpu().item())
    loss_per_round.append(val_loss.cpu().item())
    
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))
    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break
    epoch_num+=1
    
final_accs_per_client=[]
for i in range(num_workers):
    client_loss, client_acc = test(test_loaders[i][1], fed_model, criterion)
    final_accs_per_client.append(client_acc.cpu().item())
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_cifar10_fast_trmean_trim20_resnet20BN.pkl'), 'wb'))

e 0 benign_norm 782.413 val loss 2.310 val acc 0.098 best val_acc 0.098
e 10 benign_norm 277.659 val loss 1.768 val acc 0.346 best val_acc 0.346
e 20 benign_norm 277.721 val loss 1.667 val acc 0.409 best val_acc 0.413
e 30 benign_norm 278.085 val loss 1.646 val acc 0.419 best val_acc 0.438
e 40 benign_norm 278.773 val loss 1.695 val acc 0.429 best val_acc 0.451
e 50 benign_norm 279.670 val loss 1.637 val acc 0.478 best val_acc 0.478
e 60 benign_norm 281.176 val loss 1.695 val acc 0.474 best val_acc 0.483
e 70 benign_norm 282.138 val loss 1.683 val acc 0.482 best val_acc 0.492
e 80 benign_norm 283.360 val loss 1.581 val acc 0.516 best val_acc 0.516
e 90 benign_norm 286.147 val loss 1.540 val acc 0.504 best val_acc 0.523
e 100 benign_norm 291.872 val loss 1.792 val acc 0.498 best val_acc 0.533
e 110 benign_norm 293.738 val loss 1.615 val acc 0.524 best val_acc 0.533
e 120 benign_norm 297.880 val loss 1.694 val acc 0.519 best val_acc 0.533
e 130 benign_norm 306.630 val loss 1.681 val acc 

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.