# FLDetector for MNIST


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [2]:
import collections
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [3]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))
            
        return per_participant_list

In [5]:
def get_client_train_data(trainset, num_workers=100, bias=0.5):
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10

    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        
        if not len(each_worker_data[selected_worker]):
            each_worker_data[selected_worker] = x[None, :]
        else:
            each_worker_data[selected_worker]= torch.concat((each_worker_data[selected_worker], x[None, :]))
        
        each_worker_label[selected_worker].append(y)
    
    each_worker_tr_data = [[] for _ in range(num_workers)]
    each_worker_tr_label = [[] for _ in range(num_workers)]
    each_worker_te_data = [[] for _ in range(num_workers)]
    each_worker_te_label = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    for i in range(num_workers):
        w_len = len(each_worker_data[i])
        len_tr = int(6 * w_len / 7)
        len_te = w_len - len_tr
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        each_worker_tr_data[i] = each_worker_data[i][tr_idx]
        each_worker_tr_label[i] = torch.Tensor(each_worker_label[i])[tr_idx]
        
        each_worker_te_data[i] = each_worker_data[i][te_idx]
        each_worker_te_label[i] = torch.Tensor(each_worker_label[i])[te_idx]
        
    global_test_data = torch.concat(each_worker_te_data)
    global_test_label = torch.concat(each_worker_te_label)
    del each_worker_data, each_worker_label
    return each_worker_tr_data, each_worker_tr_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label

In [6]:
def get_client_data_dirichlet(trainset, num_workers, alpha=1, force=False):
    per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=alpha, force=force)
    
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    each_worker_te_data = [[] for _ in range(num_workers)]
    each_worker_te_label = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(6*w_len/7)
        len_te = w_len - len_tr
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        
        for idx in tr_idx:
            each_worker_data[worker_idx].append(trainset[idx][0])
            each_worker_label[worker_idx].append(trainset[idx][1])
        each_worker_data[worker_idx] = torch.stack(each_worker_data[worker_idx])
        each_worker_label[worker_idx] = torch.Tensor(each_worker_label[worker_idx]).long()
        
        for idx in te_idx:
            each_worker_te_data[worker_idx].append(trainset[idx][0])
            each_worker_te_label[worker_idx].append(trainset[idx][1])
        each_worker_te_data[worker_idx] = torch.stack(each_worker_te_data[worker_idx])
        each_worker_te_label[worker_idx] = torch.Tensor(each_worker_te_label[worker_idx]).long()
    
    global_test_data = torch.concat(each_worker_te_data)
    global_test_label = torch.concat(each_worker_te_label)
    
    return each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label


In [7]:
#Trim attack
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(f):
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [8]:
#NDSS attack
def our_attack_median(all_updates, n_attackers, dev_type='unit_vec'):
    model_re = (all_updates)
    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)
    lamda = torch.Tensor([10.0]).cuda() #compute_lambda_our(all_updates, model_re, n_attackers)
    threshold_diff = 1e-5
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0
    iters = 0 
    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = torch.median(mal_updates, 0)[0]
        
        loss = torch.norm(agg_grads - model_re)
        
        if prev_loss < loss: #attack successful
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:#attack failed
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss
        
    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * n_attackers)
    mal_updates = torch.cat((mal_updates, all_updates), 0)

    return mal_updates

In [9]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [10]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(30, 50, 5)
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(800, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Match baseline with fang distribution and 80 clients

In [11]:
all_data = torch.utils.data.ConcatDataset((trainset, testset))
num_workers = 100
distribution='fang'
param = .5
force = True

if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

# Mean

In [22]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
num_workers = 100
nbyz = 20
lr = 0.01
nepochs=1000
best_acc = 0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'
for e in range(nepochs):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    agg_grads = torch.mean(user_grads, 0)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%100==0 or e==nepochs-1:
        print(e, acc, best_acc)

final_accs_per_client=[]
for i in range(num_workers):
    total, correct = 0,0
    with torch.no_grad():
        outputs = net(each_worker_te_data[i].cuda())
        _, predicted = torch.max(outputs.data, 1)
        total += len(each_worker_te_label[i])
        correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
        acc = correct/total
    final_accs_per_client.append(acc)

results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
)

pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_mnist_slow_mean.pkl'), 'wb'), )

0 0.10195141377937077 0.10195141377937077
100 0.7241138988450816 0.7241138988450816
200 0.8746515332536838 0.8746515332536838
300 0.9005376344086021 0.9005376344086021
400 0.9122859418558343 0.9122859418558343
500 0.9217443249701314 0.9217443249701314
600 0.9287136598964556 0.9287136598964556
700 0.9359816806053365 0.9359816806053365
800 0.9430505774591796 0.9430505774591796
900 0.946236559139785 0.9463361210673039
999 0.9495221027479092 0.9495221027479092


In [None]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 20
lr = 0.01
nepochs=1000
best_acc=0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

for e in range(nepochs):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    user_grads = full_trim(user_grads, nbyz)
    agg_grads = tr_mean(user_grads, nbyz)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%100==0 or e==nepochs-1:
        print(e, acc, best_acc)

final_accs_per_client=[]
for i in range(num_workers):
    total, correct = 0,0
    with torch.no_grad():
        outputs = net(each_worker_te_data[i].cuda())
        _, predicted = torch.max(outputs.data, 1)
        total += len(each_worker_te_label[i])
        correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
        acc = correct/total
    final_accs_per_client.append(acc)

results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
)

pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_mnist_slow_trmean_trim20.pkl'), 'wb'), )

0 0.13839107925129432 0.13839107925129432
100 0.26473516527279967 0.26473516527279967
200 0.428215850258861 0.428215850258861
300 0.6329151732377539 0.6329151732377539
400 0.7400438072481084 0.8118279569892473
500 0.8461768219832736 0.8634010354440462
600 0.7329749103942652 0.8865989645559538
700 0.7830545599362804 0.8951612903225806
800 0.8276583034647551 0.9036240541616886
900 0.9074074074074074 0.9175627240143369


# Median

In [20]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
num_workers = 100
nbyz = 20
lr = 0.01
nepochs=1000
best_acc = 0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

for e in range(nepochs):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    agg_grads = torch.median(user_grads, 0)[0]
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%100==0 or e==nepochs-1:
        print(e, acc, best_acc)

final_accs_per_client=[]
for i in range(num_workers):
    total, correct = 0,0
    with torch.no_grad():
        outputs = net(each_worker_te_data[i].cuda())
        _, predicted = torch.max(outputs.data, 1)
        total += len(each_worker_te_label[i])
        correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
        acc = correct/total
    final_accs_per_client.append(acc)

results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
)

pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_mnist_slow_median.pkl'), 'wb'), )

0 0.11927518916766229 0.11927518916766229
100 0.7697132616487455 0.7697132616487455
200 0.7053962564715253 0.7697132616487455
300 0.7737953006770211 0.8020708880923935
400 0.8378136200716846 0.8378136200716846
500 0.824173636001593 0.8472720031859817
600 0.8395061728395061 0.856431700517722
700 0.8619076065312624 0.8767423337315811
800 0.8738550378335325 0.8890880127439267
900 0.8368180007964954 0.9168657905217045
999 0.9414575866188769 0.9445440063719633


In [24]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 20
lr = 0.01
nepochs=1000
best_acc=0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

for e in range(nepochs):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    # user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
    user_grads = full_trim(user_grads, nbyz)
    agg_grads = torch.median(user_grads, 0)[0]
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%100==0 or e==nepochs-1:
        print(e, acc, best_acc)

final_accs_per_client=[]
for i in range(num_workers):
    total, correct = 0,0
    with torch.no_grad():
        outputs = net(each_worker_te_data[i].cuda())
        _, predicted = torch.max(outputs.data, 1)
        total += len(each_worker_te_label[i])
        correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
        acc = correct/total
    final_accs_per_client.append(acc)

results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
)

pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_mnist_slow_median_trim20.pkl'), 'wb'), )

100 0.1163878932696137 0.24870569494225409
200 0.008960573476702509 0.24870569494225409
300 0.17293906810035842 0.24870569494225409
400 0.11031461569095978 0.24870569494225409
500 0.1424731182795699 0.24870569494225409
600 0.2670250896057348 0.2670250896057348
700 0.4345878136200717 0.4369772998805257
800 0.4911389884508164 0.49193548387096775
900 0.6145957785742732 0.7382516925527678
999 0.7387495021903624 0.8428912783751493


# Good FedAvg baseline MNIST + Fang distribution + 80 clients

In [12]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(train_data, labels, model, optimizer, batch_size=20):
    model.train()
    criterion = nn.CrossEntropyLoss()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    len_t = (len(train_data) // batch_size)
    if len(train_data)%batch_size:
        len_t += 1

    r=np.arange(len(train_data))
    np.random.shuffle(r)
    
    train_data = train_data[r]
    labels = labels[r]
    
    for ind in range(len_t):

        inputs = train_data[ind * batch_size:(ind + 1) * batch_size]
        targets = labels[ind * batch_size:(ind + 1) * batch_size]

        inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return (losses.avg, top1.avg)


def test(test_data, labels, model, criterion, use_cuda, debug_='MEDIUM', batch_size=64):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    len_t = (len(test_data) // batch_size)
    if len(test_data)%batch_size:
        len_t += 1

    with torch.no_grad():
        for ind in range(len_t):
            # measure data loading time
            inputs = test_data[ind * batch_size:(ind + 1) * batch_size]
            targets = labels[ind * batch_size:(ind + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()

            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

    return (losses.avg, top1.avg)

# Mean without any attack

In [29]:
torch.cuda.empty_cache()
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20

use_cuda=True
criterion = nn.CrossEntropyLoss()
resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
fed_model.apply(init_weights)
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)
        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))
        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)
    agg_update = torch.mean(user_updates, 0)
    del user_updates
    #update the global model
    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    #load the state dictionary back into the global model
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, val_loss, val_acc, best_global_acc))
    epoch_num+=1

final_accs_per_client=[]
for i in range(num_workers):
    client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                   fed_model, criterion, use_cuda)
    final_accs_per_client.append(client_acc)
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_mnist_fast_mean.pkl'), 'wb'))

e 0 val loss 0.902 val acc 85.872 best val_acc 85.872
e 10 val loss 0.079 val acc 97.601 best val_acc 97.601
e 20 val loss 0.059 val acc 98.208 best val_acc 98.208
e 30 val loss 0.052 val acc 98.457 best val_acc 98.467
e 40 val loss 0.048 val acc 98.616 best val_acc 98.616
e 49 val loss 0.045 val acc 98.686 best val_acc 98.686
e 50 val loss 0.045 val acc 98.686 best val_acc 98.686


# Mean under full trim attack

In [30]:
torch.cuda.empty_cache()
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20

use_cuda=True
criterion = nn.CrossEntropyLoss()
resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
fed_model.apply(init_weights)
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_updates=[]
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)
        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))
        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)
    
    user_updates = full_trim(user_updates, nbyz)
    agg_update = tr_mean(user_updates, nbyz)
    del user_updates
    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, val_loss, val_acc, best_global_acc))
    epoch_num+=1

final_accs_per_client=[]
for i in range(num_workers):
    client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                   fed_model, criterion, use_cuda)
    final_accs_per_client.append(client_acc)
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_mnist_fast_mean_trmean_trim20.pkl'), 'wb'))

e 0 val loss 1.760 val acc 78.465 best val_acc 78.465
e 10 val loss 0.134 val acc 95.998 best val_acc 95.998
e 20 val loss 0.122 val acc 96.426 best val_acc 96.426
e 30 val loss 0.118 val acc 96.724 best val_acc 96.724
e 40 val loss 0.116 val acc 96.794 best val_acc 96.854
e 49 val loss 0.115 val acc 96.864 best val_acc 96.864
e 50 val loss 0.116 val acc 96.824 best val_acc 96.864


# Median without any attack

In [31]:
torch.cuda.empty_cache()
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20

use_cuda=True
criterion = nn.CrossEntropyLoss()
resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
fed_model.apply(init_weights)
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)
        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))
        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)
    
    agg_update = torch.median(user_updates, 0)[0]
    del user_updates
    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, val_loss, val_acc, best_global_acc))
    epoch_num+=1

final_accs_per_client=[]
for i in range(num_workers):
    client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                   fed_model, criterion, use_cuda)
    final_accs_per_client.append(client_acc)
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_mnist_fast_median.pkl'), 'wb'))

e 0 val loss 1.133 val acc 86.858 best val_acc 86.858
e 10 val loss 0.082 val acc 97.541 best val_acc 97.541
e 20 val loss 0.062 val acc 98.108 best val_acc 98.108
e 30 val loss 0.054 val acc 98.407 best val_acc 98.407
e 40 val loss 0.050 val acc 98.477 best val_acc 98.507
e 49 val loss 0.047 val acc 98.576 best val_acc 98.576
e 50 val loss 0.047 val acc 98.596 best val_acc 98.596


# Median under full trim attack

In [32]:
torch.cuda.empty_cache()
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20

use_cuda=True
criterion = nn.CrossEntropyLoss()
resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
fed_model.apply(init_weights)
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_updates=[]
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)
        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))
        update =  (params - model_received)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)
    
    user_updates = full_trim(user_updates, nbyz)
    agg_update = torch.median(user_updates, 0)[0]
    del user_updates
    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, val_loss, val_acc, best_global_acc))
    epoch_num+=1

final_accs_per_client=[]
for i in range(num_workers):
    client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                   fed_model, criterion, use_cuda)
    final_accs_per_client.append(client_acc)
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_mnist_fast_median_trim20.pkl'), 'wb'))

e 0 val loss 1.735 val acc 76.563 best val_acc 76.563
e 10 val loss 0.126 val acc 96.207 best val_acc 96.207
e 20 val loss 0.109 val acc 96.774 best val_acc 96.774
e 30 val loss 0.103 val acc 97.023 best val_acc 97.023
e 40 val loss 0.101 val acc 97.123 best val_acc 97.143
e 49 val loss 0.099 val acc 97.202 best val_acc 97.232
e 50 val loss 0.099 val acc 97.182 best val_acc 97.232


# ====== Below are previous explorations. Ignore them! ======

# Faster baselines with server momentum

In [None]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 20
nepochs=100
best_acc=0
lr = 0.1
cnn_optimizer = SGD(net.parameters(), lr = lr, momentum=0.9, weight_decay=1e-5)
for e in range(nepochs):
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    agg_grads = torch.mean(user_grads, 0)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    if e%100==0 or e==nepochs-1:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            acc = correct/total
            best_acc = max(best_acc, acc)
        print(e, acc, best_acc)

In [22]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 20
nepochs=200
best_acc=0
lr = 0.05
cnn_optimizer = SGD(net.parameters(), lr = lr, momentum=0.9, weight_decay=1e-5)
for e in range(nepochs):
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
    agg_grads = tr_mean(user_grads, nbyz)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    if e%20==0 or e==nepochs-1:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            acc = correct/total
            best_acc = max(best_acc, acc)
        print(e, acc, best_acc)

0 0.18996415770609318 0.18996415770609318
20 0.09259259259259259 0.18996415770609318
40 0.23725607327757867 0.23725607327757867
60 0.09219434488251693 0.23725607327757867
80 0.1942453205894066 0.23725607327757867
100 0.5595380326563122 0.5595380326563122
120 0.09747112704101951 0.5595380326563122
140 0.09747112704101951 0.5595380326563122
160 0.09747112704101951 0.5595380326563122
180 0.09747112704101951 0.5595380326563122
199 0.09747112704101951 0.5595380326563122


In [None]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 20
nepochs=1000
best_acc=0
lr = 0.01
cnn_optimizer = SGD(net.parameters(), lr = lr, momentum=0.9, weight_decay=1e-5)
for e in range(nepochs):
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    agg_grads = torch.median(user_grads, 0)[0]
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    if e%100==0 or e==nepochs-1:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            acc = correct/total
            best_acc = max(best_acc, acc)
        print(e, acc, best_acc)

In [None]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 20
nepochs=100
best_acc=0
lr = 0.1
cnn_optimizer = SGD(net.parameters(), lr = lr, momentum=0.9, weight_decay=1e-5)
for e in range(nepochs):
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
    agg_grads = torch.median(user_grads, 0)[0]
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)

    cnn_optimizer.step(model_grads)
    if e%20==0 or e==nepochs-1:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            acc = correct/total
            best_acc = max(best_acc, acc)
        print(e, acc, best_acc)

# Mount trim attack bad baselines

In [37]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 28
lr = 0.02
for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    user_grads = full_trim(user_grads, nbyz)
    agg_grads = tr_mean(user_grads, nbyz)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    if e%50==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(e,correct/total)

0 0.08076080462059351
50 0.29874526986656047
100 0.38548097988448515
150 0.4051981676956781
200 0.4898426608245369
250 0.6350328619796853
300 0.8146783509261103
350 0.7818163712407887
400 0.807010555666202
450 0.6992630950009958
500 0.7984465245966939
550 0.8383788090021908
600 0.8644692292372037
650 0.9151563433578969
700 0.8799044015136427
750 0.7283409679346744
800 0.9242182832105158
850 0.9305915156343358
900 0.9334793865763792
950 0.9288986257717586
999 0.933379804819757


In [44]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 20
lr = 0.02
for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    user_grads = full_trim(user_grads, nbyz)
    agg_grads = tr_mean(user_grads, nbyz)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    if e%50==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(e,correct/total)

0 0.10250796178343949
50 0.2483081210191083
100 0.3976910828025478
150 0.5472730891719745
200 0.7301950636942676
250 0.8241441082802548
300 0.7927945859872612
350 0.794187898089172
400 0.8379777070063694
450 0.8304140127388535
500 0.8810708598726115
550 0.8898288216560509
600 0.9166003184713376
650 0.9093351910828026
700 0.9161027070063694
750 0.9248606687898089
800 0.8802746815286624
850 0.9094347133757962
900 0.9301353503184714
950 0.9267515923566879
999 0.929140127388535


In [45]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 20
lr = 0.01
for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    user_grads = full_trim(user_grads, nbyz)
    agg_grads = tr_mean(user_grads, nbyz)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    if e%50==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(e,correct/total)

0 0.1506767515923567
50 0.11275875796178345
100 0.3004578025477707
150 0.34156050955414013
200 0.40017914012738853
250 0.43580812101910826
300 0.4590963375796178
350 0.4823845541401274
400 0.6436106687898089
450 0.7318869426751592
500 0.6098726114649682
550 0.7991640127388535
600 0.7964769108280255
650 0.7870222929936306
700 0.7532842356687898
750 0.800656847133758
800 0.8605692675159236
850 0.7570660828025477
900 0.897890127388535
950 0.9058519108280255
999 0.8165804140127388


In [38]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
nbyz = 28
lr = 0.02
for e in range(1000):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())

        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    user_grads = full_trim(user_grads, nbyz)
    agg_grads = torch.median(user_grads, 0)[0]
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    if e%50==0 or e==999:
        total, correct = 0,0
        with torch.no_grad():
            inputs, labels = global_test_data.to(device), global_test_label.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(e,correct/total)

0 0.09659430392352121
50 0.0898227444732125
100 0.02101175064728142
150 0.12557259510057758
200 0.14280023899621588
250 0.2940649273053177
300 0.40818562039434375
350 0.5251941844254132
400 0.5418243377813184
450 0.6308504282015535
500 0.7849034056960765
550 0.8366859191396137
600 0.8732324238199561
650 0.8091017725552678
700 0.8170683130850428
750 0.8386775542720574
800 0.8091017725552678
850 0.7159928301135232
900 0.9263095000995818
950 0.9284007169886477
999 0.9061939852619


# attack good baselines

In [40]:
torch.cuda.empty_cache()
nepochs=2000
local_epochs = 2
batch_size = 8
num_workers = 100
local_lr = 0.05
global_lr = 1
nepochs = 50
nbyz = 28
byz_type = 'full_trim'
aggregation = 'trim'

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
net.apply(init_weights)

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        # optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.99**epoch_num))
        optimizer = optim.SGD(model.parameters(), lr = lr)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)
    
    user_updates = full_trim(user_updates, nbyz)
    agg_update = tr_mean(user_updates, nbyz)
    # agg_update = torch.mean(user_updates, 0)

    del user_updates
    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    net.apply(init_weights)
    
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name

    fed_model.load_state_dict(state_dict)

    if epoch_num%5==0 or epoch_num==nepochs-1:
        val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
        is_best = best_global_acc < val_acc
        best_global_acc = max(best_global_acc, val_acc)
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    epoch_num+=1

e 0 benign_norm 1.284 val loss 2.157 val acc 50.827 best val_acc 50.827
e 5 benign_norm 0.698 val loss 0.288 val acc 91.546 best val_acc 91.546
e 10 benign_norm 0.639 val loss 0.227 val acc 93.129 best val_acc 93.129
e 15 benign_norm 0.632 val loss 0.214 val acc 93.467 best val_acc 93.467
e 20 benign_norm 0.636 val loss 0.208 val acc 93.736 best val_acc 93.736
e 25 benign_norm 0.645 val loss 0.205 val acc 93.916 best val_acc 93.916
e 30 benign_norm 0.655 val loss 0.207 val acc 93.926 best val_acc 93.926
e 35 benign_norm 0.660 val loss 0.204 val acc 94.155 best val_acc 94.155
e 40 benign_norm 0.672 val loss 0.203 val acc 94.264 best val_acc 94.264
e 45 benign_norm 0.681 val loss 0.203 val acc 94.344 best val_acc 94.344
e 49 benign_norm 0.681 val loss 0.192 val acc 94.812 best val_acc 94.812
e 50 benign_norm 0.689 val loss 0.193 val acc 94.692 best val_acc 94.812


In [42]:
torch.cuda.empty_cache()
nepochs=2000
local_epochs = 2
batch_size = 8
num_workers = 100
local_lr = 0.05
global_lr = 1
nepochs = 50
nbyz = 28
byz_type = 'full_trim'
aggregation = 'trim'

criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
net.apply(init_weights)

model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        # optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.99**epoch_num))
        optimizer = optim.SGD(model.parameters(), lr = lr)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)
    
    user_updates = full_trim(user_updates, nbyz)
    agg_update = torch.median(user_updates, 0)[0]

    del user_updates

    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    net.apply(init_weights)
    
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name

    fed_model.load_state_dict(state_dict)

    if epoch_num%5==0 or epoch_num==nepochs-1:
        val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
        is_best = best_global_acc < val_acc
        best_global_acc = max(best_global_acc, val_acc)
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    epoch_num+=1

e 0 benign_norm 1.295 val loss 2.157 val acc 60.795 best val_acc 60.795
e 5 benign_norm 0.706 val loss 0.305 val acc 91.316 best val_acc 91.316
e 10 benign_norm 0.631 val loss 0.234 val acc 92.920 best val_acc 92.920
e 15 benign_norm 0.626 val loss 0.217 val acc 93.457 best val_acc 93.457
e 20 benign_norm 0.630 val loss 0.209 val acc 93.527 best val_acc 93.527
e 25 benign_norm 0.640 val loss 0.200 val acc 93.726 best val_acc 93.726
e 30 benign_norm 0.645 val loss 0.199 val acc 93.935 best val_acc 93.935
e 35 benign_norm 0.659 val loss 0.197 val acc 94.045 best val_acc 94.045
e 40 benign_norm 0.679 val loss 0.195 val acc 94.294 best val_acc 94.294
e 45 benign_norm 0.682 val loss 0.199 val acc 94.384 best val_acc 94.384
e 49 benign_norm 0.687 val loss 0.193 val acc 94.523 best val_acc 94.523
e 50 benign_norm 0.696 val loss 0.190 val acc 94.553 best val_acc 94.553
