# FLDetector for CIFAR10 with Fang/Dirichlet distribution


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [1]:
import collections
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
import copy
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [2]:
from cifar10_models import *
from cifar10_util import *

In [3]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))

        return per_participant_list

In [4]:
def get_fang_train_data(trainset, num_workers=100, bias=0.5, force=False):
    dist_file = 'fang_nworkers%d_bias%.1f.pkl' % (num_workers, bias)
    if not force and os.path.exists(dist_file):
        print('Loading fang distribution for num_workers %d and bias %.1f from memory' % (num_workers, bias))
        return pickle.load(open(dist_file, 'rb'))
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    per_participant_list = defaultdict(list)
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()
        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y
        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        per_participant_list[selected_worker].extend([i])
    
    print('Saving fang distribution for num_workers %d and bias %.1f to memory' % (num_workers, bias))
    pickle.dump(per_participant_list, open(dist_file, 'wb'))
    return per_participant_list

In [5]:
def get_federated_data(trainset, num_workers, distribution='fang', param=1, force=False):
    if distribution == 'fang':
        per_participant_list = get_fang_train_data(trainset, num_workers, bias=param, force=force)
    elif distribution == 'dirichlet':
        per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=param, force=force)

    each_worker_idx = [[] for _ in range(num_workers)] 
    each_worker_te_idx = [[] for _ in range(num_workers)]
    np.random.seed(0)
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(5*w_len/6)
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        each_worker_idx[worker_idx] = w_indices[tr_idx]
        each_worker_te_idx[worker_idx] = w_indices[te_idx]
    global_test_idx = np.concatenate(each_worker_te_idx)
    return each_worker_idx, each_worker_te_idx, global_test_idx

In [6]:
def get_train(dataset, indices, batch_size=32, shuffle=False):
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
    
    return train_loader

In [7]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='/home/vshejwalkar_umass_edu/data/'
# load the train dataset

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_train)
cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_train)

te_cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_test)
te_cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [8]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(f):
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [9]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [10]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(trainloader, model, model_received, criterion, optimizer, pgd=False, eps=2):
    model.train()
    losses = AverageMeter()
    top1 = AverageMeter()
    for batch_ind, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item()/100.0, inputs.size()[0])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if pgd:
            curr_model = list(model.parameters())
            curr_model_vec = parameters_to_vector(curr_model)
            if torch.norm(curr_model_vec - model_received) > eps:
                curr_model_vec = eps*(curr_model_vec - model_received)/torch.norm(curr_model_vec - model_received) + model_received
                vector_to_parameters(curr_model_vec, curr_model)
        
    return (losses.avg, top1.avg)

def test(testloader, model, criterion):
    model.eval()
    losses = AverageMeter()
    top1 = AverageMeter()
    for batch_ind, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data, inputs.size()[0])
        top1.update(prec1/100.0, inputs.size()[0])
    return (losses.avg, top1.avg)

# resnet-BN

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.autograd import Variable

__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']

def _weights_init(m):
    classname = m.__class__.__name__
    #print(classname)
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, option='A'):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            if option == 'A':
                """
                For CIFAR10 ResNet paper uses option A.
                """
                self.shortcut = LambdaLayer(lambda x:
                                            F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            elif option == 'B':
                self.shortcut = nn.Sequential(
                     nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                     nn.BatchNorm2d(self.expansion * planes)
                )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20():
    return ResNet(BasicBlock, [3, 3, 3])

# FLDetector utils

In [12]:
from sklearn.metrics import roc_auc_score
def fldetector(old_gradients, user_grads, b=0, hvp=None, agr='tr_mean'):
    if hvp is not None:
        hvp = torch.from_numpy(hvp).to(device)
        pred_grad = copy.deepcopy(old_gradients)
        distance = []
        for i in range(len(old_gradients)):
            pred_grad[i] += hvp
        pred = np.zeros(100)
        pred[:b] = 1
        distance = torch.norm(pred_grad - user_grads, dim = 1).cpu().numpy()
        distance = distance / np.sum(distance)
    else:
        distance = None
    
    if agr == 'average':
        agg_grads = torch.mean(user_grads, 0)
    elif agr == 'median':
        agg_grads = torch.median(user_grads, 0)[0]
    elif agr == 'tr_mean':
        agg_grads = tr_mean(user_grads, b)
    return agg_grads, distance

def lbfgs(S_k_list, Y_k_list, v):
    curr_S_k = torch.stack(S_k_list).T
    curr_Y_k = torch.stack(Y_k_list).T
    S_k_time_Y_k = np.dot(curr_S_k.T.cpu().numpy(), curr_Y_k.cpu().numpy())
    S_k_time_S_k = np.dot(curr_S_k.T.cpu().numpy(), curr_S_k.cpu().numpy())
    R_k = np.triu(S_k_time_Y_k)
    L_k = S_k_time_Y_k - R_k
    sigma_k = np.dot(Y_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()) / (np.dot(S_k_list[-1].unsqueeze(0).cpu().numpy(), S_k_list[-1].unsqueeze(0).T.cpu().numpy()))
    D_k_diag = np.diag(S_k_time_Y_k)
    upper_mat = np.concatenate((sigma_k * S_k_time_S_k, L_k), axis=1)
    lower_mat = np.concatenate((L_k.T, -np.diag(D_k_diag)), axis=1)
    mat = np.concatenate((upper_mat, lower_mat), axis=0)
    mat_inv = np.linalg.inv(mat)

    approx_prod = sigma_k * v.cpu().numpy()
    approx_prod = approx_prod.T
    p_mat = np.concatenate((np.dot(curr_S_k.T.cpu().numpy(), sigma_k * v.unsqueeze(0).T.cpu().numpy()), np.dot(curr_Y_k.T.cpu().numpy(), v.unsqueeze(0).T.cpu().numpy())), axis=0)
    approx_prod -= np.dot(np.dot(np.concatenate((sigma_k * curr_S_k.cpu().numpy(), curr_Y_k.cpu().numpy()), axis=1), mat_inv), p_mat)

    return approx_prod

def detection(score, nobyz, nworkers):
    estimator = KMeans(n_clusters=2)
    estimator.fit(score.reshape(-1, 1))
    label_pred = estimator.labels_
    if np.mean(score[label_pred==0])<np.mean(score[label_pred==1]):
        #0 is the label of malicious clients
        label_pred = 1 - label_pred
    real_label=np.ones(nworkers)
    real_label[:nobyz]=0
    acc=len(label_pred[label_pred==real_label])/nworkers
    recall=1-np.sum(label_pred[:nobyz])/nobyz
    fpr=1-np.sum(label_pred[nobyz:])/(nworkers-nobyz)
    fnr=np.sum(label_pred[:nobyz])/nobyz
    auc = roc_auc_score(real_label, label_pred)
    print("acc %0.4f; recall %0.4f; fpr %0.4f; fnr %0.4f; auc %.4f" % (acc, recall, fpr, fnr, auc))
    return acc, fpr, fnr, auc
    # print(silhouette_score(score.reshape(-1, 1), label_pred))

def detection1(score, nobyz):
    nrefs = 10
    ks = range(1, 8)
    gaps = np.zeros(len(ks))
    gapDiff = np.zeros(len(ks) - 1)
    sdk = np.zeros(len(ks))
    min = np.min(score)
    max = np.max(score)
    score = (score - min)/(max-min)
    for i, k in enumerate(ks):
        estimator = KMeans(n_clusters=k)
        estimator.fit(score.reshape(-1, 1))
        label_pred = estimator.labels_
        center = estimator.cluster_centers_
        Wk = np.sum([np.square(score[m]-center[label_pred[m]]) for m in range(len(score))])
        WkRef = np.zeros(nrefs)
        for j in range(nrefs):
            rand = np.random.uniform(0, 1, len(score))
            estimator = KMeans(n_clusters=k)
            estimator.fit(rand.reshape(-1, 1))
            label_pred = estimator.labels_
            center = estimator.cluster_centers_
            WkRef[j] = np.sum([np.square(rand[m]-center[label_pred[m]]) for m in range(len(rand))])
        gaps[i] = np.log(np.mean(WkRef)) - np.log(Wk)
        sdk[i] = np.sqrt((1.0 + nrefs) / nrefs) * np.std(np.log(WkRef))

        if i > 0:
            gapDiff[i - 1] = gaps[i - 1] - gaps[i] + sdk[i]
    #print(gapDiff)
    for i in range(len(gapDiff)):
        if gapDiff[i] >= 0:
            select_k = i+1
            break
    if select_k == 1:
        return 0
    else:
        # print('Attack Detected!')
        return 1

In [14]:
all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))
batch_size = 16
num_workers = 100
distribution='fang'
param = .5
force = False

each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)

train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    train_loaders.append((pos, get_train(all_data, indices, batch_size)))
test_loaders = []
for pos, indices in enumerate(each_worker_te_idx):
    test_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

len(train_loaders), len(test_loaders), len(global_test_idx)

Loading fang distribution for num_workers 100 and bias 0.5 from memory


(100, 100, 10038)

# Mean under no attack

In [24]:
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.1
global_lr = 1
nbyz = 0
agr = 'average'

criterion = nn.CrossEntropyLoss()
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

# FLDetector initializations
weight_record = []
grad_record = []
test_grads = []
old_grad_list = []
malicious_scores = np.zeros((1, num_workers))
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)

# Adaptive attack initializers
good_distance_rage = np.zeros((1, nbyz))
attack_type = 'none'
dev_type = 'std'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_grads=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    weight = model_received

    if (epoch_num > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (epoch_num < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
        # user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
    elif epoch_num > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.1
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.random.choice(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                # noise = deviation * ((distance_bound + np.random.uniform(0, np.std(good_distance_rage[-1]))) / torch.norm(deviation))
                noise = deviation * ((distance_bound)) / torch.norm(deviation)
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = fldetector(old_grad_list, user_grads, nbyz, hvp, agr=agr)
    
    if distance is not None and epoch_num > (start_detection_epoch - window_size):
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= window_size+1:
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', epoch_num)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz, num_workers)
            break

    if epoch_num > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > window_size):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads
    del user_grads

    model_received = model_received + global_lr * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    
    best_accs_per_round.append(best_global_acc.cpu().item())
    accs_per_round.append(val_acc.cpu().item())
    loss_per_round.append(val_loss.cpu().item())
    
    if epoch_num%1==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))
    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break
    epoch_num+=1

# final_accs_per_client=[]
# for i in range(num_workers):
#     client_loss, client_acc = test(test_loaders[i][1], fed_model, criterion)
#     final_accs_per_client.append(client_acc)
# results = collections.OrderedDict(
#     final_accs_per_client=np.array(final_accs_per_client),
#     accs_per_round=np.array(accs_per_round),
#     best_accs_per_round=np.array(best_accs_per_round),
#     loss_per_round=np.array(loss_per_round)
# )
# pickle.dump(results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq1_baseline_cifar10_fast_trmean_trim20.pkl'), 'wb'))

e 0 benign_norm 1359.627 val loss 2.304 val acc 0.100 best val_acc 0.100
e 1 benign_norm 647.379 val loss 2.312 val acc 0.100 best val_acc 0.100
e 2 benign_norm 281.629 val loss 2.377 val acc 0.105 best val_acc 0.105
e 3 benign_norm 278.466 val loss 2.259 val acc 0.126 best val_acc 0.126
e 4 benign_norm 277.227 val loss 2.078 val acc 0.216 best val_acc 0.216
e 5 benign_norm 276.699 val loss 1.925 val acc 0.277 best val_acc 0.277
e 6 benign_norm 276.419 val loss 1.836 val acc 0.304 best val_acc 0.304
e 7 benign_norm 276.116 val loss 1.772 val acc 0.329 best val_acc 0.329
e 8 benign_norm 275.852 val loss 1.702 val acc 0.369 best val_acc 0.369
e 9 benign_norm 275.644 val loss 1.633 val acc 0.398 best val_acc 0.398
e 10 benign_norm 275.504 val loss 1.585 val acc 0.417 best val_acc 0.417
e 11 benign_norm 275.477 val loss 1.549 val acc 0.427 best val_acc 0.427
e 12 benign_norm 275.416 val loss 1.494 val acc 0.448 best val_acc 0.448
e 13 benign_norm 275.439 val loss 1.472 val acc 0.456 best v

e 114 benign_norm 277.663 val loss 0.548 val acc 0.810 best val_acc 0.812
e 115 benign_norm 277.560 val loss 0.550 val acc 0.811 best val_acc 0.812
e 116 benign_norm 277.522 val loss 0.542 val acc 0.812 best val_acc 0.812
e 117 benign_norm 277.673 val loss 0.537 val acc 0.815 best val_acc 0.815
e 118 benign_norm 277.734 val loss 0.537 val acc 0.814 best val_acc 0.815
e 119 benign_norm 277.839 val loss 0.538 val acc 0.813 best val_acc 0.815
e 120 benign_norm 277.763 val loss 0.533 val acc 0.815 best val_acc 0.815
e 121 benign_norm 277.674 val loss 0.535 val acc 0.816 best val_acc 0.816
e 122 benign_norm 277.660 val loss 0.528 val acc 0.818 best val_acc 0.818
e 123 benign_norm 277.740 val loss 0.535 val acc 0.815 best val_acc 0.818
e 124 benign_norm 277.841 val loss 0.525 val acc 0.818 best val_acc 0.818
e 125 benign_norm 277.799 val loss 0.526 val acc 0.817 best val_acc 0.818
e 126 benign_norm 277.846 val loss 0.525 val acc 0.820 best val_acc 0.820
e 127 benign_norm 277.690 val loss 0.5

KeyboardInterrupt: 

# TrMean-FLD + m=20% + Fang0.5

In [30]:
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.1
global_lr = 1
nbyz = 20
agr = 'tr_mean'

criterion = nn.CrossEntropyLoss()
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

# FLDetector initializations
weight_record = []
grad_record = []
test_grads = []
old_grad_list = []
malicious_scores = np.zeros((1, num_workers))
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)

# Adaptive attack initializers
good_distance_range = np.zeros((1, nbyz))
attack_type = 'NDSS21'
dev_type = 'std'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_grads=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    weight = model_received

    if (epoch_num > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_range = np.concatenate(
            (good_distance_range, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (epoch_num < start_detection_epoch):
        user_grads = full_trim(user_grads, nbyz)
        # user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
    elif epoch_num > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'NDSS21':
                distance_bound = np.random.choice(np.mean(good_distance_range[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                # noise = deviation * ((distance_bound + np.random.uniform(0, np.std(good_distance_range[-1]))) / torch.norm(deviation))
                noise = deviation * ((distance_bound)) / torch.norm(deviation)
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = fldetector(old_grad_list, user_grads, nbyz, hvp, agr=agr)
    
    if distance is not None and epoch_num > (start_detection_epoch - window_size):
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= window_size+1:
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', epoch_num)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz, len(user_grads))
            break

    if epoch_num > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > window_size):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads
    del user_grads

    model_received = model_received + global_lr * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    
    best_accs_per_round.append(best_global_acc.cpu().item())
    accs_per_round.append(val_acc.cpu().item())
    loss_per_round.append(val_loss.cpu().item())
    
    if epoch_num%1==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))
    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break
    epoch_num+=1

e 11 benign_norm 275.926 val loss 1.944 val acc 0.272 best val_acc 0.272
e 12 benign_norm 277.947 val loss 2.101 val acc 0.258 best val_acc 0.272
e 13 benign_norm 280.003 val loss 1.974 val acc 0.278 best val_acc 0.278
e 14 benign_norm 279.630 val loss 1.998 val acc 0.297 best val_acc 0.297
e 15 benign_norm 279.360 val loss 2.104 val acc 0.278 best val_acc 0.297
e 16 benign_norm 279.709 val loss 2.149 val acc 0.290 best val_acc 0.297
e 17 benign_norm 279.903 val loss 2.373 val acc 0.269 best val_acc 0.297
e 18 benign_norm 280.937 val loss 2.408 val acc 0.269 best val_acc 0.297
e 19 benign_norm 280.945 val loss 2.294 val acc 0.284 best val_acc 0.297
e 20 benign_norm 282.025 val loss 2.107 val acc 0.323 best val_acc 0.323
e 21 benign_norm 283.073 val loss 2.012 val acc 0.330 best val_acc 0.330
e 22 benign_norm 283.225 val loss 1.941 val acc 0.349 best val_acc 0.349
e 23 benign_norm 284.876 val loss 1.870 val acc 0.372 best val_acc 0.372
e 24 benign_norm 285.311 val loss 1.839 val acc 0.3

e 123 benign_norm 2488.839 val loss 1.258 val acc 0.555 best val_acc 0.560
e 124 benign_norm 2167.619 val loss 1.296 val acc 0.544 best val_acc 0.560
e 125 benign_norm 2246.414 val loss 1.328 val acc 0.524 best val_acc 0.560
e 126 benign_norm 2017.557 val loss 1.337 val acc 0.535 best val_acc 0.560
e 127 benign_norm 2466.261 val loss 1.273 val acc 0.545 best val_acc 0.560
e 128 benign_norm 2185.093 val loss 1.249 val acc 0.559 best val_acc 0.560
e 129 benign_norm 2132.378 val loss 1.283 val acc 0.556 best val_acc 0.560
e 130 benign_norm 2908.978 val loss 1.439 val acc 0.484 best val_acc 0.560
e 131 benign_norm 2432.674 val loss 1.317 val acc 0.543 best val_acc 0.560
e 132 benign_norm 2299.885 val loss 1.247 val acc 0.564 best val_acc 0.564
e 133 benign_norm 2200.564 val loss 1.214 val acc 0.575 best val_acc 0.575
e 134 benign_norm 2749.518 val loss 1.304 val acc 0.540 best val_acc 0.575
e 135 benign_norm 2521.189 val loss 1.300 val acc 0.547 best val_acc 0.575
e 136 benign_norm 2533.59

In [36]:
'''
FAST BASELINE
'''

use_cuda = torch.cuda.is_available()
criterion = nn.CrossEntropyLoss()
n_malicious = [20]
n_runs = 3
results={}
for n_mal in n_malicious:
    results[n_mal] = {'acc': [], 'fpr': [], 'fnr': [], 'auc': [], 'stop_e': []}
    for run in range(n_runs):
        torch.cuda.empty_cache()
        nepochs=200
        local_epochs = 2
        batch_size = 16
        num_workers = 100
        local_lr = 0.1
        global_lr = 1
        agr = 'tr_mean'
        best_global_acc=0
        epoch_num = 0

        fed_model = resnet20().cuda()
        model_received = []
        for i, (name, param) in enumerate(fed_model.state_dict().items()):
            model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

        best_accs_per_round = []
        accs_per_round = []
        loss_per_round = []
        home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

        # FLDetector initializations
        weight_record = []
        grad_record = []
        test_grads = []
        old_grad_list = []
        malicious_scores = np.zeros((1, num_workers-(20-nbyz)))
        start_detection_epoch = 10
        window_size = 10
        assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)

        # Adaptive attack initializers
        good_distance_range = np.zeros((1, nbyz))
        attack_type = 'NDSS21'
        dev_type = 'std'
        nbyz = n_mal
        
        while epoch_num <= nepochs:
            torch.cuda.empty_cache()
            round_clients = np.arange(20-nbyz, num_workers)
            round_benign = round_clients
            user_grads=[]
            benign_norm = 0
            for i in round_benign:
                model = copy.deepcopy(fed_model)
                optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)

                for epoch in range(local_epochs):
                    train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

                params = []
                for i, (name, param) in enumerate(model.state_dict().items()):
                    params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

                update =  (params - model_received)
                benign_norm += torch.norm(update)/len(round_benign)
                user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

            weight = model_received

            if (epoch_num > start_detection_epoch):
                hvp = lbfgs(weight_record, grad_record, weight - last_weight)
                hvp = np.squeeze(hvp)
            else:
                hvp = None

            good_current_grads = copy.deepcopy(user_grads[:nbyz])
            if hvp is not None:
                pred_grad = copy.deepcopy(good_old_grads)
                distance = []
                for i in range(len(good_old_grads)):
                    pred_grad[i] += torch.from_numpy(hvp).to(device)
                good_distance_range = np.concatenate(
                    (good_distance_range, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)

            if attack_type != 'none' and (epoch_num < start_detection_epoch):
                user_grads = full_trim(user_grads, nbyz)
                # user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
            elif epoch_num > start_detection_epoch:
                if attack_type == 'full_trim':
                    user_grads = full_trim(user_grads, nbyz)
                elif attack_type == 'none':
                    pass
                else:
                    if attack_type == 'NDSS21':
                        distance_bound = np.random.choice(np.mean(good_distance_range[-1:], 0))
                        model_re = torch.mean(good_current_grads, dim=0)
                        if dev_type == 'unit_vec':
                            deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                        elif dev_type == 'sign':
                            deviation = torch.sign(model_re)
                        elif dev_type == 'std':
                            deviation = torch.std(good_current_grads, 0)
                        # noise = deviation * ((distance_bound + np.random.uniform(0, np.std(good_distance_range[-1]))) / torch.norm(deviation))
                        noise = deviation * ((distance_bound)) / torch.norm(deviation)
                    elif attack_type == 'mod_trim':
                        mal_grads= full_trim(user_grads[:nbyz], nbyz)
                        pass
                    else:
                        noise = torch.zeros(hvp.shape).to(device)
                    for m in range(nbyz):
                        user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

            agg_grads, distance = fldetector(old_grad_list, user_grads, nbyz, hvp, agr=agr)

            if distance is not None and epoch_num > (start_detection_epoch - window_size):
                malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

            if malicious_scores.shape[0] >= window_size+1:
                if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
                    print('!!!! Stop at iteration:', epoch_num)
                    print('r %d nmal %d | e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (run, nbyz, epoch_num, benign_norm, val_loss, val_acc, best_global_acc))
                    acc, fpr, fnr, auc = detection(
                        np.sum(malicious_scores[-window_size:], axis=0), nbyz, len(user_grads))
                    results[n_mal]['acc'].append(acc)
                    results[n_mal]['fpr'].append(fpr)
                    results[n_mal]['fnr'].append(fnr)
                    results[n_mal]['auc'].append(auc)
                    results[n_mal]['stop_e'].append(epoch_num)
                    break
                    break

            if epoch_num > (start_detection_epoch - window_size):
                weight_record.append(weight - last_weight)
                grad_record.append(agg_grads - last_grad)

            if (len(weight_record) > window_size):
                del weight_record[0]
                del grad_record[0]

            last_weight = weight
            last_grad = agg_grads
            old_grad_list = user_grads
            good_old_grads = good_current_grads
            del user_grads

            model_received = model_received + global_lr * agg_grads
            fed_model = resnet20().cuda()
            start_idx=0
            state_dict = {}
            previous_name = 'none'
            for i, (name, param) in enumerate(fed_model.state_dict().items()):
                start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
                start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
                params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
                state_dict[name] = params
                previous_name = name
            fed_model.load_state_dict(state_dict)
            val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
            is_best = best_global_acc < val_acc
            best_global_acc = max(best_global_acc, val_acc)

            best_accs_per_round.append(best_global_acc.cpu().item())
            accs_per_round.append(val_acc.cpu().item())
            loss_per_round.append(val_loss.cpu().item())

            if epoch_num%5==0 or epoch_num==nepochs-1:
                print('r %d nmal %d | e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (run, nbyz, epoch_num, benign_norm, val_loss, val_acc, best_global_acc))
            if math.isnan(val_loss) or val_loss > 100000:
                print('val loss %f... exit'%val_loss)
                break
            epoch_num+=1

        final_accs_per_client=[]
        for i in range(num_workers):
            client_loss, client_acc = test(test_loaders[i][1], fed_model, criterion)
            final_accs_per_client.append(client_acc.cpu().item())
        all_results = collections.OrderedDict(
            final_accs_per_client=np.array(final_accs_per_client),
            accs_per_round=np.array(accs_per_round),
            best_accs_per_round=np.array(best_accs_per_round),
            loss_per_round=np.array(loss_per_round),
            results=results
        )
        pickle.dump(all_results, open(os.path.join(home_dir, 'FLDetector_plots_data/rq4_adaptive_cifar10_fast_trmean_NDSS21_std_m%d_r%d.pkl' % (nbyz, run)), 'wb'))

r 0 nmal 20 | e 0 benign_norm 1116.596 val loss 2.312 val acc 0.099 best val_acc 0.099
r 0 nmal 20 | e 5 benign_norm 282.334 val loss 2.342 val acc 0.099 best val_acc 0.099
r 0 nmal 20 | e 10 benign_norm 277.146 val loss 1.887 val acc 0.283 best val_acc 0.283
r 0 nmal 20 | e 15 benign_norm 278.692 val loss 2.663 val acc 0.276 best val_acc 0.284
r 0 nmal 20 | e 20 benign_norm 281.196 val loss 2.135 val acc 0.343 best val_acc 0.343
r 0 nmal 20 | e 25 benign_norm 287.255 val loss 1.648 val acc 0.431 best val_acc 0.431
r 0 nmal 20 | e 30 benign_norm 296.231 val loss 1.554 val acc 0.457 best val_acc 0.460
r 0 nmal 20 | e 35 benign_norm 310.453 val loss 1.451 val acc 0.482 best val_acc 0.482
r 0 nmal 20 | e 40 benign_norm 336.919 val loss 1.471 val acc 0.481 best val_acc 0.487
r 0 nmal 20 | e 45 benign_norm 357.224 val loss 1.503 val acc 0.474 best val_acc 0.487
r 0 nmal 20 | e 50 benign_norm 395.278 val loss 1.521 val acc 0.466 best val_acc 0.490
r 0 nmal 20 | e 55 benign_norm 454.854 val l

# Load Dirichlet0.1

In [13]:
all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))
batch_size = 16
num_workers = 100
distribution='dirichlet'
param = .1
force = False

each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=100, distribution=distribution, param=param, force=force)

train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    train_loaders.append((pos, get_train(all_data, indices, batch_size)))
test_loaders = []
for pos, indices in enumerate(each_worker_te_idx):
    test_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

len(train_loaders), len(test_loaders), len(global_test_idx)

(100, 100, 10043)

# Mean with no attack

In [None]:
'''
FAST BASELINE
'''

use_cuda = torch.cuda.is_available()
criterion = nn.CrossEntropyLoss()
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 32
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    train_loaders.append((pos, get_train(all_data, indices, batch_size)))
num_workers = 100
local_lr = 0.1
global_lr = 1
agr = 'average'
best_global_acc=0
epoch_num = 0
nbyz = 0

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(20-nbyz, num_workers)
    round_benign = round_clients
    user_grads=[]
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    agg_grads = torch.mean(user_grads, 0)
    del user_grads
    model_received = model_received + global_lr * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    if is_best:
        best_model = copy.deepcopy(fed_model)

    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, val_loss, val_acc, best_global_acc))
    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break
    epoch_num+=1

e 0 val loss 2.365 val acc 0.101 best val_acc 0.101
e 10 val loss 2.090 val acc 0.228 best val_acc 0.255
e 20 val loss 1.836 val acc 0.370 best val_acc 0.371
e 30 val loss 1.699 val acc 0.445 best val_acc 0.445
e 40 val loss 1.608 val acc 0.472 best val_acc 0.492
e 50 val loss 1.498 val acc 0.521 best val_acc 0.570
e 60 val loss 1.401 val acc 0.557 best val_acc 0.601
e 70 val loss 1.292 val acc 0.585 best val_acc 0.609
e 80 val loss 1.196 val acc 0.621 best val_acc 0.621
e 90 val loss 1.137 val acc 0.650 best val_acc 0.650
e 100 val loss 1.100 val acc 0.662 best val_acc 0.678
e 110 val loss 1.053 val acc 0.677 best val_acc 0.678
e 120 val loss 1.032 val acc 0.683 best val_acc 0.691
e 130 val loss 1.023 val acc 0.678 best val_acc 0.694
e 140 val loss 0.972 val acc 0.700 best val_acc 0.700
e 150 val loss 0.911 val acc 0.719 best val_acc 0.719
e 160 val loss 0.882 val acc 0.718 best val_acc 0.720
e 170 val loss 0.874 val acc 0.718 best val_acc 0.720
e 180 val loss 0.846 val acc 0.727 best

In [None]:
use_cuda = torch.cuda.is_available()
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.1
global_lr = 1
nbyz = 0
agr = 'average'

criterion = nn.CrossEntropyLoss()
best_global_acc=0
epoch_num = 0

fed_model = resnet20().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

best_accs_per_round = []
accs_per_round = []
loss_per_round = []
home_dir = '/home/vshejwalkar_umass_edu/fedrecover/'

# FLDetector initializations
weight_record = []
grad_record = []
test_grads = []
old_grad_list = []
malicious_scores = np.zeros((1, num_workers))
start_detection_epoch = 10
window_size = 10
assert (start_detection_epoch - window_size >= 0), 'start_detection_epoch %d should be more than window_size %d' % (start_detection_epoch, window_size)

# Adaptive attack initializers
good_distance_rage = np.zeros((1, nbyz))
attack_type = 'none'
dev_type = 'std'

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_grads=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr*(0.999**epoch_num), momentum=0.9, weight_decay=1e-4)

        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)

        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))

        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_grads = update[None,:] if len(user_grads) == 0 else torch.cat((user_grads, update[None,:]), 0)

    weight = model_received

    if (epoch_num > start_detection_epoch):
        hvp = lbfgs(weight_record, grad_record, weight - last_weight)
        hvp = np.squeeze(hvp)
    else:
        hvp = None

    good_current_grads = copy.deepcopy(user_grads[:nbyz])
    if hvp is not None:
        pred_grad = copy.deepcopy(good_old_grads)
        distance = []
        for i in range(len(good_old_grads)):
            pred_grad[i] += torch.from_numpy(hvp).to(device)
        good_distance_rage = np.concatenate(
            (good_distance_rage, torch.norm(pred_grad - good_current_grads, dim = 1).cpu().numpy()[None,:]), 0)
        
    if attack_type != 'none' and (epoch_num < start_detection_epoch):
        user_grads[:nbyz] = full_trim(user_grads[:nbyz], nbyz)
    elif epoch_num > start_detection_epoch:
        if attack_type == 'full_trim':
            user_grads = full_trim(user_grads, nbyz)
        elif attack_type == 'none':
            pass
        else:
            if attack_type == 'LIE':
                print("LIE")
                z = 0.1
                noise_avg = torch.mean(user_grads[:nbyz], dim=0)
                noise_std = torch.std(user_grads[:nbyz], dim=0)
                noise = noise_avg + z * noise_std
            elif attack_type == 'NDSS21':
                distance_bound = np.random.choice(np.mean(good_distance_rage[-1:], 0))
                model_re = torch.mean(good_current_grads, dim=0)
                if dev_type == 'unit_vec':
                    deviation = model_re / torch.norm(model_re)  # unit vector, dir opp to good dir
                elif dev_type == 'sign':
                    deviation = torch.sign(model_re)
                elif dev_type == 'std':
                    deviation = torch.std(good_current_grads, 0)
                # noise = deviation * ((distance_bound + np.random.uniform(0, np.std(good_distance_rage[-1]))) / torch.norm(deviation))
                noise = deviation * ((distance_bound)) / torch.norm(deviation)
            elif attack_type == 'mod_trim':
                mal_grads= full_trim(user_grads[:nbyz], nbyz)
                pass
            else:
                noise = torch.zeros(hvp.shape).to(device)
            for m in range(nbyz):
                user_grads[m] = old_grad_list[m] + torch.from_numpy(hvp).to(device) + noise

    agg_grads, distance = fldetector(old_grad_list, user_grads, nbyz, hvp, agr=agr)
    
    if distance is not None and epoch_num > (start_detection_epoch - window_size):
        malicious_scores = np.concatenate((malicious_scores, distance[None, :]), 0)

    if malicious_scores.shape[0] >= window_size+1:
        if detection1(np.sum(malicious_scores[-window_size:], axis=0), nbyz):
            print('Stop at iteration:', epoch_num)
            detection(np.sum(malicious_scores[-window_size:], axis=0), nbyz, num_workers)
            break

    if epoch_num > (start_detection_epoch - window_size):
        weight_record.append(weight - last_weight)
        grad_record.append(agg_grads - last_grad)
    
    if (len(weight_record) > window_size):
        del weight_record[0]
        del grad_record[0]
    
    last_weight = weight
    last_grad = agg_grads
    old_grad_list = user_grads
    good_old_grads = good_current_grads
    del user_grads

    model_received = model_received + global_lr * agg_grads
    fed_model = resnet20().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name
    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)
    
    best_accs_per_round.append(best_global_acc.cpu().item())
    accs_per_round.append(val_acc.cpu().item())
    loss_per_round.append(val_loss.cpu().item())
    
    if epoch_num%10==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))
    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break
    epoch_num+=1

e 0 benign_norm 1110.241 val loss 2.312 val acc 0.100 best val_acc 0.100
e 10 benign_norm 279.361 val loss 1.987 val acc 0.255 best val_acc 0.255
e 20 benign_norm 276.344 val loss 1.675 val acc 0.422 best val_acc 0.422
e 30 benign_norm 276.152 val loss 1.487 val acc 0.495 best val_acc 0.496


# TrMean-FLD + m=20% + Dirichlet0.1