# FLDetector for Fashion MNIST


In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

  warn(f"Failed to load image Python extension: {e}")


cuda


In [4]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))
            
        return per_participant_list

In [6]:
def get_client_train_data(trainset, num_workers=100, bias=0.5):
    np.random.seed(42)
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10

    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        
        if not len(each_worker_data[selected_worker]):
            each_worker_data[selected_worker] = x[None, :]
        else:
            each_worker_data[selected_worker]= torch.concat((each_worker_data[selected_worker], x[None, :]))
        
        each_worker_label[selected_worker].append(y)
    
    each_worker_tr_data = [[] for _ in range(num_workers)]
    each_worker_tr_label = [[] for _ in range(num_workers)]
    each_worker_val_data = [[] for _ in range(num_workers)]
    each_worker_val_label = [[] for _ in range(num_workers)]
    each_worker_te_data = [[] for _ in range(num_workers)]
    each_worker_te_label = [[] for _ in range(num_workers)]

    for i in range(num_workers):
        w_len = len(each_worker_data[i])
        len_tr = int(5 * w_len / 7)
        len_val = int(1 * w_len / 7)
        len_te = w_len - (len_tr + len_val)
        # tr_idx = np.random.choice(w_len, len_tr, replace=False)
        # te_idx = np.delete(np.arange(w_len), tr_idx)
        w = np.arange(w_len)
        np.random.shuffle(w)
        tr_idx, val_idx, te_idx = w[:len_tr], w[len_tr : (len_tr + len_val)], w[-len_te:]
        each_worker_tr_data[i] = each_worker_data[i][tr_idx]
        each_worker_tr_label[i] = torch.Tensor(each_worker_label[i])[tr_idx]
        each_worker_val_data[i] = each_worker_data[i][val_idx]
        each_worker_val_label[i] = torch.Tensor(each_worker_label[i])[val_idx]
        each_worker_te_data[i] = each_worker_data[i][te_idx]
        each_worker_te_label[i] = torch.Tensor(each_worker_label[i])[te_idx]
        
    global_test_data = torch.concat(each_worker_te_data)
    global_test_label = torch.concat(each_worker_te_label)
    del each_worker_data, each_worker_label
    return each_worker_tr_data, each_worker_tr_label, each_worker_val_data, each_worker_val_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label

In [7]:
def get_client_data_dirichlet(trainset, num_workers, alpha=1, force=False):
    per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=alpha, force=force)
    
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    each_worker_te_data = [[] for _ in range(num_workers)]
    each_worker_te_label = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(6*w_len/7)
        len_te = w_len - len_tr
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        
        for idx in tr_idx:
            each_worker_data[worker_idx].append(trainset[idx][0])
            each_worker_label[worker_idx].append(trainset[idx][1])
        each_worker_data[worker_idx] = torch.stack(each_worker_data[worker_idx])
        each_worker_label[worker_idx] = torch.Tensor(each_worker_label[worker_idx]).long()
        
        for idx in te_idx:
            each_worker_te_data[worker_idx].append(trainset[idx][0])
            each_worker_te_label[worker_idx].append(trainset[idx][1])
        each_worker_te_data[worker_idx] = torch.stack(each_worker_te_data[worker_idx])
        each_worker_te_label[worker_idx] = torch.Tensor(each_worker_te_label[worker_idx]).long()
    
    global_test_data = torch.concat(each_worker_te_data)
    global_test_label = torch.concat(each_worker_te_label)
    
    return each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label


In [8]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(f):
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [9]:
def our_attack_median(all_updates, n_attackers, dev_type='unit_vec'):
    model_re = (all_updates)
    if dev_type == 'unit_vec':
        deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
    elif dev_type == 'sign':
        deviation = torch.sign(model_re)
    elif dev_type == 'std':
        deviation = torch.std(all_updates, 0)
    lamda = torch.Tensor([10.0]).cuda() #compute_lambda_our(all_updates, model_re, n_attackers)
    threshold_diff = 1e-5
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0
    iters = 0 
    while torch.abs(lamda_succ - lamda) > threshold_diff:
        mal_update = (model_re - lamda * deviation)
        mal_updates = torch.stack([mal_update] * n_attackers)
        mal_updates = torch.cat((mal_updates, all_updates), 0)

        agg_grads = torch.median(mal_updates, 0)[0]
        
        loss = torch.norm(agg_grads - model_re)
        
        if prev_loss < loss:
            lamda_succ = lamda
            lamda = lamda + lamda_fail / 2
        else:
            lamda = lamda - lamda_fail / 2

        lamda_fail = lamda_fail / 2
        prev_loss = loss
        
    mal_update = (model_re - lamda_succ * deviation)
    mal_updates = torch.stack([mal_update] * n_attackers)
    mal_updates = torch.cat((mal_updates, all_updates), 0)

    return mal_updates

In [10]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [11]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 3)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(30, 50, 3)
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(1250, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
#         print(x.shape)
        x = self.pool1(F.relu(self.conv1(x)))
#         print(x.shape)
        x = self.pool2(F.relu(self.conv2(x)))
#         print(x.shape)
        x = x.view(x.size(0), -1)
#         print(x.shape)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [12]:
model = cnn()
sum(p.numel() for p in model.parameters())

266060

In [19]:
all_data = torch.utils.data.ConcatDataset((trainset, testset))
num_workers = 100
distribution='fang'
param = .5
force = True
if distribution=='fang':
    each_worker_data, each_worker_label, each_worker_val_data, each_worker_val_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_train_data(all_data, num_workers=100, bias=param)
elif distribution == 'dirichlet':
    each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=param, force=force)

In [15]:
len(global_test_label)

10048

# Slow baselines + fang distribution + 80 benign clients

In [20]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.1)
net = cnn().to(device)
net.apply(init_weights)
num_workers = 100
nbyz = 20
lr = 0.1
batch_size = 32
nepochs=500
best_acc = 0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

cnn_optimizer = SGD(net.parameters(), lr = lr)
for e in range(nepochs):
    #cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        idx = np.random.choice(len(each_worker_data[i]), batch_size)
        output = net_(each_worker_data[i][idx].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i][idx]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    agg_grads = torch.mean(user_grads, 0)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%20==0 or e==nepochs-1:
        print(e, acc, best_acc)

# final_accs_per_client=[]
# for i in range(num_workers):
#     total, correct = 0,0
#     with torch.no_grad():
#         outputs = net(each_worker_te_data[i].cuda())
#         _, predicted = torch.max(outputs.data, 1)
#         total += len(each_worker_te_label[i])
#         correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
#         acc = correct/total
#     final_accs_per_client.append(acc)

# results = collections.OrderedDict(
#     final_accs_per_client=np.array(final_accs_per_client),
#     accs_per_round=np.array(accs_per_round),
#     best_accs_per_round=np.array(best_accs_per_round),
# )

0 0.12201170518797738 0.12201170518797738
20 0.48467413947029064 0.49727209602222
40 0.6145223688126178 0.6856462652514631
60 0.6839599246106537 0.6918956452732864
80 0.7349469298680686 0.7349469298680686
100 0.7505207816684852 0.7528023013589922
120 0.7559765896240452 0.7635155242535463
140 0.7336573752603909 0.7737327646066858
160 0.767086598551731 0.7797837516119432
180 0.7768078563634561 0.785437952584069
200 0.7687729391925404 0.7949608173792283
220 0.7600436464636445 0.8017061799424661
240 0.7942664418212478 0.8073603809145918
260 0.7718480309493105 0.8145025295109612
280 0.7984326951691301 0.8145025295109612
300 0.812518599345303 0.8216446781073307
320 0.8024997520087293 0.8237278047812717
340 0.8204543200079357 0.8268028965380418
360 0.820156730483087 0.8287868267037001
380 0.8140065469695467 0.830969149885924
400 0.821247892074199 0.8337466521178455
420 0.8166848526931852 0.8359289753000695
440 0.8221406606487451 0.8393016565816883
460 0.8300763813113778 0.839599246106537
480 

## Less slow Mean for FedRecover experiments

In [45]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.1)
net = cnn().to(device)
net.apply(init_weights)
num_workers = 100
nbyz = 20
lr = 0.02
nepochs=400
best_acc = 0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

for e in range(nepochs):
    cnn_optimizer = SGD(net.parameters(), lr = lr*(.9995**e))
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    agg_grads = torch.mean(user_grads, 0)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%20==0 or e==nepochs-1:
        print(e, acc, best_acc)

final_accs_per_client=[]
for i in range(num_workers):
    total, correct = 0,0
    with torch.no_grad():
        outputs = net(each_worker_te_data[i].cuda())
        _, predicted = torch.max(outputs.data, 1)
        total += len(each_worker_te_label[i])
        correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
        acc = correct/total
    final_accs_per_client.append(acc)

results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
)

0 0.17454943741909787 0.17454943741909787
20 0.46151548342128845 0.46151548342128845
40 0.5153838494473763 0.5153838494473763
60 0.5330080653191277 0.5589963158418799
80 0.6352683461117196 0.6352683461117196
100 0.6417405157821369 0.6419396594643035
120 0.6835606890371403 0.6897341431843075
140 0.667131335258389 0.6897341431843075
160 0.6735039330877228 0.6997908991337249
180 0.6897341431843075 0.7013840485910584
200 0.7128348103156428 0.7128348103156428
220 0.7156228218659763 0.7156228218659763
240 0.7043712038235587 0.7258787214975605
260 0.7194065518271433 0.7367320521756447
280 0.7203026983968933 0.7392213482027282
300 0.7362341929702281 0.7392213482027282
320 0.7461913770785622 0.7461913770785622
340 0.7515682564970626 0.7534601214776461
360 0.7495768196753958 0.7542566962063129
380 0.7454943741909787 0.7542566962063129
399 0.7600318629891467 0.7600318629891467


In [22]:
tr_len = val_len = te_len = 0
for i in range(len(each_worker_data)):
    tr_len += len(each_worker_data[i])
    val_len += len(each_worker_val_data[i])
    te_len += len(each_worker_te_data[i])
print(f'Lengths of data: train {tr_len} validation {val_len} test {te_len}')

Lengths of data: train 49960 validation 9959 test 10081


# Mean without any attack

In [23]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
num_workers = 100
nbyz = 20
lr = 0.01
nepochs=1000
best_acc = 0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

cnn_optimizer = optim.SGD(net.parameters(), lr = lr)

for e in range(nepochs):
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    agg_grads = torch.mean(user_grads, 0)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    # model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param.grad=param_.cuda()
        # model_grads.append(param_)
    cnn_optimizer.step()
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%20==0 or e==nepochs-1:
        print(e, acc, best_acc)

0 0.09969249082432298 0.09969249082432298
20 0.37238369209403827 0.37238369209403827
40 0.5212776510266839 0.5212776510266839
60 0.5851602023608768 0.5851602023608768
80 0.625136395198889 0.625136395198889
100 0.6549945441920444 0.6549945441920444
120 0.6736434877492312 0.6736434877492312
140 0.6859438547763119 0.6859438547763119
160 0.653209007042952 0.689018946533082
180 0.658764011506795 0.689018946533082
200 0.66302946136296 0.689018946533082
220 0.6715603610752902 0.689018946533082
240 0.6784049201468109 0.689018946533082
260 0.6819759944449956 0.689018946533082
280 0.6867374268425751 0.6906060906656085
300 0.6907052871738915 0.6947723440134908
320 0.689118143041365 0.6992361868862216
340 0.6925900208312668 0.7089574446979466
360 0.7006249380021823 0.710643785338756
380 0.7113381608967364 0.7141156631286578
400 0.7196706675925008 0.7196706675925008
420 0.7209602222001785 0.7211586152167444
440 0.7199682571173495 0.7211586152167444
460 0.7025096716595576 0.7253248685646265
480 0.72

In [49]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
num_workers = 100
nbyz = 20
lr = 0.01
nepochs=1000
best_acc = 0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

for e in range(nepochs):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    agg_grads = torch.mean(user_grads, 0)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%100==0 or e==nepochs-1:
        print(e, acc, best_acc)

final_accs_per_client=[]
for i in range(num_workers):
    total, correct = 0,0
    with torch.no_grad():
        outputs = net(each_worker_te_data[i].cuda())
        _, predicted = torch.max(outputs.data, 1)
        total += len(each_worker_te_label[i])
        correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
        acc = correct/total
    final_accs_per_client.append(acc)

results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_fashion_slow_mean.pkl'), 'wb'), )


0 0.09937269740117495 0.09937269740117495
100 0.5121975505327093 0.5121975505327093
200 0.636861495569053 0.6486109728168874
300 0.6866474161107239 0.6866474161107239
400 0.716120681071393 0.716120681071393
500 0.7206014139201434 0.7275714427959773
600 0.7442995120979787 0.7442995120979787
700 0.7543562680473962 0.7543562680473962
800 0.7470875236483122 0.7549536990938962
900 0.7655083142487304 0.7662053171363139
999 0.7624215871751469 0.7693916160509808


# Mean under full trim attack

In [34]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
num_workers = 100
nbyz = 20
lr = 0.01
nepochs=1000
best_acc = 0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'


for e in range(nepochs):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    user_grads = full_trim(user_grads, nbyz)
    agg_grads = tr_mean(user_grads, 0)
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%100==0 or e==nepochs-1:
        print(e, acc, best_acc)

final_accs_per_client=[]
for i in range(num_workers):
    total, correct = 0,0
    with torch.no_grad():
        outputs = net(each_worker_te_data[i].cuda())
        _, predicted = torch.max(outputs.data, 1)
        total += len(each_worker_te_label[i])
        correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
        acc = correct/total
    final_accs_per_client.append(acc)

results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_fashion_slow_trmean_trim20.pkl'), 'wb'), )

0 0.09767997610275814 0.09767997610275814
100 0.23867370307676988 0.23867370307676988
200 0.3300806531912775 0.3300806531912775
300 0.43234093398386936 0.43234093398386936
400 0.5213581599123768 0.5213581599123768
500 0.5155829931295429 0.5443592552026287
600 0.5495369909389625 0.5729363735935478
700 0.5351986458229613 0.6112715324106343
800 0.568654784426964 0.6192372796973016
900 0.6142586876431345 0.6192372796973016
999 0.5324106342726277 0.6404460818480534


# Median without any attack

In [53]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
num_workers = 100
nbyz = 20
lr = 0.01
nepochs=1000
best_acc = 0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

for e in range(nepochs):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_
    
    agg_grads = torch.median(user_grads, 0)[0]
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%100==0 or e==nepochs-1:
        print(e, acc, best_acc)

final_accs_per_client=[]
for i in range(num_workers):
    total, correct = 0,0
    with torch.no_grad():
        outputs = net(each_worker_te_data[i].cuda())
        _, predicted = torch.max(outputs.data, 1)
        total += len(each_worker_te_label[i])
        correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
        acc = correct/total
    final_accs_per_client.append(acc)

results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_fashion_slow_median.pkl'), 'wb'), )

0 0.09907398187792492 0.09907398187792492
100 0.5924524544458827 0.5924524544458827
200 0.49766006173454147 0.6430349497162202
300 0.6203325699492184 0.6430349497162202
400 0.5586976003186299 0.6684257691924723
500 0.6465199641541373 0.6716120681071392
600 0.6584685850841382 0.6871452753161406
700 0.6669321915762223 0.704172060141392
800 0.6710146370606392 0.704172060141392
900 0.683759832719307 0.704172060141392
999 0.6952105944438912 0.704172060141392


# Median under full trim attack

In [54]:
torch.cuda.empty_cache()
criterion = nn.CrossEntropyLoss()
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
net = cnn().to(device)
net.apply(init_weights)
num_workers = 100
nbyz = 20
lr = 0.01
nepochs=1000
best_acc = 0
best_accs_per_round = []
accs_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'


for e in range(nepochs):
    cnn_optimizer = SGD(net.parameters(), lr = lr)
    user_grads = []
    for i in range(nbyz, num_workers):
        net_ = copy.deepcopy(net)
        net_.zero_grad()
        output = net_(each_worker_data[i].cuda())
        loss = criterion(output, torch.Tensor(each_worker_label[i]).long().cuda())
        loss.backward(retain_graph = True)
        param_grad=[]
        for param in net_.parameters():
            param_grad=param.grad.data.view(-1) if not len(param_grad) else torch.cat((param_grad,param.grad.view(-1)))
        user_grads=param_grad[None, :] if len(user_grads)==0 else torch.cat((user_grads,param_grad[None,:]), 0)
        del net_

    user_grads = full_trim(user_grads, nbyz)
    agg_grads = torch.median(user_grads, 0)[0]
    del user_grads
    start_idx=0
    cnn_optimizer.zero_grad()
    model_grads=[]
    for i, param in enumerate(net.parameters()):
        param_=agg_grads[start_idx:start_idx+len(param.data.view(-1))].reshape(param.data.shape)
        start_idx=start_idx+len(param.data.view(-1))
        param_=param_.cuda()
        model_grads.append(param_)
    cnn_optimizer.step(model_grads)
    total, correct = 0,0
    with torch.no_grad():
        inputs, labels = global_test_data.to(device), global_test_label.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        acc = correct/total
        best_acc = max(best_acc, acc)
    best_accs_per_round.append(best_acc)
    accs_per_round.append(acc)
    if e%100==0 or e==nepochs-1:
        print(e, acc, best_acc)

final_accs_per_client=[]
for i in range(num_workers):
    total, correct = 0,0
    with torch.no_grad():
        outputs = net(each_worker_te_data[i].cuda())
        _, predicted = torch.max(outputs.data, 1)
        total += len(each_worker_te_label[i])
        correct += (predicted == each_worker_te_label[i].long().cuda()).sum().item()
        acc = correct/total
    final_accs_per_client.append(acc)

results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_fashion_slow_median_trim20.pkl'), 'wb'))

0 0.0887185104052574 0.0887185104052574
100 0.18729463307776562 0.21776361644926814
200 0.19675395798068307 0.4230807527631186
300 0.3940057751667828 0.4230807527631186
400 0.34989544956686247 0.5362939360748781
500 0.37429055063228117 0.5362939360748781
600 0.41929702280195164 0.5446579707258787
700 0.4897938862889575 0.5446579707258787
800 0.42178631882903517 0.6279996017126357
900 0.5562083042915463 0.6279996017126357
999 0.5091108234591257 0.6349696305884696


# Good FedAvg baseline Fashion MNIST + Fang distribution + 80 clients

In [55]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(train_data, labels, model, optimizer, batch_size=20):
    model.train()
    criterion = nn.CrossEntropyLoss()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    len_t = (len(train_data) // batch_size)
    if len(train_data)%batch_size:
        len_t += 1

    r=np.arange(len(train_data))
    np.random.shuffle(r)
    
    train_data = train_data[r]
    labels = labels[r]
    
    for ind in range(len_t):

        inputs = train_data[ind * batch_size:(ind + 1) * batch_size]
        targets = labels[ind * batch_size:(ind + 1) * batch_size]

        inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return (losses.avg, top1.avg)


def test(test_data, labels, model, criterion, use_cuda, debug_='MEDIUM', batch_size=64):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    len_t = (len(test_data) // batch_size)
    if len(test_data)%batch_size:
        len_t += 1

    with torch.no_grad():
        for ind in range(len_t):
            # measure data loading time
            inputs = test_data[ind * batch_size:(ind + 1) * batch_size]
            targets = labels[ind * batch_size:(ind + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()

            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

    return (losses.avg, top1.avg)

# Mean without any attack

In [59]:
torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()

local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20

best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
fed_model.apply(init_weights)
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = lr, momentum=0.9, weight_decay=1e-4)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    agg_update = torch.mean(user_updates, 0)
    del user_updates

    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, val_loss, val_acc, best_global_acc))
    epoch_num+=1

final_accs_per_client=[]
for i in range(num_workers):
    client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                   fed_model, criterion, use_cuda)
    final_accs_per_client.append(client_acc)
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_fashion_fast_mean.pkl'), 'wb'))

e 0 val loss 1.402 val acc 69.700 best val_acc 69.700
e 10 val loss 0.442 val acc 83.979 best val_acc 83.979
e 20 val loss 0.362 val acc 87.016 best val_acc 87.016
e 30 val loss 0.327 val acc 88.270 best val_acc 88.290
e 40 val loss 0.306 val acc 89.246 best val_acc 89.266
e 49 val loss 0.295 val acc 89.585 best val_acc 89.585
e 50 val loss 0.293 val acc 89.565 best val_acc 89.585


# Mean under full trim attack

In [60]:
torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()

local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20

best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
fed_model.apply(init_weights)
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = lr, momentum=0.9, weight_decay=1e-4)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    user_updates = full_trim(user_updates, nbyz)
    agg_update = tr_mean(user_updates, nbyz)
    del user_updates

    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, val_loss, val_acc, best_global_acc))
    epoch_num+=1

final_accs_per_client=[]
for i in range(num_workers):
    client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                   fed_model, criterion, use_cuda)
    final_accs_per_client.append(client_acc)
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_fashion_fast_trmean_trim20.pkl'), 'wb'))

e 0 val loss 1.904 val acc 57.204 best val_acc 57.204
e 10 val loss 0.557 val acc 78.941 best val_acc 78.941
e 20 val loss 0.507 val acc 81.470 best val_acc 81.470
e 30 val loss 0.498 val acc 82.137 best val_acc 82.376
e 40 val loss 0.489 val acc 82.585 best val_acc 82.655
e 49 val loss 0.480 val acc 83.142 best val_acc 83.142
e 50 val loss 0.481 val acc 83.142 best val_acc 83.142


# Median without any attack

In [None]:
torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()

local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20

best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
fed_model.apply(init_weights)
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(nbyz, num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = lr, momentum=0.9, weight_decay=1e-4)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    agg_update = torch.median(user_updates, 0)[0]
    del user_updates

    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, val_loss, val_acc, best_global_acc))
    epoch_num+=1

final_accs_per_client=[]
for i in range(num_workers):
    client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                   fed_model, criterion, use_cuda)
    final_accs_per_client.append(client_acc)
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_fashion_fast_median.pkl'), 'wb'))

e 0 val loss 1.726 val acc 59.653 best val_acc 59.653
e 10 val loss 0.451 val acc 83.352 best val_acc 83.352
e 20 val loss 0.369 val acc 86.677 best val_acc 86.677


# Median under full trim attack

In [None]:
torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()

local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20

best_global_acc=0
epoch_num = 0
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
fed_model = cnn().to(device)
fed_model.apply(init_weights)
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = lr, momentum=0.9, weight_decay=1e-4)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(
                each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    user_updates = full_trim(user_updates, nbyz)
    agg_update = torch.median(user_updates, 0)[0]
    del user_updates

    model_received = model_received + global_lr * agg_update
    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    best_accs_per_round.append(best_global_acc)
    accs_per_round.append(val_acc)
    loss_per_round.append(val_loss)
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, val_loss, val_acc, best_global_acc))
    epoch_num+=1

final_accs_per_client=[]
for i in range(num_workers):
    client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                   fed_model, criterion, use_cuda)
    final_accs_per_client.append(client_acc)
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_fashion_fast_median_trim20.pkl'), 'wb'))