# FLDetector for Fashion MNIST


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

  warn(f"Failed to load image Python extension: {e}")


cuda


In [3]:
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))
            
        return per_participant_list

In [5]:
def get_client_train_data(trainset, num_workers=100, bias=0.5):
    np.random.seed(42)
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10

    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        
        if not len(each_worker_data[selected_worker]):
            each_worker_data[selected_worker] = x[None, :]
        else:
            each_worker_data[selected_worker]= torch.concat((each_worker_data[selected_worker], x[None, :]))
        
        each_worker_label[selected_worker].append(y)
    
    each_worker_tr_data = [[] for _ in range(num_workers)]
    each_worker_tr_label = [[] for _ in range(num_workers)]
    each_worker_val_data = [[] for _ in range(num_workers)]
    each_worker_val_label = [[] for _ in range(num_workers)]
    each_worker_te_data = [[] for _ in range(num_workers)]
    each_worker_te_label = [[] for _ in range(num_workers)]

    for i in range(num_workers):
        w_len = len(each_worker_data[i])
        len_tr = int(5 * w_len / 7)
        len_val = int(1 * w_len / 7)
        len_te = w_len - (len_tr + len_val)
        # tr_idx = np.random.choice(w_len, len_tr, replace=False)
        # te_idx = np.delete(np.arange(w_len), tr_idx)
        w = np.arange(w_len)
        np.random.shuffle(w)
        tr_idx, val_idx, te_idx = w[:len_tr], w[len_tr : (len_tr + len_val)], w[-len_te:]
        each_worker_tr_data[i] = each_worker_data[i][tr_idx]
        each_worker_tr_label[i] = torch.Tensor(each_worker_label[i])[tr_idx]
        each_worker_val_data[i] = each_worker_data[i][val_idx]
        each_worker_val_label[i] = torch.Tensor(each_worker_label[i])[val_idx]
        each_worker_te_data[i] = each_worker_data[i][te_idx]
        each_worker_te_label[i] = torch.Tensor(each_worker_label[i])[te_idx]
        
    global_test_data = torch.concat(each_worker_te_data)
    global_test_label = torch.concat(each_worker_te_label)
    del each_worker_data, each_worker_label
    return each_worker_tr_data, each_worker_tr_label, each_worker_val_data, each_worker_val_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label

In [6]:
def get_client_data_dirichlet(trainset, num_workers, alpha=1, force=False):
    per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=alpha, force=force)
    
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    
    each_worker_te_data = [[] for _ in range(num_workers)]
    each_worker_te_label = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(6*w_len/7)
        len_te = w_len - len_tr
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        
        for idx in w_indices[tr_idx]:
            each_worker_data[worker_idx].append(trainset[idx][0])
            each_worker_label[worker_idx].append(trainset[idx][1])
        each_worker_data[worker_idx] = torch.stack(each_worker_data[worker_idx])
        each_worker_label[worker_idx] = torch.Tensor(each_worker_label[worker_idx]).long()
        
        for idx in w_indices[te_idx]:
            each_worker_te_data[worker_idx].append(trainset[idx][0])
            each_worker_te_label[worker_idx].append(trainset[idx][1])
        each_worker_te_data[worker_idx] = torch.stack(each_worker_te_data[worker_idx])
        each_worker_te_label[worker_idx] = torch.Tensor(each_worker_te_label[worker_idx]).long()
    
    global_test_data = torch.concat(each_worker_te_data)
    global_test_label = torch.concat(each_worker_te_label)
    
    return each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label


In [7]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(f):
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [8]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [9]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 3)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(30, 50, 3)
        self.pool2 = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(1250, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
#         print(x.shape)
        x = self.pool1(F.relu(self.conv1(x)))
#         print(x.shape)
        x = self.pool2(F.relu(self.conv2(x)))
#         print(x.shape)
        x = x.view(x.size(0), -1)
#         print(x.shape)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [10]:
model = cnn()
sum(p.numel() for p in model.parameters())

266060

# Good FedAvg baseline Fashion MNIST + Fang distribution + 80 clients

In [11]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(train_data, labels, model, optimizer, batch_size=20):
    model.train()
    criterion = nn.CrossEntropyLoss()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    len_t = (len(train_data) // batch_size)
    if len(train_data)%batch_size:
        len_t += 1

    r=np.arange(len(train_data))
    np.random.shuffle(r)
    
    train_data = train_data[r]
    labels = labels[r]
    
    for ind in range(len_t):

        inputs = train_data[ind * batch_size:(ind + 1) * batch_size]
        targets = labels[ind * batch_size:(ind + 1) * batch_size]

        inputs, targets = inputs.cuda(), targets.cuda()

        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return (losses.avg, top1.avg)


def test(test_data, labels, model, criterion, use_cuda, debug_='MEDIUM', batch_size=64):
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    len_t = (len(test_data) // batch_size)
    if len(test_data)%batch_size:
        len_t += 1

    with torch.no_grad():
        for ind in range(len_t):
            # measure data loading time
            inputs = test_data[ind * batch_size:(ind + 1) * batch_size]
            targets = labels[ind * batch_size:(ind + 1) * batch_size]

            inputs, targets = inputs.cuda(), targets.cuda()

            inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

    return (losses.avg, top1.avg)

# Mean without any attack

In [13]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()
criterion = nn.CrossEntropyLoss()
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20
force=True
all_data = torch.utils.data.ConcatDataset((trainset, testset))
distribution='dirichlet'
params = [.1, .2, .3, .4, .5]

for dir_p in params:
    for run in range(3):
        print('===> Processing Dir %.1f run %d' % (dir_p, run))
        each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=dir_p, force=force)
        best_global_acc=0
        epoch_num = 0

        fed_model = cnn().to(device)
        fed_model.apply(init_weights)
        model_received = []
        for i, (name, param) in enumerate(fed_model.state_dict().items()):
            model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

        while epoch_num <= nepochs:
            torch.cuda.empty_cache()
            round_clients = np.arange(nbyz, num_workers)
            round_benign = round_clients
            user_updates=[]
            benign_norm = 0

            for i in round_benign:
                model = copy.deepcopy(fed_model)
                optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9, weight_decay=1e-4)
                for epoch in range(local_epochs):
                    train_loss, train_acc = train(
                        each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)
                params = []
                for i, (name, param) in enumerate(model.state_dict().items()):
                    params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                        (params, param.view(-1).data.type(torch.cuda.FloatTensor)))
                update =  (params - model_received)
                benign_norm += torch.norm(update)/len(round_benign)
                user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

            agg_update = torch.mean(user_updates, 0)
            del user_updates
            model_received = model_received + global_lr * agg_update
            fed_model = cnn().to(device)
            fed_model.apply(init_weights)
            start_idx=0
            state_dict = {}
            previous_name = 'none'
            for i, (name, param) in enumerate(fed_model.state_dict().items()):
                start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
                start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
                params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
                state_dict[name] = params
                previous_name = name
            fed_model.load_state_dict(state_dict)
            val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
            is_best = best_global_acc < val_acc
            best_global_acc = max(best_global_acc, val_acc)

            if epoch_num%10==0 or epoch_num==nepochs-1:
                print('Dir %.1f r %d e %d val loss %.3f val acc %.3f best val_acc %.3f'% (dir_p, run, epoch_num, val_loss, val_acc, best_global_acc))
            epoch_num+=1

===> Processing Dir 0.1 run 0
generating participant indices for alpha 0.1
Dir 0.1 r 0 e 0 val loss 2.071 val acc 30.404 best val_acc 30.404
Dir 0.1 r 0 e 10 val loss 0.595 val acc 77.448 best val_acc 77.836
Dir 0.1 r 0 e 20 val loss 0.489 val acc 82.026 best val_acc 82.026
Dir 0.1 r 0 e 30 val loss 0.447 val acc 83.808 best val_acc 83.808
Dir 0.1 r 0 e 40 val loss 0.416 val acc 84.952 best val_acc 84.952
Dir 0.1 r 0 e 49 val loss 0.397 val acc 85.599 best val_acc 85.599
Dir 0.1 r 0 e 50 val loss 0.397 val acc 85.629 best val_acc 85.629
===> Processing Dir 0.1 run 1
generating participant indices for alpha 0.1
Dir 0.1 r 1 e 0 val loss 2.033 val acc 37.609 best val_acc 37.609
Dir 0.1 r 1 e 10 val loss 0.609 val acc 76.533 best val_acc 76.533
Dir 0.1 r 1 e 20 val loss 0.497 val acc 82.006 best val_acc 82.006
Dir 0.1 r 1 e 40 val loss 0.415 val acc 84.813 best val_acc 84.902
Dir 0.1 r 1 e 49 val loss 0.393 val acc 85.500 best val_acc 85.500
Dir 0.1 r 1 e 50 val loss 0.392 val acc 85.649 b

# Trim attack on Dir + TrMean

In [17]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

torch.cuda.empty_cache()
use_cuda = torch.cuda.is_available()
criterion = nn.CrossEntropyLoss()
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nepochs = 50
nbyz = 20

all_data = torch.utils.data.ConcatDataset((trainset, testset))
fast_ndss = True
params = [.1, .2, .3, .4, .5, 10]

results = {}
for dir_p in np.array(params)[::-1]:
    results[dir_p] = []
    for run in range(3):
        print('===> Processing NDSS on Dir %.1f run %d' % (dir_p, run))
        each_worker_data, each_worker_label, each_worker_te_data, each_worker_te_label, global_test_data, global_test_label = get_client_data_dirichlet(all_data, num_workers, alpha=dir_p, force=force)
        best_global_acc=0
        epoch_num = 0

        fed_model = cnn().to(device)
        fed_model.apply(init_weights)
        model_received = []
        for i, (name, param) in enumerate(fed_model.state_dict().items()):
            model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

        while epoch_num <= nepochs:
            torch.cuda.empty_cache()
            round_clients = np.arange(num_workers)
            round_benign = round_clients
            user_updates=[]
            benign_norm = 0

            for i in round_benign:
                model = copy.deepcopy(fed_model)
                optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9, weight_decay=1e-4)
                for epoch in range(local_epochs):
                    train_loss, train_acc = train(
                        each_worker_data[i], torch.Tensor(each_worker_label[i]).long(), model, optimizer, batch_size)
                params = []
                for i, (name, param) in enumerate(model.state_dict().items()):
                    params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat(
                        (params, param.view(-1).data.type(torch.cuda.FloatTensor)))
                update =  (params - model_received)
                benign_norm += torch.norm(update)/len(round_benign)
                user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

            user_updates = full_trim(user_updates, nbyz)
            agg_update = tr_mean(user_updates, nbyz)

            del user_updates
            model_received = model_received + global_lr * agg_update
            fed_model = cnn().to(device)
            fed_model.apply(init_weights)
            start_idx=0
            state_dict = {}
            previous_name = 'none'
            for i, (name, param) in enumerate(fed_model.state_dict().items()):
                start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
                start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
                params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
                state_dict[name] = params
                previous_name = name
            fed_model.load_state_dict(state_dict)
            val_loss, val_acc = test(global_test_data, global_test_label.long(), fed_model, criterion, use_cuda)
            is_best = best_global_acc < val_acc
            best_global_acc = max(best_global_acc, val_acc)

            if epoch_num%10==0 or epoch_num==nepochs-1:
                print('Trim on Dir %.1f + TrMean | r %d e %d val loss %.3f val acc %.3f best val_acc %.3f'% (dir_p, run, epoch_num, val_loss, val_acc, best_global_acc))
            epoch_num+=1
        results[dir_p].append(best_global_acc)

for fcj_p in params:
    print('======= Results for Trim attack on Dir %.1f  =====' % fcj_p)
    print(results[fcj_p])

===> Processing NDSS on Dir 10.0 run 0
generating participant indices for alpha 10.0
Trim on Dir 10.0 + TrMean | r 0 e 0 val loss 1.115 val acc 68.625 best val_acc 68.625
Trim on Dir 10.0 + TrMean | r 0 e 10 val loss 0.524 val acc 80.145 best val_acc 80.145
Trim on Dir 10.0 + TrMean | r 0 e 20 val loss 0.478 val acc 82.525 best val_acc 82.525
Trim on Dir 10.0 + TrMean | r 0 e 30 val loss 0.453 val acc 83.740 best val_acc 83.740
Trim on Dir 10.0 + TrMean | r 0 e 40 val loss 0.443 val acc 84.367 best val_acc 84.367
Trim on Dir 10.0 + TrMean | r 0 e 49 val loss 0.431 val acc 84.985 best val_acc 84.985
Trim on Dir 10.0 + TrMean | r 0 e 50 val loss 0.429 val acc 85.184 best val_acc 85.184
===> Processing NDSS on Dir 10.0 run 1
generating participant indices for alpha 10.0
Trim on Dir 10.0 + TrMean | r 1 e 0 val loss 1.078 val acc 70.915 best val_acc 70.915
Trim on Dir 10.0 + TrMean | r 1 e 10 val loss 0.516 val acc 80.215 best val_acc 80.215
Trim on Dir 10.0 + TrMean | r 1 e 20 val loss 0.4

Trim on Dir 0.2 + TrMean | r 0 e 0 val loss 2.170 val acc 31.763 best val_acc 31.763
Trim on Dir 0.2 + TrMean | r 0 e 10 val loss 0.943 val acc 64.781 best val_acc 64.781
Trim on Dir 0.2 + TrMean | r 0 e 20 val loss 0.802 val acc 71.135 best val_acc 71.135
Trim on Dir 0.2 + TrMean | r 0 e 30 val loss 0.771 val acc 72.430 best val_acc 72.430
Trim on Dir 0.2 + TrMean | r 0 e 40 val loss 0.757 val acc 73.018 best val_acc 73.566
Trim on Dir 0.2 + TrMean | r 0 e 49 val loss 0.759 val acc 73.038 best val_acc 73.566
Trim on Dir 0.2 + TrMean | r 0 e 50 val loss 0.764 val acc 72.699 best val_acc 73.566
===> Processing NDSS on Dir 0.2 run 1
generating participant indices for alpha 0.2
Trim on Dir 0.2 + TrMean | r 1 e 0 val loss 2.194 val acc 25.608 best val_acc 25.608
Trim on Dir 0.2 + TrMean | r 1 e 10 val loss 0.998 val acc 63.147 best val_acc 63.147
Trim on Dir 0.2 + TrMean | r 1 e 20 val loss 0.846 val acc 68.327 best val_acc 68.327
Trim on Dir 0.2 + TrMean | r 1 e 30 val loss 0.831 val acc 

KeyError: tensor(-0.6187, device='cuda:0')

In [18]:
results

{10.0: [85.18371005211648, 84.73563677331882, 84.91486608423014],
 0.5: [81.30413141093959, 81.9313090930595, 81.67247386987438],
 0.4: [79.00607508224458, 78.88656508999746, 81.36639776990633],
 0.3: [76.70550740438324, 75.53032565565654, 75.09212228945795],
 0.2: [73.56573705179282, 69.89043822269515, 72.06175298804781],
 0.1: [54.80692675159236, 55.40406050955414, 56.47890127388535]}