In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
import copy
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [2]:
from cifar10_models import *
from cifar10_util import *

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable

__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']

def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])

In [4]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))

        return per_participant_list

In [5]:
def get_fang_train_data(trainset, num_workers=100, bias=0.5):
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10

    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    per_participant_list = defaultdict(list)
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        
        per_participant_list[selected_worker].extend([i])
    
    return per_participant_list

In [6]:
def get_federated_data(trainset, num_workers, distribution='fang', param=1, force=False):
    if distribution == 'fang':
        per_participant_list = get_fang_train_data(trainset, num_workers, bias=param)
    elif distribution == 'dirichlet':
        per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=param, force=force)

    each_worker_idx = [[] for _ in range(num_workers)]
    
    each_worker_te_idx = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(6*w_len/7)
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        
        each_worker_idx[worker_idx] = w_indices[tr_idx]
        each_worker_te_idx[worker_idx] = w_indices[te_idx]
    
    global_test_idx = np.concatenate(each_worker_te_idx)
    
    return each_worker_idx, each_worker_te_idx, global_test_idx

In [7]:
def get_train(dataset, indices, batch_size=32, shuffle=False):
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
    
    return train_loader

In [8]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='/home/vshejwalkar_umass_edu/data/'
# load the train dataset

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_train)
cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_train)

te_cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_test)
te_cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [9]:
def lbfgs(S_k_list, Y_k_list, v):
    curr_S_k = torch.stack(S_k_list).T
    curr_Y_k = torch.stack(Y_k_list).T
    S_k_time_Y_k = np.dot(curr_S_k.T.cpu().numpy(), curr_Y_k.cpu().numpy())
    S_k_time_S_k = np.dot(curr_S_k.T.cpu().numpy(), curr_S_k.cpu().numpy())
    R_k = np.triu(S_k_time_Y_k)
    L_k = S_k_time_Y_k - R_k
    sigma_k = np.dot(Y_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()) / (np.dot(S_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()))
    D_k_diag = np.diag(S_k_time_Y_k)
    upper_mat = np.concatenate((sigma_k * S_k_time_S_k, L_k), axis=1)
    lower_mat = np.concatenate((L_k.T, -np.diag(D_k_diag)), axis=1)
    mat = np.concatenate((upper_mat, lower_mat), axis=0)
    mat_inv = np.linalg.inv(mat)

    approx_prod = sigma_k * v.cpu().numpy()
    approx_prod = approx_prod.T
    p_mat = np.concatenate((np.dot(curr_S_k.T.cpu().numpy(), sigma_k * v.unsqueeze(0).T.cpu().numpy()), np.dot(curr_Y_k.T.cpu().numpy(), v.unsqueeze(0).T.cpu().numpy())), axis=0)
    approx_prod -= np.dot(np.dot(np.concatenate((sigma_k * curr_S_k.cpu().numpy(), curr_Y_k.cpu().numpy()), axis=1), mat_inv), p_mat)

    return approx_prod

In [42]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(f):
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [11]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [12]:
def simple_mean(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = torch.mean(user_grads,dim=0)
    
    return agg_grads, distance

In [13]:
def median(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = torch.median(user_grads, 0)[0]
    
    return agg_grads, distance

In [14]:
def trimmed_mean(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = tr_mean(user_grads, b)
    
    return agg_grads, distance

In [15]:
def detection(score, nobyz):
    estimator = KMeans(n_clusters=2)
    estimator.fit(score.reshape(-1, 1))
    label_pred = estimator.labels_
    if np.mean(score[label_pred==0])<np.mean(score[label_pred==1]):
        #0 is the label of malicious clients
        label_pred = 1 - label_pred
    real_label=np.ones(100)
    real_label[:nobyz]=0
    acc=len(label_pred[label_pred==real_label])/100
    recall=1-np.sum(label_pred[:nobyz])/nobyz
    fpr=1-np.sum(label_pred[nobyz:])/(100-nobyz)
    fnr=np.sum(label_pred[:nobyz])/nobyz
    print("acc %0.4f; recall %0.4f; fpr %0.4f; fnr %0.4f;" % (acc, recall, fpr, fnr))
    print(silhouette_score(score.reshape(-1, 1), label_pred))

def detection1(score, nobyz):
    nrefs = 10
    ks = range(1, 8)
    gaps = np.zeros(len(ks))
    gapDiff = np.zeros(len(ks) - 1)
    sdk = np.zeros(len(ks))
    min = np.min(score)
    max = np.max(score)
    score = (score - min)/(max-min)
    for i, k in enumerate(ks):
        estimator = KMeans(n_clusters=k)
        estimator.fit(score.reshape(-1, 1))
        label_pred = estimator.labels_
        center = estimator.cluster_centers_
        Wk = np.sum([np.square(score[m]-center[label_pred[m]]) for m in range(len(score))])
        WkRef = np.zeros(nrefs)
        for j in range(nrefs):
            rand = np.random.uniform(0, 1, len(score))
            estimator = KMeans(n_clusters=k)
            estimator.fit(rand.reshape(-1, 1))
            label_pred = estimator.labels_
            center = estimator.cluster_centers_
            WkRef[j] = np.sum([np.square(rand[m]-center[label_pred[m]]) for m in range(len(rand))])
        gaps[i] = np.log(np.mean(WkRef)) - np.log(Wk)
        sdk[i] = np.sqrt((1.0 + nrefs) / nrefs) * np.std(np.log(WkRef))

        if i > 0:
            gapDiff[i - 1] = gaps[i - 1] - gaps[i] + sdk[i]
    #print(gapDiff)
    for i in range(len(gapDiff)):
        if gapDiff[i] >= 0:
            select_k = i+1
            break
    if select_k == 1:
        print('No attack detected!')
        return 0
    else:
        print('Attack Detected!')
        return 1

In [16]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(trainloader, model, model_received, criterion, optimizer, pgd=False, eps=2):
    # switch to train mode
    model.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(trainloader):

        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item()/100.0, inputs.size()[0])
        top5.update(prec5.item()/100.0, inputs.size()[0])

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if pgd:
            curr_model = list(model.parameters())
            curr_model_vec = parameters_to_vector(curr_model)

            if torch.norm(curr_model_vec - model_received) > eps:
                curr_model_vec = eps*(curr_model_vec - model_received)/torch.norm(curr_model_vec - model_received) + model_received
                vector_to_parameters(curr_model_vec, curr_model)
        
    return (losses.avg, top1.avg)

def test(testloader, model, criterion):
    model.eval()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)

        loss = criterion(outputs, targets)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data, inputs.size()[0])
        top1.update(prec1/100.0, inputs.size()[0])
        top5.update(prec5/100.0, inputs.size()[0])
    return (losses.avg, top1.avg)

# Fang + FLDetector - with state_dict + No attack

In [26]:
torch.cuda.empty_cache()
local_epochs = 1
batch_size = 32
num_workers = 100

local_lr = 1
global_lr = .5
nepochs = 1000

aggregation = 'trim'

all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='fang'
param = .5
force = True
each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    train_loaders.append((pos, get_train(all_data, indices, len(indices))))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = int(num_workers * 0.28)
good_distance_rage = np.zeros((1, nbyz))
malicious_scores = np.zeros((1, num_workers))
attack_type = 'none'
weight_record = []
grad_record = []
test_grads = []
old_grad_list = []

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_grads=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
#         optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.99**epoch_num), momentum=0.9, weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr = local_lr)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    weight = copy.deepcopy(model_received)

    if (epoch_num > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (epoch_num < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif epoch_num > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.01
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
#     agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and epoch_num > (start_detection_epoch - window_size):
        print('==>', epoch_num, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= (window_size+1):
        print('performing detection at epoch %d' % epoch_num)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', epoch_num)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if epoch_num > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    model_received = model_received + global_lr * (0.999 ** epoch_num) * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)

    if epoch_num%1==0 or epoch_num==nepochs-1:
        val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
        is_best = best_global_acc < val_acc
        best_global_acc = max(best_global_acc, val_acc)
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f' % (
            epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

e 0 benign_norm 9.039 val loss 46.354 val acc 0.100 best val_acc 0.100
e 1 benign_norm 8.299 val loss 229.226 val acc 0.100 best val_acc 0.100
e 2 benign_norm 10.367 val loss 304.616 val acc 0.100 best val_acc 0.100
e 3 benign_norm 7.427 val loss 792.109 val acc 0.099 best val_acc 0.100
e 4 benign_norm 6.935 val loss 1141.288 val acc 0.099 best val_acc 0.100
e 5 benign_norm 6.612 val loss 847.240 val acc 0.105 best val_acc 0.105
e 6 benign_norm 7.097 val loss 140.016 val acc 0.102 best val_acc 0.105
e 7 benign_norm 5.271 val loss 227.043 val acc 0.098 best val_acc 0.105
e 8 benign_norm 5.580 val loss 125.481 val acc 0.099 best val_acc 0.105
e 9 benign_norm 5.264 val loss 51.188 val acc 0.105 best val_acc 0.105
e 10 benign_norm 4.898 val loss 35.363 val acc 0.101 best val_acc 0.105
==> 11 (1, 100)
e 11 benign_norm 4.367 val loss 20.682 val acc 0.102 best val_acc 0.105
==> 12 (2, 100)
e 12 benign_norm 4.167 val loss 8.249 val acc 0.113 best val_acc 0.113
==> 13 (3, 100)
e 13 benign_norm 

# Fang + FLDetector + trim attack

# mean agr

In [35]:
torch.cuda.empty_cache()
local_epochs = 1
batch_size = 32
num_workers = 100

local_lr = 1
global_lr = .5
nepochs = 1000

aggregation = 'trim'

all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='fang'
param = .5
force = True
each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    train_loaders.append((pos, get_train(all_data, indices, len(indices))))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = int(num_workers * 0.28)
good_distance_rage = np.zeros((1, nbyz))
malicious_scores = np.zeros((1, num_workers))
attack_type = 'full_trim'
weight_record = []
grad_record = []
test_grads = []
old_grad_list = []

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_grads=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
#         optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.99**epoch_num), momentum=0.9, weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr = local_lr)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    weight = copy.deepcopy(model_received)

    if (epoch_num > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (epoch_num < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif epoch_num > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.01
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
#     agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and epoch_num > (start_detection_epoch - window_size):
        print('==>', epoch_num, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= (window_size+1):
        print('performing detection at epoch %d' % epoch_num)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', epoch_num)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if epoch_num > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    model_received = model_received + global_lr * (0.999 ** epoch_num) * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)

    if epoch_num%1==0 or epoch_num==nepochs-1:
        val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
        is_best = best_global_acc < val_acc
        best_global_acc = max(best_global_acc, val_acc)
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f' % (
            epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

e 0 benign_norm 9.662 val loss 461.655 val acc 0.100 best val_acc 0.100
e 1 benign_norm 12.622 val loss 6074.586 val acc 0.102 best val_acc 0.102
e 2 benign_norm 13.465 val loss 11501.611 val acc 0.102 best val_acc 0.102
e 3 benign_norm 16.115 val loss 33232.262 val acc 0.102 best val_acc 0.102
e 4 benign_norm 16.025 val loss 7643.963 val acc 0.097 best val_acc 0.102
e 5 benign_norm 12.231 val loss 9276.187 val acc 0.100 best val_acc 0.102
e 6 benign_norm 14.740 val loss 3148.542 val acc 0.097 best val_acc 0.102
e 7 benign_norm 12.070 val loss 1783.328 val acc 0.102 best val_acc 0.102
e 8 benign_norm 12.074 val loss 766.637 val acc 0.102 best val_acc 0.102
e 9 benign_norm 10.755 val loss 2902.072 val acc 0.100 best val_acc 0.102
e 10 benign_norm 11.643 val loss 350.978 val acc 0.102 best val_acc 0.102
==> 11 (1, 100)
e 11 benign_norm 9.070 val loss 2309.509 val acc 0.104 best val_acc 0.104
==> 12 (2, 100)
e 12 benign_norm 12.885 val loss 649.998 val acc 0.107 best val_acc 0.107
==> 13 

In [41]:
torch.cuda.empty_cache()
local_epochs = 1
batch_size = 32
num_workers = 100

local_lr = 1
global_lr = .5
nepochs = 1000

aggregation = 'trim'

all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='fang'
param = .5
force = True
each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    train_loaders.append((pos, get_train(all_data, indices, len(indices))))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = int(num_workers * 0.28)
good_distance_rage = np.zeros((1, nbyz))
malicious_scores = np.zeros((1, num_workers))
init_attack = 'LIE'
attack_type = 'LIE'
dev_type = 'std'
weight_record = []
grad_record = []
test_grads = []
old_grad_list = []

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_grads=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
#         optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.99**epoch_num), momentum=0.9, weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr = local_lr)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    weight = copy.deepcopy(model_received)

    if (epoch_num > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (epoch_num < start_detection_epoch):
        if init_attack == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif init_attack == 'LIE':
            z = 0.01
            noise_avg = torch.mean(user_grads[:nbyz], dim=0)
            noise_std = torch.std(user_grads[:nbyz], dim=0)
            noise = noise_avg + z * noise_std
            user_grads[:nbyz] = torch.stack(nbyz * [noise])

    elif epoch_num > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                z = 0.01
#                 noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
#     agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and epoch_num > (start_detection_epoch - window_size):
        print('==>', epoch_num, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= (window_size+1):
        print('performing detection at epoch %d' % epoch_num)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', epoch_num)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if epoch_num > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    model_received = model_received + global_lr * (0.999 ** epoch_num) * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)

    if epoch_num%1==0 or epoch_num==nepochs-1:
        val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
        is_best = best_global_acc < val_acc
        best_global_acc = max(best_global_acc, val_acc)
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f' % (
            epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

e 0 benign_norm 10.353 val loss 100.559 val acc 0.102 best val_acc 0.102
e 1 benign_norm 9.587 val loss 305.944 val acc 0.100 best val_acc 0.102
e 2 benign_norm 9.341 val loss 168.438 val acc 0.097 best val_acc 0.102
e 3 benign_norm 7.324 val loss 1003.820 val acc 0.103 best val_acc 0.103
e 4 benign_norm 8.057 val loss 833.734 val acc 0.100 best val_acc 0.103
e 5 benign_norm 6.991 val loss 544.614 val acc 0.098 best val_acc 0.103
e 6 benign_norm 6.419 val loss 153.872 val acc 0.097 best val_acc 0.103
e 7 benign_norm 5.686 val loss 63.572 val acc 0.098 best val_acc 0.103
e 8 benign_norm 5.212 val loss 35.630 val acc 0.098 best val_acc 0.103
e 9 benign_norm 4.887 val loss 19.439 val acc 0.099 best val_acc 0.103
e 10 benign_norm 4.586 val loss 12.241 val acc 0.119 best val_acc 0.119
==> 11 (1, 100)
e 11 benign_norm 4.233 val loss 12.693 val acc 0.101 best val_acc 0.119
==> 12 (2, 100)
e 12 benign_norm 4.215 val loss 24.238 val acc 0.097 best val_acc 0.119
==> 13 (3, 100)
e 13 benign_norm 

# trimmed mean agr

In [43]:
torch.cuda.empty_cache()
local_epochs = 1
batch_size = 32
num_workers = 100

local_lr = 1
global_lr = .5
nepochs = 1000

aggregation = 'trim'

all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='fang'
param = .5
force = True
each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    train_loaders.append((pos, get_train(all_data, indices, len(indices))))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = int(num_workers * 0.28)
good_distance_rage = np.zeros((1, nbyz))
malicious_scores = np.zeros((1, num_workers))
init_attack = 'LIE'
attack_type = 'LIE'
dev_type = 'std'
weight_record = []
grad_record = []
test_grads = []
old_grad_list = []

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_grads=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
#         optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.99**epoch_num), momentum=0.9, weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr = local_lr)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    weight = copy.deepcopy(model_received)

    if (epoch_num > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (epoch_num < start_detection_epoch):
        if init_attack == 'full_trim':
            user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
        elif init_attack == 'LIE':
            z = 0.01
            noise_avg = torch.mean(user_grads[:nbyz], dim=0)
            noise_std = torch.std(user_grads[:nbyz], dim=0)
            noise = noise_avg + z * noise_std
            user_grads[:nbyz] = torch.stack(nbyz * [noise])

    elif epoch_num > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                z = 0.01
#                 noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and epoch_num > (start_detection_epoch - window_size):
        print('==>', epoch_num, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= (window_size+1):
        print('performing detection at epoch %d' % epoch_num)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', epoch_num)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if epoch_num > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    model_received = model_received + global_lr * (0.999 ** epoch_num) * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)

    if epoch_num%1==0 or epoch_num==nepochs-1:
        val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
        is_best = best_global_acc < val_acc
        best_global_acc = max(best_global_acc, val_acc)
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f' % (
            epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

e 0 benign_norm 10.187 val loss 306.335 val acc 0.096 best val_acc 0.096
e 1 benign_norm 11.349 val loss 316.114 val acc 0.100 best val_acc 0.100
e 2 benign_norm 9.612 val loss 226.081 val acc 0.103 best val_acc 0.103
e 3 benign_norm 8.937 val loss 403.680 val acc 0.100 best val_acc 0.103
e 4 benign_norm 7.745 val loss 2553.563 val acc 0.093 best val_acc 0.103
e 5 benign_norm 9.172 val loss 4347.745 val acc 0.103 best val_acc 0.103
e 6 benign_norm 11.121 val loss 1229.835 val acc 0.103 best val_acc 0.103
e 7 benign_norm 9.710 val loss 455.237 val acc 0.102 best val_acc 0.103
e 8 benign_norm 9.996 val loss 101.610 val acc 0.102 best val_acc 0.103
e 9 benign_norm 9.125 val loss 69.453 val acc 0.095 best val_acc 0.103
e 10 benign_norm 8.534 val loss 48.394 val acc 0.101 best val_acc 0.103
==> 11 (1, 100)
e 11 benign_norm 7.675 val loss 26.177 val acc 0.106 best val_acc 0.106
==> 12 (2, 100)
e 12 benign_norm 7.186 val loss 14.765 val acc 0.084 best val_acc 0.106
==> 13 (3, 100)
e 13 benign

In [44]:
torch.cuda.empty_cache()
local_epochs = 1
batch_size = 32
num_workers = 100

local_lr = 1
global_lr = .5
nepochs = 1000

aggregation = 'trim'

all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='fang'
param = .5
force = True
# each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
#     all_data, num_workers=100, distribution=distribution, param=param, force=force)
# train_loaders = []
# for pos, indices in enumerate(each_worker_idx):
#     train_loaders.append((pos, get_train(all_data, indices, len(indices))))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = int(num_workers * 0.28)
good_distance_rage = np.zeros((1, nbyz))
malicious_scores = np.zeros((1, num_workers))
init_attack = 'full_trim'
attack_type = 'NDSS21'
dev_type = 'std'
weight_record = []
grad_record = []
test_grads = []
old_grad_list = []

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_grads=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
#         optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.99**epoch_num), momentum=0.9, weight_decay=1e-4)
        optimizer = optim.SGD(model.parameters(), lr = local_lr)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    weight = copy.deepcopy(model_received)

    if (epoch_num > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (epoch_num < start_detection_epoch):
        if init_attack == 'full_trim':
            user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
        elif init_attack == 'LIE':
            z = 0.01
            noise_avg = torch.mean(user_grads[:nbyz], dim=0)
            noise_std = torch.std(user_grads[:nbyz], dim=0)
            noise = noise_avg + z * noise_std
            user_grads[:nbyz] = torch.stack(nbyz * [noise])

    elif epoch_num > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                z = 0.01
#                 noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and epoch_num > (start_detection_epoch - window_size):
        print('==>', epoch_num, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= (window_size+1):
        print('performing detection at epoch %d' % epoch_num)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', epoch_num)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if epoch_num > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    model_received = model_received + global_lr * (0.999 ** epoch_num) * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)

    if epoch_num%1==0 or epoch_num==nepochs-1:
        val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
        is_best = best_global_acc < val_acc
        best_global_acc = max(best_global_acc, val_acc)
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f' % (
            epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

e 0 benign_norm 9.795 val loss 128.449 val acc 0.097 best val_acc 0.097
e 1 benign_norm 9.171 val loss 112.578 val acc 0.100 best val_acc 0.100
e 2 benign_norm 8.954 val loss 113.510 val acc 0.100 best val_acc 0.100
e 3 benign_norm 7.812 val loss 216.066 val acc 0.103 best val_acc 0.103
e 4 benign_norm 8.888 val loss 746.109 val acc 0.103 best val_acc 0.103
e 5 benign_norm 9.444 val loss 849.423 val acc 0.097 best val_acc 0.103
e 6 benign_norm 8.700 val loss 448.563 val acc 0.097 best val_acc 0.103
e 7 benign_norm 8.781 val loss 166.903 val acc 0.098 best val_acc 0.103
e 8 benign_norm 7.308 val loss 72.789 val acc 0.104 best val_acc 0.104
e 9 benign_norm 6.128 val loss 84.020 val acc 0.103 best val_acc 0.104
e 10 benign_norm 6.712 val loss 19.809 val acc 0.098 best val_acc 0.104
==> 11 (1, 100)
e 11 benign_norm 4.651 val loss 59.494 val acc 0.108 best val_acc 0.108
==> 12 (2, 100)
e 12 benign_norm 7.048 val loss 48.081 val acc 0.104 best val_acc 0.108
==> 13 (3, 100)
e 13 benign_norm 6

==> 66 (56, 100)
performing detection at epoch 66
No attack detected!
e 66 benign_norm 111.198 val loss 50291.457 val acc 0.103 best val_acc 0.108
==> 67 (57, 100)
performing detection at epoch 67
No attack detected!
e 67 benign_norm 286.445 val loss 11791.866 val acc 0.108 best val_acc 0.108
==> 68 (58, 100)
performing detection at epoch 68
No attack detected!
e 68 benign_norm 262.943 val loss 2406.462 val acc 0.100 best val_acc 0.108
==> 69 (59, 100)
performing detection at epoch 69
No attack detected!
e 69 benign_norm 242.543 val loss 921.980 val acc 0.103 best val_acc 0.108
==> 70 (60, 100)
performing detection at epoch 70
No attack detected!
e 70 benign_norm 226.409 val loss 5708.082 val acc 0.103 best val_acc 0.108
==> 71 (61, 100)
performing detection at epoch 71
No attack detected!
e 71 benign_norm 230.121 val loss 755.251 val acc 0.081 best val_acc 0.108
==> 72 (62, 100)
performing detection at epoch 72
No attack detected!
e 72 benign_norm 201.377 val loss 793.694 val acc 0.10

e 122 benign_norm 31.634 val loss 141.625 val acc 0.094 best val_acc 0.134
==> 123 (113, 100)
performing detection at epoch 123
No attack detected!
e 123 benign_norm 73.371 val loss 38.811 val acc 0.096 best val_acc 0.134
==> 124 (114, 100)
performing detection at epoch 124
No attack detected!
e 124 benign_norm 30.179 val loss 21.186 val acc 0.105 best val_acc 0.134
==> 125 (115, 100)
performing detection at epoch 125
No attack detected!
e 125 benign_norm 32.556 val loss 17.079 val acc 0.100 best val_acc 0.134
==> 126 (116, 100)
performing detection at epoch 126
No attack detected!
e 126 benign_norm 44.062 val loss 28.884 val acc 0.104 best val_acc 0.134
==> 127 (117, 100)
performing detection at epoch 127
No attack detected!
e 127 benign_norm 38.001 val loss 44.256 val acc 0.106 best val_acc 0.134
==> 128 (118, 100)
performing detection at epoch 128
No attack detected!
e 128 benign_norm 54.775 val loss 39.036 val acc 0.103 best val_acc 0.134
==> 129 (119, 100)
performing detection at 

No attack detected!
e 178 benign_norm 50.624 val loss 11.145 val acc 0.103 best val_acc 0.134
==> 179 (169, 100)
performing detection at epoch 179
No attack detected!
e 179 benign_norm 59.080 val loss 8.095 val acc 0.103 best val_acc 0.134
==> 180 (170, 100)
performing detection at epoch 180
No attack detected!
e 180 benign_norm 53.063 val loss 5.620 val acc 0.100 best val_acc 0.134
==> 181 (171, 100)
performing detection at epoch 181
No attack detected!
e 181 benign_norm 55.113 val loss 3.900 val acc 0.104 best val_acc 0.134
==> 182 (172, 100)
performing detection at epoch 182
No attack detected!
e 182 benign_norm 95.704 val loss 794.983 val acc 0.097 best val_acc 0.134
==> 183 (173, 100)
performing detection at epoch 183
No attack detected!
e 183 benign_norm 195.908 val loss 584.969 val acc 0.103 best val_acc 0.134
==> 184 (174, 100)
performing detection at epoch 184
No attack detected!
e 184 benign_norm 189.935 val loss 101.094 val acc 0.103 best val_acc 0.134
==> 185 (175, 100)
per

No attack detected!
e 239 benign_norm 417.774 val loss 21.620 val acc 0.103 best val_acc 0.134
==> 240 (230, 100)
performing detection at epoch 240
No attack detected!
e 240 benign_norm 412.920 val loss 17.440 val acc 0.103 best val_acc 0.134
==> 241 (231, 100)
performing detection at epoch 241
No attack detected!
e 241 benign_norm 399.495 val loss 16.937 val acc 0.134 best val_acc 0.134
==> 242 (232, 100)
performing detection at epoch 242
No attack detected!
e 242 benign_norm 416.271 val loss 123.739 val acc 0.099 best val_acc 0.134
==> 243 (233, 100)
performing detection at epoch 243
No attack detected!
e 243 benign_norm 403.733 val loss 56.341 val acc 0.100 best val_acc 0.134
==> 244 (234, 100)
performing detection at epoch 244
No attack detected!
e 244 benign_norm 329.984 val loss 47.508 val acc 0.103 best val_acc 0.134
==> 245 (235, 100)
performing detection at epoch 245
No attack detected!
e 245 benign_norm 307.597 val loss 42.570 val acc 0.103 best val_acc 0.134
==> 246 (236, 10

No attack detected!
e 294 benign_norm 829.469 val loss 945.743 val acc 0.098 best val_acc 0.134
==> 295 (285, 100)
performing detection at epoch 295
No attack detected!
e 295 benign_norm 649.063 val loss 1415.260 val acc 0.100 best val_acc 0.134
==> 296 (286, 100)
performing detection at epoch 296
No attack detected!
e 296 benign_norm 696.790 val loss 1228.139 val acc 0.095 best val_acc 0.134
==> 297 (287, 100)
performing detection at epoch 297
No attack detected!
e 297 benign_norm 614.678 val loss 1410.003 val acc 0.103 best val_acc 0.134
==> 298 (288, 100)
performing detection at epoch 298
No attack detected!
e 298 benign_norm 644.848 val loss 1219.252 val acc 0.088 best val_acc 0.134
==> 299 (289, 100)
performing detection at epoch 299
No attack detected!
e 299 benign_norm 653.679 val loss 1361.169 val acc 0.094 best val_acc 0.134
==> 300 (290, 100)
performing detection at epoch 300
No attack detected!
e 300 benign_norm 477.870 val loss 5569.326 val acc 0.103 best val_acc 0.134
==> 

# Fast baseline + Fang + FLDetector + attacks

In [None]:
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 8
num_workers = 100

local_lr = 0.1
global_lr = 1

nbyz = 28
byz_type = 'full_trim'
aggregation = 'trim'


all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))

num_workers = 100
distribution='fang'
param = .5
force = True
# each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
#     all_data, num_workers=100, distribution=distribution, param=param, force=force)
# train_loaders = []
# for pos, indices in enumerate(each_worker_idx):
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_data, indices, batch_size)))
# test_loaders = []
# for pos, indices in each_worker_te_idx.items():
#     batch_size = batch_size
#     train_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_grads=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)
        # optimizer = optim.SGD(model.parameters(), lr = 1)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    weight = copy.deepcopy(model_received)

    if (epoch_num > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (epoch_num < start_detection_epoch):
        if init_attack == 'full_trim':
            user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
        elif init_attack == 'LIE':
            z = 0.01
            noise_avg = torch.mean(user_grads[:nbyz], dim=0)
            noise_std = torch.std(user_grads[:nbyz], dim=0)
            noise = noise_avg + z * noise_std
            user_grads[:nbyz] = torch.stack(nbyz * [noise])

    elif epoch_num > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                z = 0.01
#                 noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and epoch_num > (start_detection_epoch - window_size):
        print('==>', epoch_num, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= (window_size+1):
        print('performing detection at epoch %d' % epoch_num)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', epoch_num)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if epoch_num > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads

    model_received = model_received + global_lr * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name

    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1