# FLDetector for MNIST with Dirichlet distribution

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [2]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_data = torch.utils.data.DataLoader(trainset, batch_size=60000, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_data = torch.utils.data.DataLoader(testset, batch_size=5000, shuffle=False)

In [3]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))
            
        return per_participant_list

In [26]:
def get_client_train_data(trainset, num_workers=100, bias=0.5):
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10

    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        
        if not len(each_worker_data[selected_worker]):
            each_worker_data[selected_worker] = x[None, :]
        else:
            each_worker_data[selected_worker]= torch.concat((each_worker_data[selected_worker], x[None, :]))
        
        each_worker_label[selected_worker].append(y)
    for i in range(num_workers):
        each_worker_label[i] = torch.Tensor(np.array(each_worker_label[i])).long()
    return each_worker_data, each_worker_label

In [5]:
def get_client_data_dirichlet(trainset, num_workers, alpha=1, force=False):
    per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=alpha, force=force)
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    for worker_idx in range(len(per_participant_list)):
        for idx in per_participant_list[worker_idx]:
            each_worker_data[worker_idx].append(trainset[idx][0])
            each_worker_label[worker_idx].append(trainset[idx][1])
        each_worker_data[worker_idx] = torch.stack(each_worker_data[worker_idx])
        each_worker_label[worker_idx] = torch.Tensor(each_worker_label[worker_idx]).long()
    return each_worker_data, each_worker_label

In [6]:
num_workers = 100
distribution='dirichlet'
if distribution=='bias':
    each_worker_data, each_worker_label = get_client_train_data(trainset, num_workers=100, bias=0.5)
elif distribution == 'dirichlet':
    alpha = .1
    force = True
    each_worker_data, each_worker_label = get_client_data_dirichlet(trainset, num_workers, alpha=alpha, force=force)

generating participant indices for alpha 0.1


In [7]:
dataset = 'mnist'
bias = 0.1
net = 'cnn'
batch_size = 32
# lr = 0.0002
# lr = 1e-3
lr = 0.01
nworkers = 100
nepochs = 100
gpu = 3
seed = 41
nbyz = 84
byz_type = 'full_trim'
aggregation = 'trim'

In [8]:
def lbfgs(S_k_list, Y_k_list, v):
    curr_S_k = torch.stack(S_k_list).T
    curr_Y_k = torch.stack(Y_k_list).T
    S_k_time_Y_k = np.dot(curr_S_k.T.cpu().numpy(), curr_Y_k.cpu().numpy())
    S_k_time_S_k = np.dot(curr_S_k.T.cpu().numpy(), curr_S_k.cpu().numpy())
    R_k = np.triu(S_k_time_Y_k)
    L_k = S_k_time_Y_k - R_k
    sigma_k = np.dot(Y_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()) / (np.dot(S_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()))
    D_k_diag = np.diag(S_k_time_Y_k)
    upper_mat = np.concatenate((sigma_k * S_k_time_S_k, L_k), axis=1)
    lower_mat = np.concatenate((L_k.T, -np.diag(D_k_diag)), axis=1)
    mat = np.concatenate((upper_mat, lower_mat), axis=0)
    mat_inv = np.linalg.inv(mat)

    approx_prod = sigma_k * v.cpu().numpy()
    approx_prod = approx_prod.T
    p_mat = np.concatenate((np.dot(curr_S_k.T.cpu().numpy(), sigma_k * v.unsqueeze(0).T.cpu().numpy()), np.dot(curr_Y_k.T.cpu().numpy(), v.unsqueeze(0).T.cpu().numpy())), axis=0)
    approx_prod -= np.dot(np.dot(np.concatenate((sigma_k * curr_S_k.cpu().numpy(), curr_Y_k.cpu().numpy()), axis=1), mat_inv), p_mat)

    return approx_prod

In [9]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(20):
        # apply attack to compromised worker devices with randomness
        ##random_12 = 1. + np.random.uniform(size=vi_shape)
        ##random_12 = torch.Tensor(random_12).float().cuda()
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [10]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [20]:
def simple_mean(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = torch.mean(user_grads,dim=0)
    
    return agg_grads, distance

In [19]:
def median(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = torch.median(user_grads, 0)[0]
    
    return agg_grads, distance

In [21]:
def trimmed_mean(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = tr_mean(user_grads, 20)
    
    return agg_grads, distance

In [15]:
def detection(score, nobyz):
    estimator = KMeans(n_clusters=2)
    estimator.fit(score.reshape(-1, 1))
    label_pred = estimator.labels_
    if np.mean(score[label_pred==0])<np.mean(score[label_pred==1]):
        #0 is the label of malicious clients
        label_pred = 1 - label_pred
    real_label=np.ones(100)
    real_label[:nobyz]=0
    acc=len(label_pred[label_pred==real_label])/100
    recall=1-np.sum(label_pred[:nobyz])/nobyz
    fpr=1-np.sum(label_pred[nobyz:])/(100-nobyz)
    fnr=np.sum(label_pred[:nobyz])/nobyz
    print("acc %0.4f; recall %0.4f; fpr %0.4f; fnr %0.4f;" % (acc, recall, fpr, fnr))
    print(silhouette_score(score.reshape(-1, 1), label_pred))

def detection1(score, nobyz):
    nrefs = 10
    ks = range(1, 8)
    gaps = np.zeros(len(ks))
    gapDiff = np.zeros(len(ks) - 1)
    sdk = np.zeros(len(ks))
    min = np.min(score)
    max = np.max(score)
    score = (score - min)/(max-min)
    for i, k in enumerate(ks):
        estimator = KMeans(n_clusters=k)
        estimator.fit(score.reshape(-1, 1))
        label_pred = estimator.labels_
        center = estimator.cluster_centers_
        Wk = np.sum([np.square(score[m]-center[label_pred[m]]) for m in range(len(score))])
        WkRef = np.zeros(nrefs)
        for j in range(nrefs):
            rand = np.random.uniform(0, 1, len(score))
            estimator = KMeans(n_clusters=k)
            estimator.fit(rand.reshape(-1, 1))
            label_pred = estimator.labels_
            center = estimator.cluster_centers_
            WkRef[j] = np.sum([np.square(rand[m]-center[label_pred[m]]) for m in range(len(rand))])
        gaps[i] = np.log(np.mean(WkRef)) - np.log(Wk)
        sdk[i] = np.sqrt((1.0 + nrefs) / nrefs) * np.std(np.log(WkRef))

        if i > 0:
            gapDiff[i - 1] = gaps[i - 1] - gaps[i] + sdk[i]
    #print(gapDiff)
    for i in range(len(gapDiff)):
        if gapDiff[i] >= 0:
            select_k = i+1
            break
    if select_k == 1:
        # print('No attack detected!')
        return 0
    else:
        print('Attack Detected!')
        return 1

In [14]:
num_workers = nworkers
lr = lr
epochs = nepochs
grad_list = []
old_grad_list = []
weight_record = []
grad_record = []
train_acc_list = []
distance1 = []
distance2 = []
auc_list = []

# Baseline FLDetector with various data distributions

In [28]:
num_workers = 100
distribution='bias'
param = 0.5
if distribution=='bias':
    each_worker_data, each_worker_label = get_client_train_data(trainset, num_workers=num_workers, bias=param)
elif distribution == 'dirichlet':
    force = True
    each_worker_data, each_worker_label = get_client_data_dirichlet(trainset, num_workers, alpha=param, force=force)
    
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
malicious_scores = np.zeros((1, num_workers))
attack_type='LIE'
criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.15
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)

for e in range(100):
    cnn_optimizer = SGD(net.parameters(), lr = lr*(0.96**e))
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, each_worker_label[i].cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    if(e < start_detection_epoch):
        user_grads = full_trim(user_grads, 20)
    elif e > start_detection_epoch:
        if attack_type == 'LIE':
            z = 0.3*(1.1**(e-20))
            noise_avg = torch.mean(user_grads[:20], dim=0)
            noise_std = torch.std(user_grads[:20], dim=0)
            noise = noise_avg + z * noise_std
        elif attack_type == 'NDSS21':
            pass
        else:
            noise = torch.zeros(hvp.shape).to(device)
        for m in range(20):
            user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise
            
#     agg_grads, distance = simple_mean(old_grad_list, user_grads, 0, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, 20, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        # print('==>', e, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        # print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-10:], axis=0), 20):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-10:], axis=0), 20)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)

    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]

    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
  
    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        for i, data in enumerate(test_data):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(e,correct/total)

0 0.056
1 0.0862
2 0.1001
3 0.101
4 0.101
5 0.101
6 0.101
7 0.101
8 0.101
9 0.0937
10 0.0696
11 0.098
12 0.0784
13 0.099
14 0.0958
15 0.0892
16 0.1009
17 0.1009
18 0.1009
19 0.098
Attack Detected!
Stop at iteration: 20
acc 0.9000; recall 1.0000; fpr 0.1250; fnr 0.0000;
0.864210377672373


# Adaptive attack against FLDetector based on simple noise addition

## Method: I added noise in the *little-is-enough attack fashion*

## Observation: FLdetector does not detect the attack and the model does not train.

In [26]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
malicious_scores = np.zeros((1, num_workers))
attack_type = 'LIE'

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.15
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)

for e in range(100):
    cnn_optimizer = SGD(net.parameters(), lr = lr*(0.96**e))
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, each_worker_label[i].cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    if(e < start_detection_epoch):
        user_grads = full_trim(user_grads, 20)
    elif e > start_detection_epoch:
        if attack_type == 'LIE':
            print("LIE")
            z = 0.3*(1.1**(e-20))
            noise_avg = torch.mean(user_grads[:20], dim=0)
            noise_std = torch.std(user_grads[:20], dim=0)
            noise = noise_avg + z * noise_std
        elif attack_type == 'NDSS21':
            pass
        else:
            noise = torch.zeros(hvp.shape).to(device)
        for m in range(20):
            user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise
#             user_grads[m] = torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = simple_mean(old_grad_list, user_grads, 20, hvp)
#     agg_grads, distance = trimmed_mean(old_grad_list, user_grads, 20, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        print('==>', e, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-10:], axis=0), 20):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-10:], axis=0), 20)
            break

#     agg_grads = tr_mean(user_grads, 20)
#     agg_grads=torch.median(user_grads,dim=0)[0]
#     agg_grads=torch.mean(user_grads,dim=0)

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
  
    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        for i, data in enumerate(test_data):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(e,correct/total)

0 0.0147
1 0.0807
2 0.1035
3 0.1294
4 0.1663
5 0.1848
6 0.1858
7 0.163
8 0.1281
9 0.2149
10 0.1629
LIE
<class 'torch.Tensor'>
Detection AUC: 0.7250; Detection AUC: 0.0000
==> 11 (1, 100)
11 0.098
LIE
<class 'torch.Tensor'>
Detection AUC: 1.0000; Detection AUC: 0.9625
==> 12 (2, 100)
12 0.1689
LIE
<class 'torch.Tensor'>
Detection AUC: 1.0000; Detection AUC: 0.0000
==> 13 (3, 100)
13 0.098
LIE
<class 'torch.Tensor'>
Detection AUC: 1.0000; Detection AUC: 0.0000
==> 14 (4, 100)
14 0.1062
LIE
<class 'torch.Tensor'>
Detection AUC: 0.4500; Detection AUC: 0.0000
==> 15 (5, 100)
15 0.135
LIE
<class 'torch.Tensor'>
Detection AUC: 1.0000; Detection AUC: 0.0000
==> 16 (6, 100)
16 0.098
LIE
<class 'torch.Tensor'>
Detection AUC: 0.9875; Detection AUC: 0.0000
==> 17 (7, 100)
17 0.098
LIE
<class 'torch.Tensor'>
Detection AUC: 0.6875; Detection AUC: 0.5875
==> 18 (8, 100)
18 0.098
LIE
<class 'torch.Tensor'>
Detection AUC: 1.0000; Detection AUC: 1.0000
==> 19 (9, 100)
19 0.098
LIE
<class 'torch.Tensor'>

In [40]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
malicious_scores = np.zeros((1, num_workers))
attack_type = 'LIE'

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.15
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)

for e in range(100):
    cnn_optimizer = SGD(net.parameters(), lr = lr*(0.96**e))
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, each_worker_label[i].cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    if(e < start_detection_epoch):
        user_grads = full_trim(user_grads, 20)
    elif e > start_detection_epoch:
        if attack_type == 'LIE':
            print("LIE")
            z = 0.3*(1.03**(e))
            noise_avg = torch.mean(user_grads[:20], dim=0)
            noise_std = torch.std(user_grads[:20], dim=0)
            noise = noise_avg + z * noise_std
        elif attack_type == 'NDSS21':
            pass
        else:
            noise = torch.zeros(hvp.shape).to(device)
        for m in range(20):
            user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise
#             user_grads[m] = torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, 20, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, 20, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        print('==>', e, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), 20):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), 20)
            break

#     agg_grads = tr_mean(user_grads, 20)
#     agg_grads=torch.median(user_grads,dim=0)[0]
#     agg_grads=torch.mean(user_grads,dim=0)

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
  
    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        for i, data in enumerate(test_data):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(e,correct/total)

0 0.1078
1 0.1097
2 0.1114
3 0.1129
4 0.1142
5 0.1121
6 0.1067
7 0.1047
8 0.3145
9 0.1795
10 0.1304
LIE
Detection AUC: 0.0000; Detection AUC: 0.6500
==> 11 (1, 100)
11 0.101
LIE
Detection AUC: 0.0000; Detection AUC: 0.0000
==> 12 (2, 100)
12 0.1032
LIE
Detection AUC: 0.0000; Detection AUC: 0.0000
==> 13 (3, 100)
13 0.1365
LIE
Detection AUC: 0.0000; Detection AUC: 0.0000
==> 14 (4, 100)
14 0.2116
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 15 (5, 100)
15 0.2325
LIE
Detection AUC: 0.0000; Detection AUC: 0.9750
==> 16 (6, 100)
16 0.2072
LIE
Detection AUC: 0.0000; Detection AUC: 0.9750
==> 17 (7, 100)
17 0.1739
LIE
Detection AUC: 0.0000; Detection AUC: 0.7000
==> 18 (8, 100)
18 0.1208
LIE
Detection AUC: 0.0000; Detection AUC: 0.9000
==> 19 (9, 100)
19 0.1597
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 20 (10, 100)
performing detection at epoch 20
No attack detected!
20 0.2201
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 21 (11, 100)
performing detection at epoch 2