# FLDetector for CIFAR10 with Fang distribution


In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

  from IPython.core.display import display, HTML


In [2]:
import collections
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random
import copy
import time
from functools import reduce
from torchsummary import summary

import os
import sys
import pickle
sys.path.insert(0,'./utils/')
from logger import *
from eval import *
from misc import *

from sklearn.metrics import roc_auc_score
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from collections import defaultdict

from SGD import *
import copy
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

  warn(f"Failed to load image Python extension: {e}")


ModuleNotFoundError: No module named 'torchsummary'

In [None]:
from cifar10_models import *
from cifar10_util import *

In [None]:
def sample_dirichlet_train_data(trainset, no_participants, alpha=0.9, force=False):
        """
            Input: Number of participants and alpha (param for distribution)
            Output: A list of indices denoting data in CIFAR training set.
            Requires: cifar_classes, a preprocessed class-indice dictionary.
            Sample Method: take a uniformly sampled 10-dimension vector as parameters for
            dirichlet distribution to sample number of images in each class.
        """
        if not os.path.exists('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants)) or force:
            print('generating participant indices for alpha %.1f'%alpha)
            np.random.seed(0)
            cifar_classes = {}
            for ind, x in enumerate(trainset):
                _, label = x
                if label in cifar_classes:
                    cifar_classes[label].append(ind)
                else:
                    cifar_classes[label] = [ind]

            per_participant_list = defaultdict(list)
            no_classes = len(cifar_classes.keys())
            for n in range(no_classes):
                random.shuffle(cifar_classes[n])
                sampled_probabilities = len(cifar_classes[n]) * np.random.dirichlet(
                    np.array(no_participants * [alpha]))
                for user in range(no_participants):
                    no_imgs = int(round(sampled_probabilities[user]))
                    sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)]
                    per_participant_list[user].extend(sampled_list)
                    cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):]
            with open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'wb') as f:
                pickle.dump(per_participant_list, f)
        else:
            per_participant_list = pickle.load(open('./dirichlet_a_%.1f_nusers_%d.pkl'%(alpha, no_participants), 'rb'))

        return per_participant_list

In [3]:
def get_fang_train_data(trainset, num_workers=100, bias=0.5, force=False):
    dist_file = 'fang_nworkers%d_bias%.1f.pkl' % (num_workers, bias)
    if not force and os.path.exists(dist_file):
        print('Loading fang distribution for num_workers %d and bias %.1f from memory' % (num_workers, bias))
        return pickle.load(open(dist_file, 'rb'))
    bias_weight = bias
    other_group_size = (1 - bias_weight) / 9.
    worker_per_group = num_workers / 10
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]
    per_participant_list = defaultdict(list)
    for i, (x, y) in enumerate(trainset):
        # assign a data point to a group
        upper_bound = (y) * (1 - bias_weight) / 9. + bias_weight
        lower_bound = (y) * (1 - bias_weight) / 9.
        rd = np.random.random_sample()
        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y
        rd = np.random.random_sample()
        selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
        per_participant_list[selected_worker].extend([i])
    
    print('Saving fang distribution for num_workers %d and bias %.1f from memory' % (num_workers, bias))
    pickle.dump(per_participant_list, open(dist_file, 'wb'))
    return per_participant_list

In [4]:
def get_federated_data(trainset, num_workers, distribution='fang', param=1, force=False):
    if distribution == 'fang':
        per_participant_list = get_fang_train_data(trainset, num_workers, bias=param, force=force)
    elif distribution == 'dirichlet':
        per_participant_list = sample_dirichlet_train_data(trainset, num_workers, alpha=param, force=force)

    each_worker_idx = [[] for _ in range(num_workers)]
    
    each_worker_te_idx = [[] for _ in range(num_workers)]
    
    np.random.seed(0)
    
    for worker_idx in range(len(per_participant_list)):
        w_indices = np.array(per_participant_list[worker_idx])
        w_len = len(w_indices)
        len_tr = int(5*w_len/6)
        tr_idx = np.random.choice(w_len, len_tr, replace=False)
        te_idx = np.delete(np.arange(w_len), tr_idx)
        
        each_worker_idx[worker_idx] = w_indices[tr_idx]
        each_worker_te_idx[worker_idx] = w_indices[te_idx]
    
    global_test_idx = np.concatenate(each_worker_te_idx)
    
    return each_worker_idx, each_worker_te_idx, global_test_idx

In [5]:
def get_train(dataset, indices, batch_size=32, shuffle=False):
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))
    
    return train_loader

In [6]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
data_loc='/home/vshejwalkar_umass_edu/data/'
# load the train dataset

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_train)
cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_train)

te_cifar10_train = datasets.CIFAR10(root=data_loc, train=True, download=True, transform=transform_test)
te_cifar10_test = datasets.CIFAR10(root=data_loc, train=False, download=True, transform=transform_test)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [7]:
def full_trim(v, f):
    '''
    Full-knowledge Trim attack. w.l.o.g., we assume the first f worker devices are compromised.
    v: the list of squeezed gradients
    f: the number of compromised worker devices
    '''
    vi_shape = v[0].unsqueeze(0).T.shape
    v_tran = v.T
    
    maximum_dim = torch.max(v_tran, dim=1)
    maximum_dim = maximum_dim[0].reshape(vi_shape)
    minimum_dim = torch.min(v_tran, dim=1)
    minimum_dim = minimum_dim[0].reshape(vi_shape)
    direction = torch.sign(torch.sum(v_tran, dim=-1, keepdims=True))
    directed_dim = (direction > 0) * minimum_dim + (direction < 0) * maximum_dim

    for i in range(f):
        random_12 = 2
        tmp = directed_dim * ((direction * directed_dim > 0) / random_12 + (direction * directed_dim < 0) * random_12)
        tmp = tmp.squeeze()
        v[i] = tmp
    return v

In [8]:
def tr_mean(all_updates, n_attackers):
    sorted_updates = torch.sort(all_updates, 0)[0]
    out = torch.mean(sorted_updates[n_attackers:-n_attackers], 0) if n_attackers else torch.mean(sorted_updates,0)
    return out

In [9]:
from torch.nn.utils import parameters_to_vector, vector_to_parameters

def train(trainloader, model, model_received, criterion, optimizer, pgd=False, eps=2):
    # switch to train mode
    model.train()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(trainloader):

        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item()/100.0, inputs.size()[0])
        top5.update(prec5.item()/100.0, inputs.size()[0])

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if pgd:
            curr_model = list(model.parameters())
            curr_model_vec = parameters_to_vector(curr_model)

            if torch.norm(curr_model_vec - model_received) > eps:
                curr_model_vec = eps*(curr_model_vec - model_received)/torch.norm(curr_model_vec - model_received) + model_received
                vector_to_parameters(curr_model_vec, curr_model)
        
    return (losses.avg, top1.avg)

def test(testloader, model, criterion):
    model.eval()
    losses = AverageMeter()
    top1 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data, inputs.size()[0])
        top1.update(prec1/100.0, inputs.size()[0])
    return (losses.avg, top1.avg)

# Alexnet

In [10]:
class AlexNet(nn.Module):

    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


def alexnet(**kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.
    """
    model = AlexNet(**kwargs)
    return model

In [11]:
all_data = torch.utils.data.ConcatDataset((cifar10_train, cifar10_test))
all_test_data = torch.utils.data.ConcatDataset((te_cifar10_train, te_cifar10_test))
batch_size = 16
num_workers = 100
distribution='fang'
param = .5
force = False

each_worker_idx, each_worker_te_idx, global_test_idx = get_federated_data(
    all_data, num_workers=num_workers, distribution=distribution, param=param, force=force)
train_loaders = []
for pos, indices in enumerate(each_worker_idx):
    train_loaders.append((pos, get_train(all_data, indices, batch_size)))
test_loaders = []
for pos, indices in enumerate(each_worker_te_idx):
    test_loaders.append((pos, get_train(all_test_data, indices, len(indices))))
cifar10_test_loader = get_train(all_test_data, global_test_idx)

len(train_loaders), len(test_loaders), len(global_test_idx)

NameError: name 'os' is not defined

In [14]:
torch.cuda.empty_cache()
nepochs=200
local_epochs = 2
batch_size = 16
num_workers = 100
local_lr = 0.01
global_lr = 1
nbyz = 20
criterion = nn.CrossEntropyLoss()
use_cuda = torch.cuda.is_available()

resume=False
round_nclients = num_workers
best_global_acc=0
epoch_num = 0

fed_model = alexnet().cuda()
model_received = []
for i, (name, param) in enumerate(fed_model.state_dict().items()):
    model_received = param.view(-1).data.type(torch.cuda.FloatTensor) if len(model_received) == 0 else torch.cat((model_received, param.view(-1).data.type(torch.cuda.FloatTensor)))

while epoch_num <= nepochs:
    torch.cuda.empty_cache()
    round_clients = np.arange(num_workers)
    round_benign = round_clients
    user_updates=[]
    benign_norm = 0
    for i in round_benign:
        model = copy.deepcopy(fed_model)
        optimizer = optim.SGD(model.parameters(), lr = local_lr, momentum=0.9, weight_decay=1e-4)
        for epoch in range(local_epochs):
            train_loss, train_acc = train(train_loaders[i][1], model, model_received, criterion, optimizer)
        params = []
        for i, (name, param) in enumerate(model.state_dict().items()):
            params = param.view(-1).data.type(torch.cuda.FloatTensor) if len(params) == 0 else torch.cat((params, param.view(-1).data.type(torch.cuda.FloatTensor)))
        update =  (params - model_received)
        benign_norm += torch.norm(update)/len(round_benign)
        user_updates = update[None,:] if len(user_updates) == 0 else torch.cat((user_updates, update[None,:]), 0)

    user_updates[:nbyz] = full_trim(user_updates[:nbyz], nbyz)
    agg_update = tr_mean(user_updates, nbyz)
    del user_updates

    model_received = model_received + global_lr * agg_update
    fed_model = alexnet().cuda()
    start_idx=0
    state_dict = {}
    previous_name = 'none'
    for i, (name, param) in enumerate(fed_model.state_dict().items()):
        start_idx = 0 if i == 0 else start_idx + len(fed_model.state_dict()[previous_name].data.view(-1))
        start_end = start_idx + len(fed_model.state_dict()[name].data.view(-1))
        params = model_received[start_idx:start_end].reshape(fed_model.state_dict()[name].data.shape)
        state_dict[name] = params
        previous_name = name

    fed_model.load_state_dict(state_dict)
    val_loss, val_acc = test(cifar10_test_loader, fed_model, criterion)
    is_best = best_global_acc < val_acc
    best_global_acc = max(best_global_acc, val_acc)

    if epoch_num%2==0 or epoch_num==nepochs-1:
        print('e %d benign_norm %.3f val loss %.3f val acc %.3f best val_acc %.3f'% (epoch_num, benign_norm, val_loss, val_acc, best_global_acc))

    if math.isnan(val_loss) or val_loss > 100000:
        print('val loss %f... exit'%val_loss)
        break

    epoch_num+=1

e 0 benign_norm 1.291 val loss 2.300 val acc 0.100 best val_acc 0.100
e 2 benign_norm 1.200 val loss 2.264 val acc 0.140 best val_acc 0.140
e 4 benign_norm 1.191 val loss 2.179 val acc 0.191 best val_acc 0.191
e 6 benign_norm 1.136 val loss 2.041 val acc 0.225 best val_acc 0.225
e 8 benign_norm 1.147 val loss 2.005 val acc 0.221 best val_acc 0.225
e 10 benign_norm 1.178 val loss 1.974 val acc 0.241 best val_acc 0.241
e 12 benign_norm 1.213 val loss 1.945 val acc 0.253 best val_acc 0.255
e 14 benign_norm 1.279 val loss 1.905 val acc 0.276 best val_acc 0.276
e 16 benign_norm 1.339 val loss 1.869 val acc 0.287 best val_acc 0.287
e 18 benign_norm 1.392 val loss 1.840 val acc 0.302 best val_acc 0.302
e 20 benign_norm 1.469 val loss 1.840 val acc 0.303 best val_acc 0.307
e 22 benign_norm 1.538 val loss 1.817 val acc 0.323 best val_acc 0.323
e 24 benign_norm 1.581 val loss 1.829 val acc 0.317 best val_acc 0.323
e 26 benign_norm 1.613 val loss 1.783 val acc 0.339 best val_acc 0.339
e 28 benign