In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [3]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))
            
        return per_participant_list

In [5]:
def get_client_train_data(trainset, num_workers=100, bias=0.5):
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10

    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        
        if not len(each_worker_data[selected_worker]):
            each_worker_data[selected_worker] = x[None, :]
        else:
            each_worker_data[selected_worker]= torch.concat((each_worker_data[selected_worker], x[None, :]))
        
        each_worker_label[selected_worker].append(y)
    
    each_worker_tr_data = [[] for _ in range(num_workers)]
    each_worker_tr_label = [[] for _ in range(num_workers)]
    each_worker_te_data = [[] for _ in range(num_workers)]
    each_worker_te_label = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    for i in range(num_workers):
        w_len = len(each_worker_data[i])
        len_tr = int(6 * w_len / 7)
        len_te = w_len - len_tr
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        each_worker_tr_data[i] = each_worker_data[i][tr_idx]
        each_worker_tr_label[i] = torch.Tensor(each_worker_label[i])[tr_idx]
        
        each_worker_te_data[i] = each_worker_data[i][te_idx]
        each_worker_te_label[i] = torch.Tensor(each_worker_label[i])[te_idx]
        
    global_test_data = torch.concat(each_worker_te_data)
    global_test_label = torch.concat(each_worker_te_label)
    del each_worker_data, each_worker_label
    return each_worker_tr_data, each_worker_tr_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label

In [6]:
def get_client_data_dirichlet(trainset, num_workers, alpha=1, force=False):
    per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=alpha, force=force)
    
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    each_worker_te_data = [[] for _ in range(num_workers)]
    each_worker_te_label = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(6*w_len/7)
        len_te = w_len - len_tr
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        
        for idx in tr_idx:
            each_worker_data[worker_idx].append(trainset[idx][0])
            each_worker_label[worker_idx].append(trainset[idx][1])
        each_worker_data[worker_idx] = torch.stack(each_worker_data[worker_idx])
        each_worker_label[worker_idx] = torch.Tensor(each_worker_label[worker_idx]).long()
        
        for idx in te_idx:
            each_worker_te_data[worker_idx].append(trainset[idx][0])
            each_worker_te_label[worker_idx].append(trainset[idx][1])
        each_worker_te_data[worker_idx] = torch.stack(each_worker_te_data[worker_idx])
        each_worker_te_label[worker_idx] = torch.Tensor(each_worker_te_label[worker_idx]).long()
    
    global_test_data = torch.concat(each_worker_te_data)
    global_test_label = torch.concat(each_worker_te_label)
    
    return each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label


In [7]:
dataset = 'mnist'
bias = 0.1
net = 'cnn'
batch_size = 32
# lr = 0.0002
# lr = 1e-3
lr = 0.01
nworkers = 100
nepochs = 100
gpu = 3
seed = 41
nbyz = 20
byz_type = 'full_trim'
aggregation = 'trim'

In [8]:
def lbfgs(S_k_list, Y_k_list, v):
    curr_S_k = torch.stack(S_k_list).T
    curr_Y_k = torch.stack(Y_k_list).T
    S_k_time_Y_k = np.dot(curr_S_k.T.cpu().numpy(), curr_Y_k.cpu().numpy())
    S_k_time_S_k = np.dot(curr_S_k.T.cpu().numpy(), curr_S_k.cpu().numpy())
    R_k = np.triu(S_k_time_Y_k)
    L_k = S_k_time_Y_k - R_k
    sigma_k = np.dot(Y_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()) / (np.dot(S_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()))
    D_k_diag = np.diag(S_k_time_Y_k)
    upper_mat = np.concatenate((sigma_k * S_k_time_S_k, L_k), axis=1)
    lower_mat = np.concatenate((L_k.T, -np.diag(D_k_diag)), axis=1)
    mat = np.concatenate((upper_mat, lower_mat), axis=0)
    mat_inv = np.linalg.inv(mat)

    approx_prod = sigma_k * v.cpu().numpy()
    approx_prod = approx_prod.T
    p_mat = np.concatenate((np.dot(curr_S_k.T.cpu().numpy(), sigma_k * v.unsqueeze(0).T.cpu().numpy()), np.dot(curr_Y_k.T.cpu().numpy(), v.unsqueeze(0).T.cpu().numpy())), axis=0)
    approx_prod -= np.dot(np.dot(np.concatenate((sigma_k * curr_S_k.cpu().numpy(), curr_Y_k.cpu().numpy()), axis=1), mat_inv), p_mat)

    return approx_prod

In [22]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(f):
        # apply attack to compromised worker devices with randomness
        ##random_12 = 1. + np.random.uniform(size=vi_shape)
        ##random_12 = torch.Tensor(random_12).float().cuda()
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [10]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [11]:
def simple_mean(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        
        pred = np.zeros(len(old_gradients))
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        auc2 = roc_auc_score(pred, distance)
        auc1 = 0
        print("Detection AUC: %0.4f; Detection AUC: %0.4f" % (auc1, auc2))
        
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = torch.mean(user_grads, dim=0)
    
    return agg_grads, distance

In [12]:
def trimmed_mean(old_gradients, user_grads, b=0, hvp=None):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        
        pred = np.zeros(len(old_gradients))
        pred[:b] = 1
        auc1 = 0
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        auc2 = roc_auc_score(pred, distance)
        print("Detection AUC: %0.4f; Detection AUC: %0.4f" % (auc1, auc2))        
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    agg_grads = tr_mean(user_grads, b)
    
    return agg_grads, distance

In [15]:
def detection(score, nobyz):
    estimator = KMeans(n_clusters=2)
    estimator.fit(score.reshape(-1, 1))
    label_pred = estimator.labels_
    if np.mean(score[label_pred==0])<np.mean(score[label_pred==1]):
        #0 is the label of malicious clients
        label_pred = 1 - label_pred
    real_label=np.ones(len(score))
    real_label[:nobyz]=0
    acc=len(label_pred[label_pred==real_label])/100
    recall=1-np.sum(label_pred[:nobyz])/nobyz
    fpr=1-np.sum(label_pred[nobyz:])/(len(score)-nobyz)
    fnr=np.sum(label_pred[:nobyz])/nobyz
    print("acc %0.4f; recall %0.4f; fpr %0.4f; fnr %0.4f;" % (acc, recall, fpr, fnr))
    print(silhouette_score(score.reshape(-1, 1), label_pred))

def detection1(score, nobyz):
    nrefs = 10
    ks = range(1, 8)
    gaps = np.zeros(len(ks))
    gapDiff = np.zeros(len(ks) - 1)
    sdk = np.zeros(len(ks))
    min = np.min(score)
    max = np.max(score)
    score = (score - min)/(max-min)
    for i, k in enumerate(ks):
        estimator = KMeans(n_clusters=k)
        estimator.fit(score.reshape(-1, 1))
        label_pred = estimator.labels_
        center = estimator.cluster_centers_
        Wk = np.sum([np.square(score[m]-center[label_pred[m]]) for m in range(len(score))])
        WkRef = np.zeros(nrefs)
        for j in range(nrefs):
            rand = np.random.uniform(0, 1, len(score))
            estimator = KMeans(n_clusters=k)
            estimator.fit(rand.reshape(-1, 1))
            label_pred = estimator.labels_
            center = estimator.cluster_centers_
            WkRef[j] = np.sum([np.square(rand[m]-center[label_pred[m]]) for m in range(len(rand))])
        gaps[i] = np.log(np.mean(WkRef)) - np.log(Wk)
        sdk[i] = np.sqrt((1.0 + nrefs) / nrefs) * np.std(np.log(WkRef))

        if i > 0:
            gapDiff[i - 1] = gaps[i - 1] - gaps[i] + sdk[i]
    print(gapDiff)
    for i in range(len(gapDiff)):
        if gapDiff[i] >= 0:
            select_k = i+1
            break
    if select_k == 1:
        print('No attack detected!')
        return 0
    else:
        print('Attack Detected!')
        return 1

In [16]:
num_workers = nworkers
lr = lr
epochs = nepochs
grad_list = []
old_grad_list = []
weight_record = []
grad_record = []
train_acc_list = []
distance1 = []
distance2 = []
auc_list = []

In [17]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(30, 50, 5)
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(800, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Match baseline with dirichlet distribution and 72 clients


In [19]:
all_data = torch.utils.data.ConcatDataset((trainset, testset))
num_workers = 100
distribution='fang'
param = .5
force = True

if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

# trimmed-mean without attack (72% clients)

In [21]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 28

lr = 0.02
for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    agg_grads = torch.mean(user_grads, 0)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    if e%50==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(e,correct/total)

0 0.06598984771573604
50 0.5348860356325271
100 0.8523937493779238
150 0.8958893201950832
200 0.915497163332338
250 0.9233602070269732
300 0.9331143624962676
350 0.9375933114362496
400 0.9425699213695631
450 0.946252612720215
500 0.9494376430775355
550 0.9526226734348562
600 0.9546133174081816
650 0.9572011545735045
700 0.958495073156166
750 0.9609833781228228
800 0.9631730864934807
850 0.9643674728774758
900 0.9652632626654722
950 0.9665571812481337
999 0.9678510998307953


In [23]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 28

lr = 0.02
for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    user_grads = full_trim(user_grads, nbyz)
    agg_grads = tr_mean(user_grads, nbyz)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    if e%50==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(e,correct/total)

0 0.09415745993829004
50 0.2762018512988952
100 0.307654026077436
150 0.27819249527222056
200 0.33681696028665276
250 0.4464019110182144
300 0.4975614611326764
350 0.8360704687966557
400 0.7418134766596994
450 0.8636408878272122
500 0.8825520055738031
550 0.8245247337513686
600 0.8218373643873793
650 0.893998208420424
700 0.8330845028366677
750 0.8992734149497362
800 0.9117149397830198
850 0.9305265253309446
900 0.9231611426296407
950 0.9328157659002687
999 0.9216681596496467


In [58]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
num_workers = 100
malicious_scores = np.zeros((1, num_workers))

attack_type = 'NDSS21'
dev_type = 'std'
all_data = torch.utils.data.ConcatDataset((trainset, testset))

# distribution='fang'
# param = .5
# force = True
# if distribution=='fang':
#     each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
# elif distribution == 'dirichlet':
#     each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.02
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = 28
good_distance_rage = np.zeros((1, nbyz))

for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (e < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif e > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.1
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                # noise = deviation * ((distance_bound + np.random.uniform(0, np.std(good_distance_rage[-1]))) / torch.norm(deviation))
                noise = deviation * ((distance_bound)) / torch.norm(deviation)
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > window_size):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    
    if e %10 ==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('===>  ', e,correct/total)


===>   0 0.1353637901861252
===>   10 0.08151687070767394
Detection AUC: 0.0000; Detection AUC: 0.3472
Detection AUC: 0.0000; Detection AUC: 0.9722
Detection AUC: 0.0000; Detection AUC: 0.4028
Detection AUC: 0.0000; Detection AUC: 0.4028
Detection AUC: 0.0000; Detection AUC: 0.4306
Detection AUC: 0.0000; Detection AUC: 0.4028
Detection AUC: 0.0000; Detection AUC: 0.4028
Detection AUC: 0.0000; Detection AUC: 0.3333
Detection AUC: 0.0000; Detection AUC: 0.3611
Detection AUC: 0.0000; Detection AUC: 0.3472
[ 0.65755221  0.05327445  0.264413   -0.04188027  0.18342663  0.01243712]
No attack detected!
===>   20 0.08241266049567035
Detection AUC: 0.0000; Detection AUC: 0.4444
[ 0.63313885 -0.07475372  0.33561425 -0.03779776  0.02242355  0.19360693]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.4167
[ 0.69999202  0.00494086  0.28261571  0.02189032 -0.04166124  0.17756099]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.3889
[ 0.65331771 -0.02703131  0.23056889  0.06343

Detection AUC: 0.0000; Detection AUC: 0.5972
[ 0.56307062 -0.11556678  0.02550552  0.03260476  0.12683715  0.0278452 ]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.5278
[ 0.67847903 -0.16076997  0.21234517  0.07333614  0.14676848  0.06699319]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.6389
[ 0.64805146 -0.15786231  0.30377933  0.03979009  0.08487627  0.16583863]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.5278
[ 0.69796601 -0.23810225  0.44519039 -0.11804426  0.05582435  0.13466097]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.5139
[ 0.62410635 -0.13259003  0.29308692 -0.00189181  0.19479431  0.06569161]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.5139
[ 0.58247945 -0.07787602  0.25495471 -0.04908652  0.20724702  0.09438522]
No attack detected!
===>   80 0.10012939185826615
Detection AUC: 0.0000; Detection AUC: 0.5000
[ 0.56063558 -0.11009097  0.25268788 -0.06230569  0.09258773  0.08111064]
No attack detecte

[0.37905185 0.33544364 0.04459581 0.23573149 0.19424153 0.04823279]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.6111
[0.37728704 0.20067991 0.01203918 0.20062544 0.24641895 0.11043733]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.5972
[ 0.51949247  0.06911464  0.15898452  0.2878074  -0.04993712  0.20644392]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.6667
[0.44153341 0.19347882 0.16088767 0.31104957 0.10261434 0.01245934]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.4722
[0.46290791 0.0838309  0.33009244 0.2061926  0.00271313 0.03872403]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.4444
[0.43137906 0.21918001 0.2425231  0.15710062 0.07303756 0.14159013]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.5278
[ 0.50009562  0.31131305  0.12224987  0.14748559  0.05510513 -0.00379202]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.5139
[ 0.46598238  0.32377104  0.07382531  0.21635901  0.04511

Detection AUC: 0.0000; Detection AUC: 0.7083
[ 0.40302104 -0.03024116  0.19258606  0.06647902  0.02130004  0.05772022]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.8611
[ 0.44615672 -0.01992096  0.18710915  0.16913377  0.14083574  0.01579372]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.7639
[ 0.31920642  0.00693675  0.22030843  0.17155677 -0.04764456 -0.03982263]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.7639
[ 0.35642886 -0.07506648  0.34002179  0.1269215  -0.11046803  0.1354772 ]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.7639
[ 0.32312101 -0.0441231   0.12621512  0.12559638  0.0762071  -0.04924114]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.6528
[ 0.29390316 -0.07990647  0.13802844  0.17532088  0.08702877  0.03081946]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.7917
[ 0.32294748 -0.13626709  0.1511899   0.06214906  0.20933105 -0.07491127]
No attack detected!
Detection AUC: 0.0000; Dete

Detection AUC: 0.0000; Detection AUC: 0.9167
[ 0.23495358  0.06940541  0.01281434 -0.06496064  0.04741145  0.20553777]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.9167
[ 0.21348537 -0.05585399 -0.01515814  0.07079216  0.18435404  0.15249729]
No attack detected!
===>   270 0.4018114860157261
Detection AUC: 0.0000; Detection AUC: 0.9028
[ 0.3257442  -0.07145041 -0.16501319  0.09790062  0.21319397  0.1507091 ]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.8889
[ 0.32485768 -0.27136195 -0.04550522  0.08092333  0.06913017  0.06186379]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.9306
[ 0.4438958  -0.39518851  0.04612607  0.21416283  0.04845698  0.05260262]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.8472
[ 0.44206146 -0.46198997  0.09769917  0.19308769  0.17428378  0.11067105]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.8750
[ 0.47158071 -0.49745023  0.03003431  0.02002342  0.33464711  0.19319975]
No attack detecte

Detection AUC: 0.0000; Detection AUC: 0.8611
[ 0.11562276  0.17083671 -0.05405338 -0.0863441   0.20955104  0.19456291]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.9306
[ 0.20409823  0.0174101   0.17115961 -0.04033825  0.10838084  0.2156361 ]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.8750
[0.08450679 0.08664669 0.10288407 0.0938176  0.22643021 0.07484872]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.8889
[ 0.06137523  0.18501521  0.10864743 -0.01180571  0.19254917  0.15240704]
No attack detected!
===>   330 0.6348163630934608
Detection AUC: 0.0000; Detection AUC: 0.9306
[-0.18444892  0.22015007  0.09933993 -0.1182948   0.23000923  0.12577539]
Attack Detected!
Stop at iteration: 331
acc 0.7400; recall 1.0000; fpr 0.3611; fnr 0.0000;
0.69490125564236


In [59]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
num_workers = 100
malicious_scores = np.zeros((1, num_workers))

attack_type = 'NDSS21'
dev_type = 'std'
all_data = torch.utils.data.ConcatDataset((trainset, testset))

# distribution='fang'
# param = .5
# force = True
# if distribution=='fang':
#     each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
# elif distribution == 'dirichlet':
#     each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.02
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = 28
good_distance_rage = np.zeros((1, nbyz))

for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (e < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif e > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.1
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.random.choice(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                # noise = deviation * ((distance_bound + np.random.uniform(0, np.std(good_distance_rage[-1]))) / torch.norm(deviation))
                noise = deviation * ((distance_bound)) / torch.norm(deviation)
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > window_size):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    
    if e %10 ==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('===>  ', e,correct/total)



===>   0 0.13486612919279387
===>   10 0.10719617796357121
Detection AUC: 0.0000; Detection AUC: 0.3472
Detection AUC: 0.0000; Detection AUC: 0.9444
Detection AUC: 0.0000; Detection AUC: 0.1806
Detection AUC: 0.0000; Detection AUC: 0.8889
Detection AUC: 0.0000; Detection AUC: 0.8611
Detection AUC: 0.0000; Detection AUC: 0.8472
Detection AUC: 0.0000; Detection AUC: 0.8889
Detection AUC: 0.0000; Detection AUC: 0.6250
Detection AUC: 0.0000; Detection AUC: 0.6667
Detection AUC: 0.0000; Detection AUC: 0.4583
[ 0.3127919   0.31176786 -0.18399457  0.15290231 -0.02395434  0.15519789]
No attack detected!
===>   20 0.12481337712750075
Detection AUC: 0.0000; Detection AUC: 0.5139
[ 0.36620293  0.3253052  -0.21798101  0.21276772 -0.01310617  0.18608286]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.2083
[0.4674064  0.1007364  0.13404149 0.07585961 0.04580398 0.05092567]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.9583
[ 0.22801514  0.40427601 -0.03367152  0.22985395  

Detection AUC: 0.0000; Detection AUC: 0.2639
[ 0.39664371  0.27189208  0.00095918  0.15904306  0.16905033 -0.118741  ]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.1528
[ 0.39954067  0.42926126 -0.05611842  0.09192077  0.12712783 -0.01003118]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.1667
[ 0.32435783  0.44024806 -0.05910464  0.20023121 -0.01527472  0.20011072]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.3611
[ 0.44351134  0.38472677 -0.03832486  0.11441934 -0.02953614  0.13637196]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.1806
[ 0.41761143  0.43313664 -0.20508853  0.11916648  0.04173572  0.11124712]
No attack detected!
Detection AUC: 0.0000; Detection AUC: 0.1111
[ 0.45340796  0.31332241 -0.20765573  0.12511425  0.07404368  0.10899721]
No attack detected!
===>   80 0.18712053349258484
Detection AUC: 0.0000; Detection AUC: 0.5278
[ 0.46340663  0.2126283  -0.11173369  0.24957703  0.03732479  0.10456418]
No attack detecte

In [62]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
num_workers = 100
malicious_scores = np.zeros((1, num_workers))

attack_type = 'NDSS21'
dev_type = 'std'
all_data = torch.utils.data.ConcatDataset((trainset, testset))

distribution='fang'
param = .5
force = True
if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.1
cnn_optimizer = SGD(net.parameters(), lr = lr, momentum=0.9, weight_decay=1e-5)
sch = torch.optim.lr_scheduler.ExponentialLR(cnn_optimizer, gamma=0.995)

start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = 28
good_distance_rage = np.zeros((1, nbyz))

for e in range(50):
    
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (e < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif e > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.1
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.random.choice(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
#                 noise = deviation * ((distance_bound)) / torch.norm(deviation)
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                noise = deviation * ((distance_bound + np.random.uniform(0, np.std(good_distance_rage[-1]))) / torch.norm(deviation))
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > window_size):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    sch.step()
    if e %10 ==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('===>  ', e,correct/total)




===>   0 0.048576547879753136
===>   10 0.09944256420465857
Detection AUC: 0.0000; Detection AUC: 0.2416
Detection AUC: 0.0000; Detection AUC: 0.0084
Detection AUC: 0.0000; Detection AUC: 0.0035
Detection AUC: 0.0000; Detection AUC: 0.9420
Detection AUC: 0.0000; Detection AUC: 0.8611
Detection AUC: 0.0000; Detection AUC: 0.4593
Detection AUC: 0.0000; Detection AUC: 0.7014
Detection AUC: 0.0000; Detection AUC: 0.6721
Detection AUC: 0.0000; Detection AUC: 0.8442
Detection AUC: 0.0000; Detection AUC: 0.6493
[-0.13903459 -0.39140504  0.25809531  0.06815942  0.06915195  0.01891466]
Attack Detected!
Stop at iteration: 20
acc 0.6200; recall 0.0000; fpr 0.1389; fnr 1.0000;
0.8342425253551271


In [63]:
np.sum(malicious_scores[-window_size:], axis=0)


array([0.10075014, 0.10034355, 0.10102042, 0.10100772, 0.10075673,
       0.10063589, 0.10067415, 0.10149158, 0.10160903, 0.10087983,
       0.10054865, 0.10065437, 0.10118496, 0.10070448, 0.0996544 ,
       0.10107337, 0.1010631 , 0.09991002, 0.10105943, 0.10132726,
       0.10154339, 0.10041358, 0.10175377, 0.10087505, 0.10109687,
       0.10123139, 0.10084032, 0.10174571, 0.10194788, 0.09904937,
       0.09380541, 0.09702941, 0.0917795 , 0.09676857, 0.09446125,
       0.09442798, 0.09400419, 0.09607597, 0.09496837, 0.09882355,
       0.09555456, 0.09841656, 0.09388723, 0.09923857, 0.0974157 ,
       0.10165723, 0.09743717, 0.10044993, 0.09387382, 0.09859051,
       0.08701711, 0.09191471, 0.08727049, 0.08341347, 0.08564741,
       0.09158969, 0.09108287, 0.09792132, 0.08573843, 0.08928625,
       0.09588788, 0.09164253, 0.09321117, 0.09239886, 0.08703439,
       0.09073279, 0.09083117, 0.09847255, 0.09185112, 0.09050218,
       0.09175321, 0.09317467, 0.0921819 , 0.09017967, 0.08596

# trimmed mean + no attack + 72% of 500 clients

In [32]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
num_workers = 500
all_data = torch.utils.data.ConcatDataset((trainset, testset))
distribution='fang'
param = .5
force = True
if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=num_workers, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)
    
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = int(500 * 0.28)

lr = 0.02
for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    agg_grads = torch.mean(user_grads, 0)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    if e%50==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(e,correct/total)

0 0.12518351766663405
50 0.7559949104433786
100 0.8257805618087501
150 0.9033963002838407
200 0.9200352353919937
250 0.9304101008123715
300 0.9370656748556327
350 0.9423509836546932
400 0.9466575315650386
450 0.9509640794753842
500 0.9543897425858863
550 0.9574239013409024
600 0.9601644318293041
650 0.9619262014289909
700 0.9643731036507781
750 0.9660369971615934
800 0.9670157580503083
850 0.9682881472056377
900 0.9686796515611236
950 0.9695605363609671
999 0.9703435450719389


# Median

In [16]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 28
lr = 0.02
all_data = torch.utils.data.ConcatDataset((trainset, testset))
num_workers = 100
distribution='fang'
param = .5
force = True
if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    agg_grads = torch.median(user_grads, 0)[0]
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    if e%50==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(e, correct/total)

0 0.13424957673538493
50 0.7415596056169704
100 0.6385818145603027
150 0.7149686286226471
200 0.789961159247087
250 0.8070909271984862
300 0.8423463798426452
350 0.8603724728612688
400 0.879792849317797
450 0.8967234339209242
500 0.9374564286425655
550 0.9353650034857086
600 0.9426351956976396
650 0.9431331540683199
700 0.9525943631112439
750 0.9539886465491485
800 0.9546857882681008
850 0.9579723135145902
900 0.9585698635594064
950 0.9586694552335425
999 0.9608604720645354


In [28]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
num_workers = 100
malicious_scores = np.zeros((1, num_workers))

attack_type = 'NDSS21'
dev_type = 'unit_vec'
all_data = torch.utils.data.ConcatDataset((trainset, testset))

distribution='fang'
param = .5
force = True
if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.02
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = 28
good_distance_rage = np.zeros((1, nbyz))

for e in range(100):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (e < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif e > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.1
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        print('==>', e, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(e,correct/total)

0 0.08542413381123058
1 0.08462763839107926
2 0.08462763839107926
3 0.08422939068100359
4 0.0843289526085225
5 0.08343289526085225
6 0.08233771405814416
7 0.08114297092791717
8 0.08064516129032258
9 0.07994822779769016
10 0.08243727598566308
Detection AUC: 0.0000; Detection AUC: 0.2917
==> 11 (1, 100)
11 0.08383114297092792
Detection AUC: 0.0000; Detection AUC: 0.7361
==> 12 (2, 100)
12 0.08512544802867383
Detection AUC: 0.0000; Detection AUC: 0.9722
==> 13 (3, 100)
13 0.08582238152130625
Detection AUC: 0.0000; Detection AUC: 0.9444
==> 14 (4, 100)
14 0.08711668657905217
Detection AUC: 0.0000; Detection AUC: 0.9444
==> 15 (5, 100)
15 0.08811230585424133
Detection AUC: 0.0000; Detection AUC: 0.9167
==> 16 (6, 100)
16 0.08950617283950617
Detection AUC: 0.0000; Detection AUC: 0.9028
==> 17 (7, 100)
17 0.09099960175228992
Detection AUC: 0.0000; Detection AUC: 0.9444
==> 18 (8, 100)
18 0.09209478295499801
Detection AUC: 0.0000; Detection AUC: 0.8889
==> 19 (9, 100)
19 0.09279171644763043
De

[ 0.46475184  0.14313485 -0.04805602  0.12840855  0.11375356  0.14347217]
No attack detected!
54 0.270310633213859
Detection AUC: 0.0000; Detection AUC: 0.6806
==> 55 (45, 100)
performing detection at epoch 55
[ 0.4984429   0.11331674  0.04302251  0.26150822 -0.09430352  0.14750381]
No attack detected!
55 0.2712066905615293
Detection AUC: 0.0000; Detection AUC: 0.6806
==> 56 (46, 100)
performing detection at epoch 56
[ 0.51657045  0.19587491 -0.04665865  0.17691828 -0.00355763  0.16490586]
No attack detected!
56 0.27220230983671845
Detection AUC: 0.0000; Detection AUC: 0.6250
==> 57 (47, 100)
performing detection at epoch 57
[ 0.43926496  0.20359289  0.04100342  0.10162832  0.27018666 -0.02409785]
No attack detected!
57 0.2735961768219833
Detection AUC: 0.0000; Detection AUC: 0.6667
==> 58 (48, 100)
performing detection at epoch 58
[0.55089391 0.11520016 0.04594439 0.11370693 0.17830411 0.1392178 ]
No attack detected!
58 0.2737953006770211
Detection AUC: 0.0000; Detection AUC: 0.6944
=

[ 0.34240059  0.29026831  0.08913794  0.17743676 -0.00430063  0.06507579]
No attack detected!
93 0.2756869772998805
Detection AUC: 0.0000; Detection AUC: 0.5833
==> 94 (84, 100)
performing detection at epoch 94
[ 0.4416672   0.1725647   0.10662793  0.19982152 -0.12263675  0.15865676]
No attack detected!
94 0.27588610115491835
Detection AUC: 0.0000; Detection AUC: 0.5139
==> 95 (85, 100)
performing detection at epoch 95
[ 0.6698769   0.04385919  0.22972631 -0.00306109  0.06779541  0.15839839]
No attack detected!
95 0.27598566308243727
Detection AUC: 0.0000; Detection AUC: 0.7083
==> 96 (86, 100)
performing detection at epoch 96
[0.54199924 0.01412372 0.34470791 0.01067764 0.10966078 0.18982432]
No attack detected!
96 0.27678215850258864
Detection AUC: 0.0000; Detection AUC: 0.3889
==> 97 (87, 100)
performing detection at epoch 97
[ 0.53379     0.00365687  0.31449595 -0.00454636  0.04362614  0.23300664]
No attack detected!
97 0.2769812823576264
Detection AUC: 0.0000; Detection AUC: 0.430

In [20]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
num_workers = 100
malicious_scores = np.zeros((1, num_workers))

attack_type = 'none'
dev_type = 'unit_vec'
all_data = torch.utils.data.ConcatDataset((trainset, testset))

distribution='dirichlet'
param = .2
force = True
if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=num_workers, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.02
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = int(num_workers * 0.28)
good_distance_rage = np.zeros((1, nbyz))

for e in range(100):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (e < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif e > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                z = 0.1
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        # print('==>', e, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(e,correct/total)

generating participant indices for alpha 0.2
0 0.10628210193599368
1 0.10588700118530225
2 0.10598577637297511
3 0.10598577637297511
4 0.10924535756617938
5 0.1230738838403793
6 0.1412485183721849
7 0.15645989727380483
8 0.18549980244962466
9 0.2081193204267088
10 0.23143026471750297
Detection AUC: 0.0000; Detection AUC: 0.4708
11 0.25276570525484
Detection AUC: 0.0000; Detection AUC: 0.4592
12 0.2623468984591071
Detection AUC: 0.0000; Detection AUC: 0.4673
13 0.27647175029632554
Detection AUC: 0.0000; Detection AUC: 0.4603
14 0.2809166337416041
Detection AUC: 0.0000; Detection AUC: 0.4748
15 0.29405373370209403
Detection AUC: 0.0000; Detection AUC: 0.4669
16 0.30264717502963256
Detection AUC: 0.0000; Detection AUC: 0.4816
17 0.31775977874357964
Detection AUC: 0.0000; Detection AUC: 0.4957
18 0.3273409719478467
Detection AUC: 0.0000; Detection AUC: 0.4761
19 0.33277360726985383
Detection AUC: 0.0000; Detection AUC: 0.4746
performing detection at epoch 20
[ 0.77852798 -0.07350276 -0.174

Detection AUC: 0.0000; Detection AUC: 0.4731
performing detection at epoch 58
[ 0.60098586  0.04460655 -0.21185116  0.03341911  0.00198022 -0.20315948]
No attack detected!
58 0.4830106677202687
Detection AUC: 0.0000; Detection AUC: 0.4764
performing detection at epoch 59
[ 0.68608455  0.0279264  -0.19908624  0.05409766 -0.02350215 -0.17063646]
No attack detected!
59 0.4876531015408929
Detection AUC: 0.0000; Detection AUC: 0.4686
performing detection at epoch 60
[ 0.6754157   0.0892888  -0.24644016  0.03261533 -0.03881181 -0.0851335 ]
No attack detected!
60 0.4968391939944686
Detection AUC: 0.0000; Detection AUC: 0.4785
performing detection at epoch 61
[ 0.68801547  0.02266    -0.22659611  0.03008208 -0.016577   -0.13709313]
No attack detected!
61 0.50622283682339
Detection AUC: 0.0000; Detection AUC: 0.4822
performing detection at epoch 62
[ 0.67805271  0.12856968 -0.29723075  0.1088831  -0.07968101 -0.0919678 ]
No attack detected!
62 0.5094824180165942
Detection AUC: 0.0000; Detection

KeyboardInterrupt: 

In [29]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
num_workers = 100
malicious_scores = np.zeros((1, num_workers))

attack_type = 'NDSS21'
dev_type = 'unit_vec'
all_data = torch.utils.data.ConcatDataset((trainset, testset))

distribution='dirichlet'
param = .5
force = True
if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.02
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = 28
good_distance_rage = np.zeros((1, nbyz))

for e in range(100):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (e < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif e > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.1
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        print('==>', e, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(e,correct/total)

generating participant indices for alpha 0.5
0 0.10749153217772464
1 0.11805140466228332
2 0.13409045626618848
3 0.15162382944809724
4 0.16626818091253237
5 0.17493524606495317
6 0.18619246861924685
7 0.19426180514046623
8 0.19924287706714486
9 0.20721259214983065
10 0.22414823670053796
Detection AUC: 0.0000; Detection AUC: 0.7639
==> 11 (1, 100)
11 0.2363020522016338
Detection AUC: 0.0000; Detection AUC: 0.7222
==> 12 (2, 100)
12 0.24297668858338314
Detection AUC: 0.0000; Detection AUC: 0.7083
==> 13 (3, 100)
13 0.24755927475592748
Detection AUC: 0.0000; Detection AUC: 0.6944
==> 14 (4, 100)
14 0.25841801155608685
Detection AUC: 0.0000; Detection AUC: 0.6806
==> 15 (5, 100)
15 0.2654911336919705
Detection AUC: 0.0000; Detection AUC: 0.7083
==> 16 (6, 100)
16 0.2837218569436143
Detection AUC: 0.0000; Detection AUC: 0.6806
==> 17 (7, 100)
17 0.2894002789400279
Detection AUC: 0.0000; Detection AUC: 0.6944
==> 18 (8, 100)
18 0.301255230125523
Detection AUC: 0.0000; Detection AUC: 0.7083
=

Detection AUC: 0.0000; Detection AUC: 0.7361
==> 54 (44, 100)
performing detection at epoch 54
[ 0.57552539 -0.10001964  0.32009299  0.22136742 -0.02498425 -0.21938299]
No attack detected!
54 0.64305638573421
Detection AUC: 0.0000; Detection AUC: 0.7500
==> 55 (45, 100)
performing detection at epoch 55
[ 0.60308767 -0.05017173  0.34232311  0.17229676  0.00421216 -0.12348446]
No attack detected!
55 0.646543136082885
Detection AUC: 0.0000; Detection AUC: 0.5972
==> 56 (46, 100)
performing detection at epoch 56
[ 0.61952219 -0.15799175  0.38874303  0.12211199  0.007711   -0.13470023]
No attack detected!
56 0.6548117154811716
Detection AUC: 0.0000; Detection AUC: 0.7083
==> 57 (47, 100)
performing detection at epoch 57
[ 0.64273338 -0.08953329  0.33362946  0.1814249   0.03812316 -0.11898965]
No attack detected!
57 0.657501494321578
Detection AUC: 0.0000; Detection AUC: 0.6944
==> 58 (48, 100)
performing detection at epoch 58
[ 0.60743177 -0.01573766  0.25558611  0.1362129   0.07121395 -0.1

Detection AUC: 0.0000; Detection AUC: 0.7500
==> 93 (83, 100)
performing detection at epoch 93
[ 0.55458723 -0.0766091   0.28577077  0.16389106  0.07383577  0.03926943]
No attack detected!
93 0.7027296274158199
Detection AUC: 0.0000; Detection AUC: 0.7500
==> 94 (84, 100)
performing detection at epoch 94
[ 0.51950012 -0.11859419  0.30514584  0.20125224  0.11868994 -0.04147754]
No attack detected!
94 0.7033273560470213
Detection AUC: 0.0000; Detection AUC: 0.7083
==> 95 (85, 100)
performing detection at epoch 95
[ 0.50419781 -0.11931024  0.3613591   0.17568109  0.08093963 -0.0612606 ]
No attack detected!
95 0.7059175134488942
Detection AUC: 0.0000; Detection AUC: 0.7361
==> 96 (86, 100)
performing detection at epoch 96
[ 0.50200848 -0.06154303  0.38596246  0.08569101  0.13591435 -0.08436136]
No attack detected!
96 0.7071129707112971
Detection AUC: 0.0000; Detection AUC: 0.7083
==> 97 (87, 100)
performing detection at epoch 97
[ 0.63358757 -0.05782452  0.30649224  0.15387878  0.09794853 

In [17]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
num_workers = 100
malicious_scores = np.zeros((1, num_workers))

attack_type = 'LIE'
dev_type = 'unit_vec'
all_data = torch.utils.data.ConcatDataset((trainset, testset))

distribution='dirichlet'
param = .1
force = True
if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.02
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = int(num_workers * 0.28)
good_distance_rage = np.zeros((1, nbyz))

for e in range(100):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (e < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif e > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.01
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        print('==>', e, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(e,correct/total)

generating participant indices for alpha 0.1
0 0.09005778043434948
1 0.09075513050408449
2 0.0928471807132895
3 0.09424188085275952
4 0.09723052400876668
5 0.09414225941422594
6 0.09563658099222953
7 0.09643355250049811
8 0.09832635983263599
9 0.10111576011157601
10 0.1033074317593146
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 11 (1, 100)
11 0.10420402470611675
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 12 (2, 100)
12 0.1108786610878661
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 13 (3, 100)
13 0.12522414823670053
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 14 (4, 100)
14 0.14973102211595934
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 15 (5, 100)
15 0.16915720263000597
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 16 (6, 100)
16 0.19207013349272764
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 17 (7, 100)
17 0.2127913927077107
LIE
Detection AUC: 0.0000; Detection AUC: 1.0000
==> 18 (8, 100)
18 0.2268380155409444
LIE
Detection

UnboundLocalError: local variable 'select_k' referenced before assignment

# impact of distribution on good baselines

In [24]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(train_data, labels, model, optimizer, batch_size=20):
    model.train()
    criterion = nn.CrossEntropyLoss()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    len_t = (len(train_data) // batch_size)
    if len(train_data)%batch_size:
        len_t += 1

    r=np.arange(len(train_data))
    np.random.shuffle(r)
    
    train_data = train_data[r]
    labels = labels[r]
    
    for ind in range(len_t):

        inputs = train_data[ind * batch_size:(ind + 1) * batch_size]
        targets = labels[ind * batch_size:(ind + 1) * batch_size]

        inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return (losses.avg, top1.avg)


def test(test_data, labels, model, criterion, use_cuda, debug_='MEDIUM', batch_size=64):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    len_t = (len(test_data) // batch_size)
    if len(test_data)%batch_size:
        len_t += 1

    with torch.no_grad():
        for ind in range(len_t):
            # measure data loading time
            inputs = test_data[ind * batch_size:(ind + 1) * batch_size]
            targets = labels[ind * batch_size:(ind + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()

            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

    return (losses.avg, top1.avg)

In [31]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
num_workers = 100
malicious_scores = np.zeros((1, num_workers))

attack_type = 'none'
dev_type = 'unit_vec'
all_data = torch.utils.data.ConcatDataset((trainset, testset))

distribution='dirichlet'
param = .1
force = True
# if distribution=='fang':
#     each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
# elif distribution == 'dirichlet':
#     each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

criterion = nn.CrossEntropyLoss()

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
net.apply(init_weights)
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

local_epochs = 2
batch_size = 8
local_lr = 0.05
global_lr = 1
nepochs = 50

start_detection_epoch = 5
window_size = 5
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)
nbyz = int(num_workers * 0.28)
good_distance_rage = np.zeros((1, nbyz))

for e in range(nepochs):
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_grads=[]
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = lr)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (e < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
    elif e > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.01
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.mean(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                noise = deviation * (distance_bound / torch.norm(deviation))
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, nbyz, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, nbyz, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        print('==>', e, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= (window_size+1):
        print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads

    del user_grads
    model_received = model_received + global_lr * agg_grads
    fed_model = cnn().to(device)
    net.apply(init_weights)
    
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name

    fed_model.load_state_dict(state_dict)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = fed_model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(e,correct/total)

0 0.6925682406853955
1 0.8588364215979278
2 0.9029687188683004
3 0.9274755927475593
4 0.9444112372982666
5 0.9564654313608288
Detection AUC: 0.0000; Detection AUC: 0.4980
==> 6 (1, 100)
6 0.9620442319187089
Detection AUC: 0.0000; Detection AUC: 0.4891
==> 7 (2, 100)
7 0.9656306037059175
Detection AUC: 0.0000; Detection AUC: 0.5789
==> 8 (3, 100)
8 0.9720063757720662
Detection AUC: 0.0000; Detection AUC: 0.5248
==> 9 (4, 100)
9 0.975293883243674
Detection AUC: 0.0000; Detection AUC: 0.5084
==> 10 (5, 100)
performing detection at epoch 10
[ 0.32872502 -0.08098534  0.2041028   0.0797909   0.13231687  0.04238797]
No attack detected!
10 0.9787806335923491
Detection AUC: 0.0000; Detection AUC: 0.5109
==> 11 (6, 100)
performing detection at epoch 11
[ 0.32257274  0.14540135  0.24418138 -0.2098959   0.08544233  0.12395957]
No attack detected!
11 0.9785813907152819


KeyboardInterrupt: 

# Adaptive attack against FLDetector based on simple noise addition

## Method: I added noise in the *little-is-enough attack fashion*

## Observation: FLdetector does not detect the attack and the model does not train.

In [39]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
malicious_scores = np.zeros((1, num_workers))
attack_type = 'LIE'

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
lr = 0.15
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)

for e in range(100):
    cnn_optimizer = SGD(net.parameters(), lr = lr*(0.96**e))
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, each_worker_label[i].cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    if(e < start_detection_epoch):
        user_grads = full_trim(user_grads, 20)
    elif e > start_detection_epoch:
        if attack_type == 'LIE':
            print("LIE")
            z = 0.3*(1.01**(e))
            noise_avg = torch.mean(user_grads[:20], dim=0)
            noise_std = torch.std(user_grads[:20], dim=0)
            noise = noise_avg + z * noise_std
        elif attack_type == 'NDSS21':
            pass
        else:
            noise = torch.zeros(hvp.shape).to(device)
        for m in range(20):
            user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise
#             user_grads[m] = torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, 20, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, 20, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        print('==>', e, malicious_scores.shape)
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), 20):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-window_size:], axis=0), 20)
            break

#     agg_grads = tr_mean(user_grads, 20)
#     agg_grads=torch.median(user_grads,dim=0)[0]
#     agg_grads=torch.mean(user_grads,dim=0)

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
  
    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        for i, data in enumerate(test_data):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(e,correct/total)

0 0.0652
1 0.116
2 0.1337
3 0.133
4 0.1128
5 0.1003
6 0.1492
7 0.147
8 0.2046
9 0.101
10 0.0974
LIE
Detection AUC: 0.0000; Detection AUC: 0.3750
==> 11 (1, 100)
11 0.101
LIE
Detection AUC: 0.0000; Detection AUC: 0.0000
==> 12 (2, 100)
12 0.1028
LIE
Detection AUC: 0.0000; Detection AUC: 0.0000
==> 13 (3, 100)
13 0.1028
LIE
Detection AUC: 0.0000; Detection AUC: 0.0000
==> 14 (4, 100)
14 0.1028
LIE
Detection AUC: 0.0000; Detection AUC: 0.9875
==> 15 (5, 100)
15 0.1028
LIE
Detection AUC: 0.0000; Detection AUC: 0.9625
==> 16 (6, 100)
16 0.1029
LIE
Detection AUC: 0.0000; Detection AUC: 0.9625
==> 17 (7, 100)
17 0.1698
LIE
Detection AUC: 0.0000; Detection AUC: 0.9625
==> 18 (8, 100)
18 0.1172
LIE
Detection AUC: 0.0000; Detection AUC: 0.9375
==> 19 (9, 100)
19 0.0974
LIE
Detection AUC: 0.0000; Detection AUC: 0.0000
==> 20 (10, 100)
performing detection at epoch 20
No attack detected!
20 0.0974
LIE
Detection AUC: 0.0000; Detection AUC: 0.1125
==> 21 (11, 100)
performing detection at epoch 21
No

In [22]:
def our_attack_trmean(all_updates, model_re, n_attackers, dev_type='unit_vec'):

    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)

    lamda = torch.Tensor([10.0]).cuda() #compute_lambda_our(all_updates, model_re, n_attackers)
    # print(lamda)
    threshold_diff = 1e-5
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0
    iters = 0 
    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = tr_mean(mal_updates, n_attackers)
        
        loss = torch.norm(agg_grads - model_re)
        
        if prev_loss < loss:
            # print('successful lamda is ', lamda)
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss
        
    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * n_attackers)
    mal_updates = torch.cat((mal_updates, all_updates), 0)

    return mal_updates

In [34]:
def get_ndss21_noise_partial_knwoledge(good_old_grads, good_current_grads, n_attackers, hvp, dev_type='std'):
    model_re = torch.mean(good_current_grads, dim=0)
    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(good_current_grads, 0)
    
    norm_ = torch.norm(deviation).cpu().numpy()
    pred_grad = copy.deepcopy(good_old_grads)
    distance = []
    for i in range(len(good_old_grads)):
        pred_grad[i] += hvp
    distance = torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()
    max_dist = np.mean(distance)

#     while torch.abs(lamda_succ - lamda) > threshold_diff:
#         # mal_update = (bad_old_grad + hvp + lamda * deviation)
#         distance = np.concatenate(distance, np.array(n_attackers * [lamda * norm_]))
#         distance = distance / np.sum(distance)
#         if detection1(distance, 0):
#             if lamda * norm_ > max_dist:
#                 lamda -= 1
#             else:
#                 lamda += 1

    return deviation * (max_dist/torch.norm(deviation))

In [44]:
torch.cuda.empty_cache()
weight_record = []
grad_record = []
test_grads = []
malicious_scores = np.zeros((1, num_workers))
attack_type = 'NDSS21'
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
#     elif isinstance(m, nn.Conv2d):
#         torch.nn.init.xavier_uniform_(m.weight)
#         m.bias.data.fill_(0.01)

criterion = nn.CrossEntropyLoss()
net = cnn().to(device)
net.apply(init_weights)
lr = 0.15
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)

for e in range(100):
    cnn_optimizer = SGD(net.parameters(), lr = lr*(0.96**e))
#     cnn_optimizer = SGD(net.parameters(), lr=lr)
    user_grads = []

    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    tmp = []
    for param in net.parameters():
        tmp = param.data.view(-1) if not len(tmp) else torch.cat((tmp, param.data.view(-1)))
    weight = tmp    

    if (e > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:20])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()

    if(e < start_detection_epoch):
#         user_grads[:20]= full_trim(user_grads[:20], 20)
        user_grads = full_trim(user_grads, 20)
    elif e > start_detection_epoch:
        if attack_type == 'LIE':
            print("LIE")
            z = 0.1*(1.1**(e))
            noise_avg = torch.mean(user_grads[:20], dim=0)
            noise_std = torch.std(user_grads[:20], dim=0)
            noise = noise_avg + z * noise_std
        elif attack_type == 'NDSS21':
            dev_type = 'unit_vec'
            distance_bound = np.mean(good_distance_rage)
            model_re = torch.mean(good_current_grads, dim=0)
            if dev_type == 'unit_vec':
                deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
            elif dev_type == 'sign':
                deviation = torch.sign(model_re)
            elif dev_type == 'std':
                deviation = torch.std(good_current_grads, 0)
            noise = deviation * (distance_bound / torch.norm(deviation))
        elif attack_type == 'mod_trim':
            mal_grads= full_trim(user_grads[:20], 20)
#             distance_bound
            pass
        else:
            noise = torch.zeros(hvp.shape).to(device)
        for m in range(20):
            user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

#     agg_grads, distance = simple_mean(old_grad_list, user_grads, 20, hvp)
    agg_grads, distance = trimmed_mean(old_grad_list, user_grads, 20, hvp)
    
    if distance is not None and e > (start_detection_epoch - window_size):
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= 11:
        print('performing detection at epoch %d' % e)
        if detection1(np.sum(malicious_scores[-10:], axis=0), 20):
            print('Stop at iteration:', e)
            detection(np.sum(malicious_scores[-10:], axis=0), 20)
            break

    if e > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > 10):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads
  
    del user_grads
    
    start_idx=0

    cnn_optimizer.zero_grad()

    model_grads=[]

    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        for i, data in enumerate(test_data):
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(e,correct/total)

0 0.097
1 0.1002
2 0.1127
3 0.1403
4 0.2097
5 0.2568
6 0.28
7 0.312
8 0.3211
9 0.3384
10 0.2196
Detection AUC: 0.0000; Detection AUC: 0.0000
11 0.2041
Detection AUC: 0.0000; Detection AUC: 0.0875
12 0.3749
Detection AUC: 0.0000; Detection AUC: 0.3875
13 0.2253
Detection AUC: 0.0000; Detection AUC: 0.2625
14 0.4456
Detection AUC: 0.0000; Detection AUC: 0.6375
15 0.1886
Detection AUC: 0.0000; Detection AUC: 0.3250
16 0.4548
Detection AUC: 0.0000; Detection AUC: 0.6250
17 0.3113
Detection AUC: 0.0000; Detection AUC: 0.6750
18 0.4516
Detection AUC: 0.0000; Detection AUC: 0.4375
19 0.4375
Detection AUC: 0.0000; Detection AUC: 0.6250
performing detection at epoch 20
No attack detected!
20 0.5131
Detection AUC: 0.0000; Detection AUC: 0.7500
performing detection at epoch 21
Attack Detected!
Stop at iteration: 21
acc 0.6300; recall 0.0000; fpr 0.2125; fnr 1.0000;
0.7783978053352452
