In [1]:
from __future__ import print_function
import argparse, os, sys, csv, shutil, time, random, operator, pickle, ast, math
import numpy as np
import pandas as pd
from torch.optim import Optimizer
import torch.nn.functional as F
import torch
import pickle
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torch.multiprocessing as mp

sys.path.insert(0,'./../utils/')
from logger import *
from eval import *
from misc import *

from cifar10_normal_train import *
from cifar10_util import *

In [2]:
from adam import Adam
from sgd import SGD
from gradient_aggregation_rules import *
from poisoning_attacks import *

In [3]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [4]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='/mnt/nfs/work1/amir/vshejwalkar/cifar10_data/'
# load the train dataset

train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=train_transform)

cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=train_transform)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
X=[]
Y=[]
for i in range(len(cifar10_train)):
    X.append(cifar10_train[i][0].numpy())
    Y.append(cifar10_train[i][1])

for i in range(len(cifar10_test)):
    X.append(cifar10_test[i][0].numpy())
    Y.append(cifar10_test[i][1])

X=np.array(X)
Y=np.array(Y)

print('total data len: ',len(X))

if not os.path.isfile('./cifar10_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./cifar10_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./cifar10_shuffle.pkl','rb'))

X=X[all_indices]
Y=Y[all_indices]

total data len:  60000


In [6]:
# data loading

nusers=50
user_tr_len=1000

total_tr_len=user_tr_len*nusers
val_len=5000
te_len=5000

print('total data len: ',len(X))

if not os.path.isfile('./cifar10_shuffle.pkl'):
    all_indices = np.arange(len(X))
    np.random.shuffle(all_indices)
    pickle.dump(all_indices,open('./cifar10_shuffle.pkl','wb'))
else:
    all_indices=pickle.load(open('./cifar10_shuffle.pkl','rb'))

total_tr_data=X[:total_tr_len]
total_tr_label=Y[:total_tr_len]

val_data=X[total_tr_len:(total_tr_len+val_len)]
val_label=Y[total_tr_len:(total_tr_len+val_len)]

te_data=X[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]
te_label=Y[(total_tr_len+val_len):(total_tr_len+val_len+te_len)]

total_tr_data_tensor=torch.from_numpy(total_tr_data).type(torch.FloatTensor)
total_tr_label_tensor=torch.from_numpy(total_tr_label).type(torch.LongTensor)

val_data_tensor=torch.from_numpy(val_data).type(torch.FloatTensor)
val_label_tensor=torch.from_numpy(val_label).type(torch.LongTensor)

te_data_tensor=torch.from_numpy(te_data).type(torch.FloatTensor)
te_label_tensor=torch.from_numpy(te_label).type(torch.LongTensor)

print('total tr len %d | val len %d | test len %d'%(len(total_tr_data_tensor),len(val_data_tensor),len(te_data_tensor)))

#==============================================================================================================

user_tr_data_tensors=[]
user_tr_label_tensors=[]

for i in range(nusers):
    
    user_tr_data_tensor=torch.from_numpy(total_tr_data[user_tr_len*i:user_tr_len*(i+1)]).type(torch.FloatTensor)
    user_tr_label_tensor=torch.from_numpy(total_tr_label[user_tr_len*i:user_tr_len*(i+1)]).type(torch.LongTensor)

    user_tr_data_tensors.append(user_tr_data_tensor)
    user_tr_label_tensors.append(user_tr_label_tensor)
    print('user %d tr len %d'%(i,len(user_tr_data_tensor)))

total data len:  60000
total tr len 50000 | val len 5000 | test len 5000
user 0 tr len 1000
user 1 tr len 1000
user 2 tr len 1000
user 3 tr len 1000
user 4 tr len 1000
user 5 tr len 1000
user 6 tr len 1000
user 7 tr len 1000
user 8 tr len 1000
user 9 tr len 1000
user 10 tr len 1000
user 11 tr len 1000
user 12 tr len 1000
user 13 tr len 1000
user 14 tr len 1000
user 15 tr len 1000
user 16 tr len 1000
user 17 tr len 1000
user 18 tr len 1000
user 19 tr len 1000
user 20 tr len 1000
user 21 tr len 1000
user 22 tr len 1000
user 23 tr len 1000
user 24 tr len 1000
user 25 tr len 1000
user 26 tr len 1000
user 27 tr len 1000
user 28 tr len 1000
user 29 tr len 1000
user 30 tr len 1000
user 31 tr len 1000
user 32 tr len 1000
user 33 tr len 1000
user 34 tr len 1000
user 35 tr len 1000
user 36 tr len 1000
user 37 tr len 1000
user 38 tr len 1000
user 39 tr len 1000
user 40 tr len 1000
user 41 tr len 1000
user 42 tr len 1000
user 43 tr len 1000
user 44 tr len 1000
user 45 tr len 1000
user 46 tr len 10

In [7]:
def lie_attack(all_updates, z):
    avg = torch.mean(all_updates, dim=0)
    std = torch.std(all_updates, dim=0)
    return avg + z * std

In [8]:
def dnc(all_updates, n_attackers, num_buckets=1, bucket=100000):
    n, d = all_updates.shape
    
    final_indices = []
    
    for p in np.arange(num_buckets):
        idx = np.sort(np.random.choice(d, bucket, replace=False))
        sampled_all_updates = all_updates[:, idx]
        sampled_good_updates = all_updates[n_attackers:][:, idx]

        centered_all_updates = sampled_all_updates - torch.mean(sampled_all_updates, 0)
        centered_good_updates = sampled_good_updates - torch.mean(sampled_good_updates, 0)
        
        u, s, v = torch.svd(centered_all_updates)
        u_g, s_g, v_g = torch.svd(centered_good_updates)
        
        scores = torch.mm(centered_all_updates, v[:,0][:, None]).cpu().numpy()
        
        final_indices.append(list(np.argsort(scores[:,0]**2)[:(n-int(1.5*n_attackers))]))

    result = set(final_indices[0]) 
    for currSet in final_indices[1:]: 
        result.intersection_update(currSet)
    final_idx = np.array(list(result))
    # print(np.array(final_idx), len((final_idx)))
    
    return torch.mean(all_updates[final_idx], 0), final_idx

In [9]:
batch_size=250
resume=0
nepochs=1200
schedule=[1000]
nbatches = user_tr_len//batch_size

gamma=.5
opt = 'sgd'
fed_lr=0.5
criterion=nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

aggregation='dnc'

at_type='lie'
dev_type = 'std'

partial = False
z_values={3:0.69847, 5:0.7054, 8:0.71904, 10:0.72575, 12:0.73891}
n_attackers=[10]

arch='alexnet'
chkpt='./'+aggregation

bucket_size = 10000
num_buckets = 2

for n_attacker in n_attackers:
    candidates = []
    torch.cuda.manual_seed_all(0)
    torch.cuda.manual_seed(0)
    np.random.seed(0)
    torch.cuda.empty_cache()
    
    epoch_num = 0
    best_global_acc = 0
    best_global_te_acc = 0

    if partial:
        fed_file='partial_chkpt_%s_af_%d.pth.tar' % (at_type, n_attacker)
        fed_best_file='partial_best_%s_af_%d.pth.tar' % (at_type, n_attacker)
    else:
        fed_file='chkpt_%s_af_%d.pth.tar' % (at_type, n_attacker)
        fed_best_file='best_%s_af_%d.pth.tar' % (at_type, n_attacker)

    fed_model, _ = return_model(arch, 0.1, 0.9, parallel=False)
    optimizer_fed = SGD(fed_model.parameters(), lr=fed_lr)

    if resume:
        fed_checkpoint = chkpt+'/'+fed_file
        assert os.path.isfile(fed_checkpoint), 'Error: no user checkpoint at %s'%(fed_checkpoint)
        checkpoint = torch.load(fed_checkpoint, map_location='cuda:%d'%torch.cuda.current_device())
        fed_model.load_state_dict(checkpoint['state_dict'])
        optimizer_fed.load_state_dict(checkpoint['optimizer'])
        resume = 0
        best_global_acc=checkpoint['best_acc']
        best_global_te_acc=checkpoint['best_te_acc']
        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        epoch_num += checkpoint['epoch']
        print('resuming from epoch %d | val acc %.4f | best acc %.3f | best te acc %.3f'%(epoch_num, val_acc, best_global_acc, best_global_te_acc))

    r=np.arange(user_tr_len)

    while epoch_num <= nepochs:
        user_grads=[]
        if not epoch_num and epoch_num%nbatches == 0:
            np.random.shuffle(r)
            for i in range(nusers):
                user_tr_data_tensors[i]=user_tr_data_tensors[i][r]
                user_tr_label_tensors[i]=user_tr_label_tensors[i][r]

        for i in range(nusers):

            inputs = user_tr_data_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]
            targets = user_tr_label_tensors[i][(epoch_num%nbatches)*batch_size:((epoch_num%nbatches) + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)

            outputs = fed_model(inputs)
            loss = criterion(outputs, targets)
            fed_model.zero_grad()
            loss.backward(retain_graph=True)

            param_grad=[]
            for param in fed_model.parameters():
                param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))

            user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)

        if epoch_num in schedule:
            for param_group in optimizer_fed.param_groups:
                param_group['lr'] *= gamma
                print('New learnin rate ', param_group['lr'])

        if n_attacker > 0:
            if at_type == 'lie':
                if partial:
                    mal_update = lie_attack(user_grads[:n_attacker], z_values[n_attacker])
                else:
                    mal_update = lie_attack(user_grads, z_values[n_attacker])
            elif at_type == 'fang':
                if aggregation == 'trmean' or aggregation == 'median':
                    if partial:
                        mal_update = fang_attack_trmean_partial(malicious_grads, n_attacker)
                    else:
                        mal_update = fang_attack_trmean(malicious_grads, n_attacker)
                else:
                    mal_update = fang_attack_krum(malicious_grads, n_attacker)
            elif at_type == 'our_agr':
                if aggregation == 'krum':
                    mal_update = our_attack_krum(malicious_grads, n_attacker, dev_type, threshold=2.0, threshold_diff=1e-7)
                elif aggregation == 'mkrum' or aggregation == 'bulyan':
                    mal_update = our_attack_mkrum(malicious_grads, n_attacker, dev_type, threshold=2.0, threshold_diff=1e-7)
                elif aggregation == 'trmean':
                    mal_update = our_attack_trmean(malicious_grads, n_attacker, dev_type, threshold=50.0, threshold_diff=1e-2)
                elif aggregation == 'median':
                    mal_update = our_attack_median(malicious_grads, n_attacker, dev_type, threshold=50.0, threshold_diff=1e-2)
            elif at_type == 'our_noagr_dist':
                mal_update = our_attack_dist(malicious_grads, n_attacker, dev_type, threshold_diff = 1e-5)
            elif at_type == 'our_noagr_score':
                mal_update = our_attack_score(malicious_grads, n_attacker, dev_type, threshold_diff = 1e-5)
            
            if at_type == 'fang' and (aggregation == 'trmean' or aggregation == 'median'):
                mal_updates = mal_update
            else:
                mal_updates = torch.stack([mal_update] * n_attacker)
                
            malicious_grads = torch.cat((mal_updates, user_grads[n_attacker:]), 0)
            
        if epoch_num == 0: print('malicious_grads shape is ', malicious_grads.shape)

        agg_grads, our_candidates = dnc(malicious_grads, n_attacker, num_buckets=num_buckets, bucket=bucket_size)
        
        if n_attacker:
            if epoch_num > 0 and (epoch_num%50==0 or epoch_num == (nepochs-1)):
                try:
                    print('number of malicious grads chosen are ', np.array(candidates).reshape(5, 10))
                except:
                    print('number of malicious grads chosen are ', np.array(candidates))
                candidates = []
                candidates.append(np.sum(our_candidates < n_attacker))
            else:
                candidates.append(np.sum(our_candidates < n_attacker))

        del user_grads

        start_idx=0

        optimizer_fed.zero_grad()

        model_grads=[]

        for i, param in enumerate(fed_model.parameters()):
            param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
            start_idx=start_idx+len(param.data.view(-1))
            param_=param_.cuda()
            model_grads.append(param_)

        optimizer_fed.step(model_grads)

        val_loss, val_acc = test(val_data_tensor,val_label_tensor,fed_model,criterion,use_cuda)
        te_loss, te_acc = test(te_data_tensor,te_label_tensor, fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        if is_best:
            best_global_te_acc = te_acc

        save_checkpoint_global(
            {
                'epoch': epoch_num,
                'state_dict': fed_model.state_dict(),
                'best_global_acc': best_global_acc,
                'best_global_te_acc': best_global_te_acc,
                'optimizer': optimizer_fed.state_dict()
            },
            is_best,
            checkpoint=chkpt,
            filename=fed_file,
            best_filename=fed_best_file,
        )
        
        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('%s: at %s n_at %d z %.2f e %d fed_model val loss %.4f val acc %.4f best val_acc %f te_acc %f'%(aggregation, at_type, n_attacker, z_values[n_attacker], epoch_num, val_loss, val_acc, best_global_acc,best_global_te_acc))

        if val_loss > 10:
            print('val loss %f too high'%val_loss)
            break
        
        epoch_num+=1

malicious_grads shape is  torch.Size([50, 2472266])
bulyan: at lie n_at 10 z 0.73 e 0 fed_model val loss 2.3031 val acc 12.7435 best val_acc 12.743506 te_acc 12.297078
bulyan: at lie n_at 10 z 0.73 e 10 fed_model val loss 2.2984 val acc 15.7873 best val_acc 15.990260 te_acc 16.761364
bulyan: at lie n_at 10 z 0.73 e 20 fed_model val loss 2.2824 val acc 22.1997 best val_acc 22.199675 te_acc 20.921266
bulyan: at lie n_at 10 z 0.73 e 30 fed_model val loss 2.2231 val acc 18.5471 best val_acc 22.199675 te_acc 20.921266
bulyan: at lie n_at 10 z 0.73 e 40 fed_model val loss 2.2184 val acc 14.4075 best val_acc 22.463474 te_acc 22.260552
number of malicious grads chosen are  [[10 10 10 10 10 10 10  0 10  5]
 [10  8 10 10 10 10 10 10 10  6]
 [10 10 10 10 10 10 10 10 10 10]
 [10  9 10  8  8 10 10 10 10 10]
 [10 10 10  7 10 10 10 10 10 10]]
bulyan: at lie n_at 10 z 0.73 e 50 fed_model val loss 2.1084 val acc 22.4635 best val_acc 25.426136 te_acc 23.863636
bulyan: at lie n_at 10 z 0.73 e 60 fed_mode