# Trim attack on FedAvg + Trimmed-mean for Fashion-MNIST


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

  warn(f"Failed to load image Python extension: {e}")


cuda


In [3]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))
            
        return per_participant_list

In [6]:
def get_fang_train_data(trainset, num_workers=100, bias=0.5, force=False):
    dist_file = 'fang_nworkers%d_bias%.1f.pkl' % (num_workers, bias)
    if not force and os.path.exists(dist_file):
        print('Loading fang distribution for num_workers %d and bias %.1f from memory' % (num_workers, bias))
        return pickle.load(open(dist_file, 'rb'))
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    per_participant_list = defaultdict(list)
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()
        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y
        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        per_participant_list[selected_worker].extend([i])
    
    print('Saving fang distribution for num_workers %d and bias %.1f from memory' % (num_workers, bias))
    pickle.dump(per_participant_list, open(dist_file, 'wb'))
    return per_participant_list

In [7]:
def get_federated_data(trainset, num_workers, distribution='fang', param=1, force=False):
    if distribution == 'fang':
        per_participant_list = get_fang_train_data(trainset, num_workers, bias=param, force=force)
    elif distribution == 'dirichlet':
        per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=param, force=force)

    each_worker_idx = [[] for _ in range(num_workers)]
    each_worker_val_idx = [[] for _ in range(num_workers)]
    each_worker_te_idx = [[] for _ in range(num_workers)]
    
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    each_worker_val_data = [[] for _ in range(num_workers)]
    each_worker_val_label = [[] for _ in range(num_workers)]
    
    each_worker_te_data = [[] for _ in range(num_workers)]
    each_worker_te_label = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(5 * w_len/7)
        len_val = int(w_len/7)
        np.random.shuffle(w_indices)
        
        tr_idx = w_indices[:len_tr]
        val_idx = w_indices[len_tr: len_tr+len_val]
        te_idx = w_indices[len_tr+len_val:]

        for idx in tr_idx:
            each_worker_data[worker_idx].append(trainset[idx][0])
            each_worker_label[worker_idx].append(trainset[idx][1])
        each_worker_data[worker_idx] = torch.stack(each_worker_data[worker_idx])
        each_worker_label[worker_idx] = torch.Tensor(each_worker_label[worker_idx]).long()
        
        for idx in val_idx:
            each_worker_val_data[worker_idx].append(trainset[idx][0])
            each_worker_val_label[worker_idx].append(trainset[idx][1])
        each_worker_val_data[worker_idx] = torch.stack(each_worker_val_data[worker_idx])
        each_worker_val_label[worker_idx] = torch.Tensor(each_worker_val_label[worker_idx]).long()
        
        for idx in te_idx:
            each_worker_te_data[worker_idx].append(trainset[idx][0])
            each_worker_te_label[worker_idx].append(trainset[idx][1])
        each_worker_te_data[worker_idx] = torch.stack(each_worker_te_data[worker_idx])
        each_worker_te_label[worker_idx] = torch.Tensor(each_worker_te_label[worker_idx]).long()
    
    global_val_data = torch.concatenate(each_worker_val_data)
    global_val_label = torch.concatenate(each_worker_val_label)
    
    global_te_data = torch.concatenate(each_worker_te_data)
    global_te_label = torch.concatenate(each_worker_te_label)
    
    return each_worker_data, each_worker_label, each_worker_val_data, each_worker_val_label, each_worker_te_data, each_worker_te_label, global_val_data, global_val_label, global_te_data, global_te_label

In [8]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(f):
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [9]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [10]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 3)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(30, 50, 3)
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(1250, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
model = cnn()
sum(p.numel() for p in model.parameters())

266060

In [11]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(train_data, labels, model, optimizer, batch_size=20):
    model.train()
    criterion = nn.CrossEntropyLoss()
    losses = AverageMeter()
    top1 = AverageMeter()
    len_t = (len(train_data) // batch_size)
    if len(train_data)%batch_size:
        len_t += 1
    r=np.arange(len(train_data))
    np.random.shuffle(r)
    train_data = train_data[r]
    labels = labels[r]
    for ind in range(len_t):
        inputs = train_data[ind * batch_size:(ind + 1) * batch_size].cuda()
        targets = labels[ind * batch_size:(ind + 1) * batch_size].cuda()
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return (losses.avg, top1.avg)


def test(test_data, labels, model, criterion, use_cuda, debug_='MEDIUM', batch_size=64):
    losses = AverageMeter()
    top1 = AverageMeter()
    model.eval()
    len_t = (len(test_data) // batch_size)
    if len(test_data)%batch_size:
        len_t += 1
    with torch.no_grad():
        for ind in range(len_t):
            inputs = test_data[ind * batch_size:(ind + 1) * batch_size].cuda()
            targets = labels[ind * batch_size:(ind + 1) * batch_size].cuda()
            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
    return (losses.avg, top1.avg)


In [12]:
all_data = torch.utils.data.ConcatDataset((trainset, testset))
num_workers = 100
distribution='fang'
param = .5
force = False

each_worker_data, each_worker_label, each_worker_val_data, each_worker_val_label, each_worker_te_data, each_worker_te_label, global_val_data, global_val_label, global_te_data, global_te_label = get_federated_data(
    all_data, num_workers=num_workers, distribution=distribution, param=param, force=force)

Loading fang distribution for num_workers 100 and bias 0.5 from memory


In [13]:
len(global_te_label), len(global_val_label)

(8664, 8526)

In [19]:
torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()
criterion = nn.CrossEntropyLoss()

local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz_ = [0]
max_nbyz = 20

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
        
for nbyz in nbyz_:
    best_global_acc = 0
    best_global_te_acc = 0
    epoch_num = 0

    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    model_received = []
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

    best_accs_per_round, best_te_accs_per_round = [], []
    accs_per_round, te_accs_per_round = [], []
    loss_per_round = []
    home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

    while epoch_num <= nepochs:
        torch.cuda.empty_cache()
        round_clients = np.arange(max_nbyz-nbyz, num_workers)
        round_benign = round_clients
        user_updates=[]
        benign_norm = 0

        for i in round_benign:
            model = copy.deepcopy(fed_model)
            optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9, weight_decay=1e-4)
            for epoch in range(local_epochs):
                train_loss, train_acc = train(
                    each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

            params = []
            for i, (name, param) in enumerate(model.state_dict().items()):
                params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                    (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

            update =  (params - model_received)
            benign_norm += torch.norm(update)/len(round_benign)
            user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

        user_updates = full_trim(user_updates, nbyz)
        agg_update = tr_mean(user_updates, nbyz)
        del user_updates

        model_received = model_received + global_lr * agg_update
        fed_model = cnn().to(device)
        start_idx=0
        state_dict = {}
        previous_name = 'none'
        for i, (name, param) in enumerate(fed_model.state_dict().items()):
            start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
            start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
            params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
            state_dict[name] = params
            previous_name = name
        fed_model.load_state_dict(state_dict)

        val_loss, val_acc = test(global_val_data, global_val_label.long(), fed_model, criterion, use_cuda)
        te_loss, te_acc = test(global_te_data, global_te_label.long(), fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        best_accs_per_round.append(best_global_acc)
        best_te_accs_per_round.append(best_global_te_acc)
        accs_per_round.append(val_acc)
        te_accs_per_round.append(te_acc)
        loss_per_round.append(val_loss)

        if epoch_num%5==0 or epoch_num==nepochs-1:
            print('nat %d | e %d val loss %.3f acc %.3f | test loss %.3f acc %.3f || best val_acc %.3f'% (nbyz, epoch_num, val_loss, val_acc, te_loss, te_acc, best_global_acc))

        epoch_num+=1

    final_accs_per_client=[]
    for i in range(max_nbyz, num_workers):
        client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                       fed_model, criterion, use_cuda)
        final_accs_per_client.append(client_acc)
    results = collections.OrderedDict(
        final_accs_per_client=np.array(final_accs_per_client),
        accs_per_round=np.array(accs_per_round),
        best_accs_per_round=np.array(best_accs_per_round),
        loss_per_round=np.array(loss_per_round)
    )
    pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/fashion100_personalized_eval_noattack.pkl'), 'wb'))

nat 0 | e 0 val loss 1.093 acc 64.004 | test loss 1.075 acc 65.512 || best val_acc 64.004
nat 0 | e 5 val loss 0.556 acc 78.489 | test loss 0.545 acc 78.659 || best val_acc 78.489
nat 0 | e 10 val loss 0.486 acc 81.316 | test loss 0.474 acc 81.798 || best val_acc 81.316
nat 0 | e 15 val loss 0.445 acc 83.533 | test loss 0.431 acc 84.245 || best val_acc 83.533
nat 0 | e 20 val loss 0.415 acc 84.987 | test loss 0.399 acc 85.596 || best val_acc 84.987
nat 0 | e 25 val loss 0.394 acc 85.820 | test loss 0.376 acc 86.507 || best val_acc 85.820
nat 0 | e 30 val loss 0.379 acc 86.395 | test loss 0.360 acc 87.211 || best val_acc 86.406
nat 0 | e 35 val loss 0.362 acc 87.098 | test loss 0.343 acc 87.604 || best val_acc 87.204
nat 0 | e 40 val loss 0.351 acc 87.650 | test loss 0.333 acc 87.996 || best val_acc 87.650
nat 0 | e 45 val loss 0.341 acc 87.990 | test loss 0.323 acc 88.250 || best val_acc 87.990
nat 0 | e 49 val loss 0.335 acc 88.177 | test loss 0.316 acc 88.643 || best val_acc 88.177
n

In [22]:
import collections
final_accs_per_client=[]
for i in range(max_nbyz, num_workers):
    client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                   fed_model, criterion, use_cuda)
    final_accs_per_client.append(client_acc)
results = collections.OrderedDict(
    final_accs_per_client=np.array(final_accs_per_client),
    accs_per_round=np.array(accs_per_round),
    best_accs_per_round=np.array(best_accs_per_round),
    loss_per_round=np.array(loss_per_round)
)
pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/fashion100_personalized_eval_noattack.pkl'), 'wb'))

In [27]:
torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()
criterion = nn.CrossEntropyLoss()

local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz_ = [1, 5, 10, 15, 20]
max_nbyz = 20

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

for nbyz in nbyz_:
    best_global_acc = 0
    best_global_te_acc = 0
    epoch_num = 0

    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    model_received = []
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

    best_accs_per_round, best_te_accs_per_round = [], []
    accs_per_round, te_accs_per_round = [], []
    loss_per_round = []
    home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

    while epoch_num <= nepochs:
        torch.cuda.empty_cache()
        round_clients = np.arange(max_nbyz - nbyz, num_workers)
        round_benign = round_clients
        user_updates=[]
        benign_norm = 0

        for i in round_benign:
            model = copy.deepcopy(fed_model)
            optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9, weight_decay=1e-4)
            for epoch in range(local_epochs):
                train_loss, train_acc = train(
                    each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

            params = []
            for i, (name, param) in enumerate(model.state_dict().items()):
                params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                    (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

            update =  (params - model_received)
            benign_norm += torch.norm(update)/len(round_benign)
            user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

        user_updates = full_trim(user_updates, nbyz)
        agg_update = tr_mean(user_updates, nbyz)
        del user_updates

        model_received = model_received + global_lr * agg_update
        fed_model = cnn().to(device)
        start_idx=0
        state_dict = {}
        previous_name = 'none'
        for i, (name, param) in enumerate(fed_model.state_dict().items()):
            start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
            start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
            params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
            state_dict[name] = params
            previous_name = name
        fed_model.load_state_dict(state_dict)

        val_loss, val_acc = test(global_val_data, global_val_label.long(), fed_model, criterion, use_cuda)
        te_loss, te_acc = test(global_te_data, global_te_label.long(), fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        best_accs_per_round.append(best_global_acc)
        best_te_accs_per_round.append(best_global_te_acc)
        accs_per_round.append(val_acc)
        te_accs_per_round.append(te_acc)
        loss_per_round.append(val_loss)

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('nat %d | e %d val loss %.3f acc %.3f | test loss %.3f acc %.3f || best val_acc %.3f'% (nbyz, epoch_num, val_loss, val_acc, te_loss, te_acc, best_global_acc))

        epoch_num+=1

    final_accs_per_client=[]
    for i in range(num_workers):
        client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                       fed_model, criterion, use_cuda)
        final_accs_per_client.append(client_acc)
    results = collections.OrderedDict(
        final_accs_per_client=np.array(final_accs_per_client),
        accs_per_round=np.array(accs_per_round),
        best_accs_per_round=np.array(best_accs_per_round),
        loss_per_round=np.array(loss_per_round)
    )
    print('---> len of final_accs_per_client: %d \n'% len(final_accs_per_client))
    pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/fashion100_personalized_eval_full_knowledge_trim%d.pkl'%(nbyz)), 'wb'))


nat 1 | e 0 val loss 1.048 acc 65.705 | test loss 1.032 acc 67.417 || best val_acc 65.705
nat 1 | e 10 val loss 0.476 acc 81.703 | test loss 0.464 acc 82.214 || best val_acc 81.703
nat 1 | e 20 val loss 0.409 acc 85.057 | test loss 0.394 acc 85.803 || best val_acc 85.057
nat 1 | e 30 val loss 0.372 acc 86.887 | test loss 0.358 acc 87.281 || best val_acc 86.887
nat 1 | e 40 val loss 0.350 acc 87.661 | test loss 0.337 acc 87.950 || best val_acc 87.661
nat 1 | e 49 val loss 0.334 acc 88.201 | test loss 0.323 acc 88.573 || best val_acc 88.201
nat 1 | e 50 val loss 0.333 acc 88.271 | test loss 0.321 acc 88.723 || best val_acc 88.271
---> len of final_accs_per_client: 100 

nat 5 | e 0 val loss 1.132 acc 67.488 | test loss 1.117 acc 69.090 || best val_acc 67.488
nat 5 | e 10 val loss 0.508 acc 80.577 | test loss 0.493 acc 81.221 || best val_acc 80.577
nat 5 | e 20 val loss 0.451 acc 83.005 | test loss 0.435 acc 84.014 || best val_acc 83.028
nat 5 | e 30 val loss 0.416 acc 84.577 | test loss 

In [26]:
torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()
criterion = nn.CrossEntropyLoss()

local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz_ = [1, 5, 10, 15, 20]
max_nbyz = 20

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
        
for nbyz in nbyz_:
    best_global_acc = 0
    best_global_te_acc = 0
    epoch_num = 0

    fed_model = cnn().to(device)
    fed_model.apply(init_weights)
    model_received = []
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

    best_accs_per_round, best_te_accs_per_round = [], []
    accs_per_round, te_accs_per_round = [], []
    loss_per_round = []
    home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

    while epoch_num <= nepochs:
        torch.cuda.empty_cache()
        round_clients = np.arange(max_nbyz - nbyz, num_workers)
        round_benign = round_clients
        user_updates=[]
        benign_norm = 0

        for i in round_benign:
            model = copy.deepcopy(fed_model)
            optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9, weight_decay=1e-4)
            for epoch in range(local_epochs):
                train_loss, train_acc = train(
                    each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)

            params = []
            for i, (name, param) in enumerate(model.state_dict().items()):
                params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                    (params, param.view(-1).data.type(torch.cuda.FloatTensor)))

            update =  (params - model_received)
            benign_norm += torch.norm(update)/len(round_benign)
            user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

        user_updates[:nbyz] = full_trim(user_updates[:nbyz], nbyz)
        agg_update = tr_mean(user_updates, nbyz)

        del user_updates
        model_received = model_received + global_lr * agg_update
        fed_model = cnn().to(device)
        start_idx=0
        state_dict = {}
        previous_name = 'none'
        for i, (name, param) in enumerate(fed_model.state_dict().items()):
            start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
            start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
            params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
            state_dict[name] = params
            previous_name = name
        fed_model.load_state_dict(state_dict)

        val_loss, val_acc = test(global_val_data, global_val_label.long(), fed_model, criterion, use_cuda)
        te_loss, te_acc = test(global_te_data, global_te_label.long(), fed_model, criterion, use_cuda)

        is_best = best_global_acc < val_acc

        best_global_acc = max(best_global_acc, val_acc)

        best_accs_per_round.append(best_global_acc)
        best_te_accs_per_round.append(best_global_te_acc)
        accs_per_round.append(val_acc)
        te_accs_per_round.append(te_acc)
        loss_per_round.append(val_loss)

        if epoch_num%10==0 or epoch_num==nepochs-1:
            print('nat %d | e %d val loss %.3f acc %.3f | test loss %.3f acc %.3f || best val_acc %.3f'% (nbyz, epoch_num, val_loss, val_acc, te_loss, te_acc, best_global_acc))

        epoch_num+=1

    final_accs_per_client=[]
    for i in range(num_workers):
        client_loss, client_acc = test(each_worker_te_data[i], each_worker_te_label[i].long(),
                                       fed_model, criterion, use_cuda)
        final_accs_per_client.append(client_acc)
    results = collections.OrderedDict(
        final_accs_per_client=np.array(final_accs_per_client),
        accs_per_round=np.array(accs_per_round),
        best_accs_per_round=np.array(best_accs_per_round),
        loss_per_round=np.array(loss_per_round)
    )
    print('---> len of final_accs_per_client: %d \n'% len(final_accs_per_client))
    pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/fashion100_personalized_eval_trim%d.pkl'%(nbyz)), 'wb'))

nat 1 | e 0 val loss 1.082 acc 67.816 | test loss 1.067 acc 69.183 || best val_acc 67.816
nat 1 | e 10 val loss 0.477 acc 82.477 | test loss 0.464 acc 82.872 || best val_acc 82.477
nat 1 | e 20 val loss 0.406 acc 85.445 | test loss 0.390 acc 86.184 || best val_acc 85.538
nat 1 | e 30 val loss 0.368 acc 87.028 | test loss 0.353 acc 87.385 || best val_acc 87.028
nat 1 | e 40 val loss 0.344 acc 87.825 | test loss 0.330 acc 88.066 || best val_acc 87.825
nat 1 | e 49 val loss 0.331 acc 88.259 | test loss 0.316 acc 88.504 || best val_acc 88.283
nat 1 | e 50 val loss 0.329 acc 88.318 | test loss 0.315 acc 88.631 || best val_acc 88.318
---> len of final_accs_per_client: 100 

nat 5 | e 0 val loss 1.025 acc 66.561 | test loss 1.005 acc 67.313 || best val_acc 66.561
nat 5 | e 10 val loss 0.469 acc 82.524 | test loss 0.455 acc 82.906 || best val_acc 82.524
nat 5 | e 20 val loss 0.408 acc 85.233 | test loss 0.394 acc 85.849 || best val_acc 85.233
nat 5 | e 30 val loss 0.374 acc 86.254 | test loss 