Resources#

https://github.com/pytorch/pytorch/blob/cc46dc45e1b4b2a9ffab4ad5442f8b864148e45a/torch/optim/_functional.py#L156

https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py

https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html

https://discuss.pytorch.org/t/custom-loss-functions/29387

https://discuss.pytorch.org/t/good-way-to-calculate-element-wise-difference-between-two-models-of-the-same-structure/67592

In [810]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import datetime

In [811]:
def convert_model_to_param_list(model):
    '''
    num_of_param=0
    for param in model.state_dict().values():
        num_of_param += torch.numel(param)
    

    params=torch.ones([num_of_param])
    '''
    if torch.typename(model)!='OrderedDict':
        model = model.state_dict()

    idx=0
    params_to_copy_group=[]
    for name, param in model.items():
        num_params_to_copy = torch.numel(param)
        params_to_copy_group.append(param.reshape([num_params_to_copy]).clone().detach())
        idx+=num_params_to_copy

    params=torch.ones([idx])
    idx=0
    for param in params_to_copy_group:    
        for par in param:
            params[idx].copy_(par)
            idx += 1

    return params

def cos_calc_btn_grads(l1, l2):
    return torch.dot(l1, l2)/(torch.linalg.norm(l1)+1e-9)/(torch.linalg.norm(l2)+1e-9)


def cos_calc(n1, n2):
    l1 = convert_model_to_param_list(n1)
    l2 = convert_model_to_param_list(n2)
    return cos_calc_btn_grads(l1, l2)

In [812]:
class SimpleNet(nn.Module):
    def __init__(self, name=None, created_time=None, is_malicious=False, net_id=-1):
        super(SimpleNet, self).__init__()
        self.created_time = created_time
        self.name=name

        self.is_malicious=is_malicious
        self.net_id=net_id


    def save_stats(self, epoch, loss, acc):
        self.stats['epoch'].append(epoch)
        self.stats['loss'].append(loss)
        self.stats['acc'].append(acc)

    def copy_params(self, state_dict, coefficient_transfer=100):

        own_state = self.state_dict()

        for name, param in state_dict.items():
            if name in own_state:
                shape = param.shape
                own_state[name].copy_(param.clone())
    
    def set_param_to_zero(self):
        own_state = self.state_dict()

        for name, param in own_state.items():
            shape = param.shape
            param.mul_(0)       

    def aggregate(self, state_dicts, aggr_weights=None):
        #self.copy_params(state_dicts[0])
        own_state = self.state_dict()
        
        nw = len(state_dicts)
        if aggr_weights is None:
            aggr_weights = [1/nw]*nw

        for i, state_dict in enumerate(state_dicts):
            for name, param in state_dict.items():
                if name in own_state:
                    shape = param.shape
                    own_state[name].add_(param.clone().mul_(aggr_weights[i]))

    def calc_grad(self, state_dict, change_self=True):
        if change_self:
            own_state = self.state_dict()

            for name, param in state_dict.items():
                if name in own_state:
                    shape = param.shape
                    own_state[name].sub_(param.clone())
        else:
            self_params = convert_model_to_param_list(self)
            ref_params = convert_model_to_param_list(state_dict)

            self_params.sub_(ref_params)
            self.grad_params = self_params

In [813]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class MnistNet(SimpleNet):#model for mnist
    def __init__(self, name=None, created_time=None,num_of_classes = 10):
        super(MnistNet, self).__init__(f'{name}_Simple', created_time)
     
        self.fc_layer = torch.nn.Sequential(#1 * 28 * 28
            Flatten(),#784
            nn.Linear(784, num_of_classes),
        )
    
    def forward(self, x):
     
        out = self.fc_layer(x)
        return out

In [814]:
class FLNet(SimpleNet):
    def __init__(self, name=None, created_time=None,num_of_classes = 10):
        super(FLNet, self).__init__(f'{name}_Simple', created_time)

        self.mnistnet1 = MnistNet()

        self.mnistnet2 = MnistNet()

        self.mnistnetavg = MnistNet()

    def forward(self, x):
        out = self.mnistnet1(x)
        return out

In [815]:

class CNN(SimpleNet):
    def __init__(self, name=None, created_time=None,num_of_classes = 10,network_id=-1, net_id=-1, is_malicious=False):
        super(CNN, self).__init__(f'{name}_Simple', created_time, is_malicious=is_malicious, net_id=net_id)

        self.network_id=network_id

        self.conv1 = nn.Sequential(         
            nn.Conv2d(
                in_channels=1,              
                out_channels=16,            
                kernel_size=5,              
                stride=1,                   
                padding=2,                  
            ),                              
            nn.ReLU(),                      
            nn.MaxPool2d(kernel_size=2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(16, 32, 5, 1, 2),     
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output    # return x for visualization
'''
class CNN(SimpleNet):
    def __init__(self, name=None, created_time=None,num_of_classes = 10,network_id=-1, net_id=-1, is_malicious=False):
        super(CNN, self).__init__(f'{name}_Simple', created_time, is_malicious=is_malicious, net_id=net_id)
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
'''

"\nclass CNN(SimpleNet):\n    def __init__(self, name=None, created_time=None,num_of_classes = 10,network_id=-1, net_id=-1, is_malicious=False):\n        super(CNN, self).__init__(f'{name}_Simple', created_time, is_malicious=is_malicious, net_id=net_id)\n        self.conv1 = nn.Conv2d(3, 6, 5)\n        self.pool = nn.MaxPool2d(2, 2)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        x = self.pool(F.relu(self.conv1(x)))\n        x = self.pool(F.relu(self.conv2(x)))\n        x = torch.flatten(x, 1) # flatten all dimensions except batch\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n"

In [816]:
import copy

def add_pixel_pattern(ori_image):
    image = copy.deepcopy(ori_image)
    poison_patterns= poison_dict['poison_pattern']
    delta =  poison_dict['poison_delta']

    for i in range(0, len(poison_patterns)):
        pos = poison_patterns[i]
        image[0][pos[0]][pos[1]] = min( image[0][pos[0]][pos[1]] + delta/np.sqrt(len(poison_patterns)), 1)


    return image

def get_poison_batch(bptt,adversarial_index=-1, evaluation=False, attack_type='label_flip'):
    if attack_type=='backdoor':
        import sys
        print("still backdooring")
        sys.exit()

        images, targets = bptt

        poison_count= 0
        new_images=images
        new_targets=targets

        for index in range(0, len(images)):
            if evaluation: # poison all data when testing
                new_targets[index] = poison_dict['poison_label_swap']
                new_images[index] = add_pixel_pattern(images[index])
                poison_count+=1

            else: # poison part of data when training
                if index < poison_dict['poisoning_per_batch']:
                    new_targets[index] = poison_dict['poison_label_swap']
                    new_images[index] = add_pixel_pattern(images[index])
                    poison_count += 1
                else:
                    new_images[index] = images[index]
                    new_targets[index]= targets[index]

        new_images = new_images.to(device)
        new_targets = new_targets.to(device).long()
        if evaluation:
            new_images.requires_grad_(False)
            new_targets.requires_grad_(False)
        return new_images,new_targets,poison_count
    
    elif attack_type=='degrade' or attack_type=='label_flip':
        images, targets = bptt

        poison_count= 0
        new_images=images
        new_targets=targets

        num_of_classes = 10

        for index in range(0, len(images)):
            if evaluation: # poison all data when testing
                if attack_type=='degrade':
                    new_targets[index] = poison_dict['poison_label_swap']
                elif attack_type=='label_flip':
                    new_targets[index] = (targets[index]+1)%num_of_classes
                new_images[index] = images[index]
                poison_count+=1

            else: # poison part of data when training
                if index < poison_dict['poisoning_per_batch']:
                    new_targets[index] = poison_dict['poison_label_swap']
                    new_images[index] = add_pixel_pattern(images[index])
                    poison_count += 1
                else:
                    new_images[index] = images[index]
                    new_targets[index]= targets[index]

        new_images = new_images.to(device)
        new_targets = new_targets.to(device).long()
        if evaluation:
            new_images.requires_grad_(False)
            new_targets.requires_grad_(False)
        return new_images,new_targets,poison_count
        

def get_batch(bptt, evaluation=False):
    data, target = bptt
    data = data.to(device)
    target = target.to(device)
    if evaluation:
        data.requires_grad_(False)
        target.requires_grad_(False)
    return data, target

In [817]:
import torch
from torch import Tensor
from typing import List, Optional

from tensorflow import print as prnt

def sgd(params: List[Tensor],
        d_p_list: List[Tensor],
        momentum_buffer_list: List[Optional[Tensor]],
        ref_params: List[Tensor],
        ref_grad_params: List[Tensor],
        *,
        weight_decay: float,
        momentum: float,
        lr: float,
        dampening: float,
        nesterov: bool,
        maximize: bool,
        inertia: float,
        minimizeDist: bool):
    r"""Functional API that performs SGD algorithm computation.
    See :class:`~torch.optim.SGD` for details.
    """
    #print("dummy", len(params), len(ref_params))

    if len(ref_params)!=0 and len(params)!=len(ref_params):
        #print(params, ref_params)
        print('params', params, '\n')
        print('ref_params', ref_params, '\n')
        sys.exit()

    if len(ref_grad_params)!=0 and len(params)!=len(ref_grad_params):
        print('params', params, '\n')
        print('ref_params', ref_grad_params, '\n')
        sys.exit()
    for i, param in enumerate(params):

        d_p = d_p_list[i]
        if weight_decay != 0:
            d_p = d_p.add(param, alpha=weight_decay)

        if momentum != 0:
            buf = momentum_buffer_list[i]

            if buf is None:
                buf = torch.clone(d_p).detach()
                momentum_buffer_list[i] = buf
            else:
                buf.mul_(momentum).add_(d_p, alpha=1 - dampening)

            if nesterov:
                d_p = d_p.add(buf, alpha=momentum)
            else:
                d_p = buf
        
        if ref_params is not None and len(ref_params) != 0 and minimizeDist:
            diff = 0
            ### if minimize euclidean distance
            #diff = param - ref_params[i]
            ### if minimize cosine similarity, ref_params contains the gradient of the reference network weights
            
            if ref_grad_params is not None:

                dir_sign = torch.sign(ref_params[i]-param)
                diff -= ref_params[i]-param
                diff -= ref_grad_params[i]
                diff = diff.mul_(0.5)
            
            #d_p = d_p.mul_(0)
            d_p = d_p.add(diff, alpha=inertia)
        
        #if diff_total==0:
            #print('diff_total 0')

        alpha = lr if maximize else -lr
        param.add_(d_p, alpha=alpha)

In [818]:
import torch
#from torch.optim import _functional as F
from torch.optim.optimizer import Optimizer, required


class SGD(Optimizer):
    def __init__(self, params, lr=required, ref_param_groups=None, ref_grad_param_groups=None, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, *, maximize=False, inertia=1.0, minimizeDist=False):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov, maximize=maximize)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        #self.ref_param_groups = ref_param_groups
        super(SGD, self).__init__(params, defaults)
        #self.ref_param_groups = ref_param_groups
        self.inertia=inertia
        self.minimizeDist=minimizeDist

        if ref_param_groups is not None:
            self.__setrefparams__(ref_param_groups)
        else:
            self.ref_param_groups = ref_param_groups

        if ref_grad_param_groups is not None:
            self.__setrefgradparams__(ref_grad_param_groups)
        else:
            self.ref_grad_param_groups = ref_param_groups

    def filter_ref_param_group(self, param_group):
        r"""Add a param group to the :class:`Optimizer` s `param_groups`.
        This can be useful when fine tuning a pre-trained network as frozen layers can be made
        trainable and added to the :class:`Optimizer` as training progresses.
        Args:
            param_group (dict): Specifies what Tensors should be optimized along with group
                specific optimization options.
        """
        assert isinstance(param_group, dict), "param group must be a dict"

        params = param_group['params']
        if isinstance(params, torch.Tensor):
            param_group['params'] = [params]
        elif isinstance(params, set):
            raise TypeError('optimizer parameters need to be organized in ordered collections, but '
                            'the ordering of tensors in sets will change between runs. Please use a list instead.')
        else:
            param_group['params'] = list(params)

        for param in param_group['params']:
            if not isinstance(param, torch.Tensor):
                raise TypeError("optimizer can only optimize Tensors, "
                                "but one of the params is " + torch.typename(param))
            if not param.is_leaf:
                raise ValueError("can't optimize a non-leaf Tensor")
        '''
        for name, default in self.defaults.items():
            if default is required and name not in param_group:
                raise ValueError("parameter group didn't specify a value of required optimization parameter " +
                                name)
            else:
                param_group.setdefault(name, default)
        '''

        params = param_group['params']
        if len(params) != len(set(params)):
            warnings.warn("optimizer contains a parameter group with duplicate parameters; "
                          "in future, this will cause an error; "
                          "see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)

        param_set = set()
        for group in self.param_groups:
            param_set.update(set(group['params']))

        if not param_set.isdisjoint(set(param_group['params'])):
            raise ValueError("some parameters appear in more than one parameter group")

        #self.ref_param_groups.append(param_group)
        return param_group
    
    def __setrefparams__(self, params):
        self.ref_param_groups = []

        param_groups = list(params)
        '''
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        '''
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.ref_param_groups.append(self.filter_ref_param_group(param_group))

    def __setrefgradparams__(self, params):
        self.ref_grad_param_groups = []

        param_groups = list(params)
        '''
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        '''
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.ref_grad_param_groups.append(self.filter_ref_param_group(param_group))

    def __setstate__(self, state):
        super(SGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)
            group.setdefault('maximize', False)
    
    

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        Args:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        #for group in self.ref_param_groups:
            #for p in group['params']:

        if self.ref_param_groups is not None and len(self.param_groups)!=len(self.ref_param_groups):
            print(len(self.param_groups), len(self.ref_param_groups))
            sys.exit()

        if self.ref_grad_param_groups is not None and len(self.param_groups)!=len(self.ref_grad_param_groups):
            print(len(self.param_groups), len(self.ref_grad_param_groups))
            sys.exit()
        for i in range(len(self.param_groups)):
            group = self.param_groups[i]
            if self.ref_param_groups is not None:
                ref_group = self.ref_param_groups[i]
            if self.ref_grad_param_groups is not None:
                ref_grad_group = self.ref_grad_param_groups[i]
            params_with_grad = []
            d_p_list = []
            momentum_buffer_list = []
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            maximize = group['maximize']
            lr = group['lr']

            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)
                    d_p_list.append(p.grad)

                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        momentum_buffer_list.append(None)
                    else:
                        momentum_buffer_list.append(state['momentum_buffer'])
            ref_params=[]
            ref_grad_params=[]
            if self.ref_param_groups is not None:
                for p in ref_group['params']:
                    ref_params.append(p)
            if self.ref_grad_param_groups is not None:
                for p in ref_grad_group['params']:
                    ref_grad_params.append(p)
            
            sgd(params_with_grad,
                  d_p_list,
                  momentum_buffer_list,
                  ref_params=ref_params,
                  ref_grad_params=ref_grad_params,
                  weight_decay=weight_decay,
                  momentum=momentum,
                  lr=lr,
                  dampening=dampening,
                  nesterov=nesterov,
                  maximize=maximize,
                  inertia=self.inertia,
                  minimizeDist=self.minimizeDist)
            

            # update momentum_buffers in state
            for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
                state = self.state[p]
                state['momentum_buffer'] = momentum_buffer

        return loss

    ##  add add_param method



In [819]:
def new_loss(output, target):
    loss = F.nll_loss(output, target)

    return loss

In [820]:
def inv_grad_test(model):
    for name, param in model.named_parameters():
        print(name, torch.isfinite(param.grad).all())

In [821]:
loss_func=nn.CrossEntropyLoss()

def train(network, optimizer, epoch):
  network.train()

  if network.network_id!=-1:
      (_, temp_train_loader)=train_loaders[network.network_id]
  else:
      temp_train_loader=train_loader

  for batch_idx, (data, target) in enumerate(temp_train_loader):
    if network.network_id>=1:
        data, target, poison_num = get_poison_batch((data, target))
    optimizer.zero_grad()
    output = network(data)
    loss = loss_func(output, target)
    loss.backward()
    #inv_grad_test(network)
    optimizer.step()
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
      epoch, batch_idx * len(data), len(train_loader.dataset),
      100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
      torch.save(network.state_dict(), 'model.pth')
      torch.save(optimizer.state_dict(), 'optimizer.pth')
  if network.network_id>=1:
      sys.exit()

In [866]:
from tqdm import tqdm

def train_net(network, optimizer, trainloader, epoch, poisonNow=False, print_flag=False, tqdm_disable=True, attack_type='backdoor'):
    for batch_idx, (data, target) in enumerate(tqdm(trainloader, disable=tqdm_disable)):
        if poisonNow:
            data, target, poison_num = get_poison_batch((data, target), attack_type=attack_type)
        else:
            data, target = get_batch((data, target))
        optimizer.zero_grad()
        output = network(data)
        loss = loss_func(output, target)
        loss.backward()
        #inv_grad_test(network)
        optimizer.step()
        if batch_idx % log_interval == 0:
            if print_flag:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(trainloader.dataset),
                100. * batch_idx / len(trainloader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append(
                (batch_idx*64) + ((epoch-1)*len(trainloader.dataset)))

In [867]:
def test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in tqdm(test_loader):
        data, target = get_batch((data, target))
        output = network(data)
        test_loss += loss_func(output, target).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))
  return 100. * correct / len(test_loader.dataset)

In [868]:
def backdoor_test(network):
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in tqdm(test_loader):
      data, target, poison_num = get_poison_batch((data, target), evaluation=True, attack_type='backdoor')
      output = network(data)
      test_loss += loss_func(output, target).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nBackdoor Test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [869]:
def calcDiff(network, network2):
    return sum((x - y).abs().sum() for x, y in zip(network.state_dict().values(), network2.state_dict().values()))

In [870]:
def get_scaled_up_grads(glob_net, networks, self=None, iter=-1):
    nets_grads=[]
    clean_server_grad=None
    grads=[]

    for i in range(len(networks)):
        grad_net=CNN().to(device)
        grad_net.to(device)
        grad_net.copy_params(networks[i].state_dict())
        nets_grads.append(grad_net)
        
        grad_net.calc_grad(glob_net.state_dict())
        grads.append(convert_model_to_param_list(grad_net))

    scaled_grad=CNN().to(device)
    scaled_grad.set_param_to_zero()
    scaled_grad.aggregate([n.state_dict() for n in nets_grads], aggr_weights=[-1]*(len(networks)-1)+[len(networks)])

    self.log.append((iter, 'Cos_sim btn scaled grad and clean server grad', 'get_scaled_up_grads', cos_calc(scaled_grad, nets_grads[-2])))
    print(self.log[-1])
    self.log.append((iter, 'Cos_sim btn mal grad and clean server grad', 'get_scaled_up_grads', cos_calc(nets_grads[-1], nets_grads[-2])))
    print(self.log[-1])

    scaled_grad.aggregate([glob_net.state_dict()], aggr_weights=[1])
    return scaled_grad

    

In [884]:
class CustomFL:
    def __init__(self, num_of_benign_nets=1, num_of_mal_nets=1, inertia=0.1, n_iter=10,
                 n_epochs=3, poison_starts_at_iter=3, learning_rate=0.1, momentum=0, weight_decay=0.1,
                 attack_type='label_flip', scale_up=False, minimizeDist=True):
        self.global_net = CNN().to(device)
        self.global_net_optim = SGD(self.global_net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        self.global_net.to(device)
        self.benign_nets = []
        self.benign_net_optims = []
        for i in range(num_of_benign_nets):
            network = CNN(net_id=i)
            network.copy_params(self.global_net.state_dict())
            optim = SGD(network.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay, inertia=inertia)
            network.to(device)
            self.benign_nets.append(network)
            self.benign_net_optims.append(optim)
        self.mal_nets = []
        self.mal_net_optims = []
        for i in range(num_of_mal_nets):
            network = CNN(is_malicious=True, net_id=i)
            network.copy_params(self.global_net.state_dict())
            optim = SGD(network.parameters(), lr=learning_rate,
                        momentum=momentum, weight_decay=weight_decay, inertia=inertia, minimizeDist=minimizeDist)
            network.to(device)
            self.mal_nets.append(network)
            self.mal_net_optims.append(optim)

        self.current_iter=0
        self.num_of_benign_nets=num_of_benign_nets
        self.num_of_mal_nets=num_of_mal_nets
        self.inertia_rate=inertia
        self.n_iter=n_iter
        self.n_epochs=n_epochs
        self.learning_rate=learning_rate
        self.momentum=momentum
        self.poison_starts_at_iter=poison_starts_at_iter
        self.weight_decay=weight_decay
        self.attack_type=attack_type
        self.scale_up=scale_up

        self.log=[]
        self.debug_log={}
        self.debug_log['cluster']=[]
        self.debug_log['cluster_without_running_avg']=[]
        self.debug_log['coses']=[]
        n_nets=num_of_benign_nets+num_of_mal_nets
        self.cos_matrices=[]
        #self.cos_matrix.append(np.zeros((n_nets, n_nets)))

    def cluster_grads(self, iter=-1):
        nets = self.benign_nets + self.mal_nets
        for net in nets:
            net.calc_grad(self.global_net.state_dict(), change_self=False)

        from sklearn.cluster import AgglomerativeClustering
        X = [np.array(net.grad_params) for net in nets]
        X= np.array(X)
        clustering = AgglomerativeClustering(n_clusters=num_of_distributions, affinity='cosine', linkage='complete').fit(X)
        from sklearn.metrics.cluster import adjusted_rand_score
        print('Original Copylist', copylist)
        print('Found clusters', clustering.labels_)
        print('Original groups', [np.argwhere(np.array(copylist)==i).flatten() for i in range(num_of_distributions)])
        print('Clustered groups', [np.argwhere(clustering.labels_==i).flatten() for i in range(num_of_distributions)])
        print('Clustering score', adjusted_rand_score(clustering.labels_.tolist(), copylist))
        self.log.append((iter, 'Original copylist', 'cluster_grads', copylist))
        self.log.append((iter, 'Clusters', 'cluster_grads', clustering.labels_))
        self.debug_log['cluster_without_running_avg'].append((iter, 'Cluster Score', 'cluster_grads', adjusted_rand_score(clustering.labels_.tolist(), copylist)))
        
        coses=[]
        
        for i1, net1 in enumerate(nets):
            coses_l=[]
            for i2, net2 in enumerate(nets):
                coses_l.append(cos_calc_btn_grads(net1.grad_params, net2.grad_params))
            coses.append(coses_l)
            
        coses = np.array(coses)
        
        '''
        self.cos_matrix = self.cos_matrix + coses
        self.cos_matrix = self.cos_matrix/np.max(self.cos_matrix)
        '''
        self.cos_matrices.append(coses)
        print(len(self.cos_matrices))
        
        num_of_coses = np.minimum(len(self.cos_matrices), 5)
        
        cos_matrix=np.zeros((len(nets), len(nets)))
        for i in range(num_of_coses):
            cos_matrix+=self.cos_matrices[-i-1]
            
        cos_matrix = cos_matrix/num_of_coses
                    
        print(cos_matrix)
        
        clustering = AgglomerativeClustering(n_clusters=num_of_distributions, affinity='precomputed', linkage='complete').fit(1-cos_matrix)
        print('Original Copylist', copylist)
        print('Found clusters', clustering.labels_)
        print('Original groups', [np.argwhere(np.array(copylist)==i).flatten() for i in range(num_of_distributions)])
        print('Clustered groups', [np.argwhere(clustering.labels_==i).flatten() for i in range(num_of_distributions)])
        print('Clustering score', adjusted_rand_score(clustering.labels_.tolist(), copylist))
        self.debug_log['cluster'].append((iter, 'Cluster Score', 'cluster_grads', adjusted_rand_score(clustering.labels_.tolist(), copylist)))
        
        
        '''
        X = [np.array(net.grad_params) for net in self.benign_nets]
        X= np.array(X)
        copylist2=copylist[:self.num_of_benign_nets]
        clustering = AgglomerativeClustering(n_clusters=len(set(copylist2)), affinity='cosine', linkage='complete').fit(X)
        print('Original Copylist', copylist2)
        print('Found clusters', clustering.labels_)
        print('Original groups', [np.argwhere(np.array(copylist2)==i).flatten() for i in range(num_of_distributions)])
        print('Clustered groups', [np.argwhere(clustering.labels_==i).flatten() for i in range(num_of_distributions)])
        print('Clustering score', adjusted_rand_score(clustering.labels_.tolist(), copylist2))
        '''
        
        return coses

    def FLtrust(self, iter=-1):
        clean_server_grad=None
        grads=[]
        nets_grads=[]
        
        nets = self.benign_nets + self.mal_nets
        for net in nets:
            net.calc_grad(self.global_net.state_dict(), change_self=False)
            grad_net = CNN().to(device)
            grad_net.to(device)
            grad_net.copy_params(net.state_dict())
            grad_net.aggregate([self.global_net.state_dict()], aggr_weights=[-1])
            nets_grads.append(grad_net)

        for i in range(self.num_of_benign_nets):
            grads.append(self.benign_nets[i].grad_params)
            if i==self.num_of_benign_nets-1:
                clean_server_grad=grads[i]

        for i in range(self.num_of_mal_nets):
            grads.append(self.mal_nets[i].grad_params)
        
        norms = [torch.linalg.norm(grad) for grad in grads]
        print('Norms of local gradients ', norms)
        self.log.append((iter, 'Norms of local gradients ', 'FLTrust', norms))

        
        cos_sims=[cos_calc_btn_grads(grad, clean_server_grad) for grad in grads]

        '''
        for grad in grads:
            cos_sims.append(torch.dot(grad, clean_server_grad)/ (torch.linalg.norm(grad)+ 1e-9) / (torch.linalg.norm(clean_server_grad)+ 1e-9))
        '''
        print('\n Aggregating models')

        #print([cos_calc() ])

        print('Cosine Similarities: ', cos_sims)
        self.log.append((iter, 'Cosine Similarities', 'FLtrust', cos_sims))
        cos_sims = np.maximum(np.array(cos_sims), 0)
        norm_weights = cos_sims/(np.sum(cos_sims)+1e-9)
        for i in range(len(norm_weights)):
            norm_weights[i] = norm_weights[i] * torch.linalg.norm(clean_server_grad) / (torch.linalg.norm(grads[i]))
        
        print('Aggregation Weights: ', norm_weights)
        self.log.append((iter, 'Aggregation Weights', 'FLtrust', norm_weights))

        self.global_net.aggregate([grad.state_dict() for grad in nets_grads], aggr_weights=norm_weights)
            
            

    def train_local_net(self, is_malicious, net_id, iter, ref_net_for_minimizing_dist=None, print_flag=False):
        if is_malicious:
            network=self.mal_nets[net_id]
            optim=self.mal_net_optims[net_id]
            # will change later to aggregate of benign_nets
            if ref_net_for_minimizing_dist is None:
                ref_net_for_minimizing_dist = self.benign_nets[0].parameters()
            ref_grad, ref_net = ref_net_for_minimizing_dist
            if ref_grad is None:
                import sys
                sys.exit()
            if ref_grad is not None:
                optim.__setrefgradparams__(ref_grad.parameters())
            optim.__setrefparams__(ref_net.parameters())
        else:
            network=self.benign_nets[net_id]
            optim=self.benign_net_optims[net_id]

        (_, _, trainloader) = train_loaders[iter][net_id + is_malicious*self.num_of_benign_nets]

        poisonNow = True if is_malicious and iter>=self.poison_starts_at_iter else False
        for epoch in range(self.n_epochs if not poisonNow else (self.n_epochs*1)):
            clientType = 'Malicious' if is_malicious else 'Benign'
            if print_flag:
                print(f'Iter {iter} - Epoch {epoch} - Client Type: {clientType} - Client Number {net_id} - Poison Training {poisonNow}')
            train_net(network, optim, trainloader, epoch, poisonNow=poisonNow, attack_type=self.attack_type)
        network.calc_grad(self.global_net.state_dict(), change_self=False)
        if poisonNow:
            acc=test(network)
            self.log.append((iter, 'Local net test accuracy: mal', 'train_local_net', acc))
            if self.attack_type=='backdoorq':
                acc = backdoor_test(network)
                self.log.append((iter, 'Local net backdoor test accuracy: mal', 'train_local_net', acc))

    def train(self, tqdm_disable=False):
        for iter in range(self.n_iter):
            distanceList=[]
            cosList=[]
            networks=[]
            networks+=self.benign_nets
            networks+=self.mal_nets

            for i in tqdm(range(self.num_of_benign_nets), disable=tqdm_disable):
                self.train_local_net(False, i, iter)


            coses = self.cluster_grads(iter)

            self.debug_log['coses'].append((iter, coses))

            benign_aggr_net=CNN().to(device)
            benign_aggr_net.set_param_to_zero()

            ### if adversary knows benign_net_aggregates
            benign_aggr_net.aggregate([net.state_dict() for net in self.benign_nets])
            ### if adversary knows clean server
            #benign_aggr_net.copy_params(self.benign_nets[-1].state_dict())

            benign_aggr_net_grad=CNN().to(device)
            benign_aggr_net_grad.copy_params(benign_aggr_net.state_dict())
            benign_aggr_net_grad.aggregate([self.global_net.state_dict()], aggr_weights=[-1])

            
            for i in range(self.num_of_mal_nets):
                self.train_local_net(True, i, iter, ref_net_for_minimizing_dist=(benign_aggr_net_grad, benign_aggr_net))
                
                if self.scale_up:
                    scaled_up_grad = get_scaled_up_grads(self.global_net, networks, self, iter)
                    self.mal_nets[i].copy_params(scaled_up_grad.state_dict())
                    #self.mal_nets[i].aggregate([benign_aggr_net.state_dict()])
                

            cosList=[cos_calc_btn_grads(net.grad_params, self.benign_nets[-1].grad_params) for net in networks]
            distanceList=[calcDiff(net, self.benign_nets[-1]) for net in networks]

            #self.cluster_grads()

            self.log.append((iter, 'Benign net distance', 'train', distanceList[:self.num_of_benign_nets]))
            print('Benign net distance', distanceList[:self.num_of_benign_nets])
            self.log.append((iter, 'Malicious net distance', 'train', distanceList[self.num_of_benign_nets:]))
            print('Malicious net distance', distanceList[self.num_of_benign_nets:])
            self.log.append((iter, 'Cos sim list', 'train', cosList))
            print('cos_sim list ', cosList)

            # aggregate nets
            #self.global_net.set_param_to_zero()
            #self.global_net.aggregate([network.state_dict() for network in networks])
            self.FLtrust(iter=iter)
            print('\n\n\nAggregate test at iter ', iter)
            acc=test(self.global_net)
            self.log.append((iter, 'Test accuracy: agg net', 'train', acc))
            #backdoor_test(self.global_net)
            self.log.append((iter, 'Backdoor test accuracy: agg net', 'train', acc))
            self.log.append((iter, 'Distance between aggregate global and clean server', 'train', calcDiff(self.global_net, self.benign_nets[-1])))

            # set all local nets equal to global net at the end of the iteration
            
            for network in networks:
                network.copy_params(self.global_net.state_dict())
            

In [885]:
print('Original groups', [np.argwhere(np.array(copylist)==i).flatten() for i in range(num_of_distributions)])

Original groups [array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64), array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=int64), array([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int64), array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39], dtype=int64), array([40, 41, 42, 43, 44, 45, 46, 47, 48, 49], dtype=int64), array([50, 51, 52, 53, 54, 55, 56, 57, 58, 59], dtype=int64), array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69], dtype=int64), array([70, 71, 72, 73, 74, 75, 76, 77, 78, 79], dtype=int64), array([80, 81, 82, 83, 84, 85, 86, 87, 88, 89], dtype=int64), array([90, 91, 92, 93, 94, 95, 96, 97, 98, 99], dtype=int64), array([100], dtype=int64), array([], dtype=int64)]


In [886]:
def assign_data(train_data, bias, ctx, num_labels=10, num_workers=100, server_pc=100, p=0.1, dataset="FashionMNIST", seed=1):
    # assign data to the clients
    other_group_size = (1 - bias) / (num_labels - 1)
    worker_per_group = num_workers / num_labels

    #assign training data to each worker
    each_worker_data = [[] for _ in range(num_workers)]
    each_worker_label = [[] for _ in range(num_workers)]   
    server_data = []
    server_label = [] 
    
    # compute the labels needed for each class
    real_dis = [1. / num_labels for _ in range(num_labels)]
    samp_dis = [0 for _ in range(num_labels)]
    num1 = int(server_pc * p)
    samp_dis[1] = num1
    average_num = (server_pc - num1) / (num_labels - 1)
    resid = average_num - np.floor(average_num)
    sum_res = 0.
    for other_num in range(num_labels - 1):
        if other_num == 1:
            continue
        samp_dis[other_num] = int(average_num)
        sum_res += resid
        if sum_res >= 1.0:
            samp_dis[other_num] += 1
            sum_res -= 1
    samp_dis[num_labels - 1] = server_pc - np.sum(samp_dis[:num_labels - 1])

    # randomly assign the data points based on the labels
    server_counter = [0 for _ in range(num_labels)]
    for _, (x, y) in enumerate(train_data):
        '''
        if dataset == "FashionMNIST":
            x = x.as_in_context(ctx).reshape(1,1,28,28)
        else:
            raise NotImplementedError
        y = y.as_in_context(ctx)
        '''

        upper_bound = y * (1. - bias) / (num_labels - 1) + bias
        lower_bound = y * (1. - bias) / (num_labels - 1)
        rd = np.random.random_sample()

        if rd > upper_bound:
            worker_group = int(np.floor((rd - upper_bound) / other_group_size) + y + 1)
        elif rd < lower_bound:
            worker_group = int(np.floor(rd / other_group_size))
        else:
            worker_group = y

        
        if server_counter[int(y)] < samp_dis[int(y)]:
            server_data.append(x)
            server_label.append(y)
            server_counter[int(y)] += 1
        else:
            rd = np.random.random_sample()
            selected_worker = int(worker_group * worker_per_group + int(np.floor(rd * worker_per_group)))
            each_worker_data[selected_worker].append(x)
            each_worker_label[selected_worker].append(y)
    
   
    #     server_data = nd.concat(*server_data, dim=0)
    #     server_label = nd.concat(*server_label, dim=0)
    
    
#     each_worker_data = [nd.concat(*each_worker, dim=0) for each_worker in each_worker_data] 
#     each_worker_label = [nd.concat(*each_worker, dim=0) for each_worker in each_worker_label]
    

#     # randomly permute the workers
#     random_order = np.random.RandomState(seed=seed).permutation(num_workers)
#     each_worker_data = [each_worker_data[i] for i in random_order]
#     each_worker_label = [each_worker_label[i] for i in random_order]
    
    
    return server_data, server_label, each_worker_data, each_worker_label
    

In [887]:
sd, sl, ewd, ewl = assign_data(train_dataset, 0.5, None)

In [888]:
xx=[1, 2 ,3]
yy=[4, 5, 6]
print((x, y) for (x, y) in zip(xx, yy))

<generator object <genexpr> at 0x000001C2E5B86CC8>


In [889]:
poison_dict = dict()
poison_dict['poison_delta'] = 0.1
poison_dict['poison_pattern'] = [[23,25], [24,24],[25,23],[25,25]]
poison_dict['poisoning_per_batch'] = 80
poison_dict['poison_label_swap'] = 0

In [897]:
#@title
import torchvision
import random

batch_size_train = 100
batch_size_test = 1000
learning_rate = 0.01
log_interval = 10


### important hyperparameters
num_of_workers=101
num_of_mal_workers=0
n_iter=50
n_epochs=2
poison_starts_at_iter=n_iter
inertia=0.1
momentum=0.1
attack_type='label_flip'
scale_up=False
minimizeDist=False

iid = False
num_of_distributions = int(num_of_workers/10)+1
num_of_workers_in_distribs = num_of_workers * np.random.dirichlet(np.array(num_of_distributions * [3.0]))
num_of_workers_in_distribs = [int(val) for val in num_of_workers_in_distribs]
while 0 in num_of_workers_in_distribs:
    num_of_workers_in_distribs.remove(0)
num_of_workers_in_distribs.append(num_of_workers-sum(num_of_workers_in_distribs))
print(num_of_workers_in_distribs, sum(num_of_workers_in_distribs))
num_of_distributions = len(num_of_workers_in_distribs)
copylist = []
for i in range(len(num_of_workers_in_distribs)):
    copylist += num_of_workers_in_distribs[i]*[i]
random.shuffle(copylist)
print(copylist)

label_skew_ratios=[]

device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

[12, 8, 4, 11, 6, 6, 7, 6, 10, 16, 8, 7] 101
[1, 3, 10, 8, 3, 7, 4, 0, 0, 11, 0, 3, 4, 7, 0, 1, 8, 7, 3, 4, 1, 3, 11, 0, 10, 7, 5, 9, 2, 0, 8, 1, 4, 11, 8, 9, 7, 5, 9, 11, 9, 8, 10, 6, 9, 3, 4, 9, 10, 9, 9, 9, 10, 3, 5, 11, 9, 8, 5, 3, 6, 8, 5, 6, 10, 8, 0, 9, 8, 11, 3, 1, 6, 0, 9, 4, 0, 0, 2, 1, 6, 6, 8, 6, 9, 7, 3, 0, 3, 5, 0, 1, 1, 2, 9, 9, 11, 2, 9, 10, 10]


<torch._C.Generator at 0x1c2913e2250>

In [898]:
train_losses = []
train_counter = []
test_losses = []
#test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

In [899]:
#@title
from torchvision import datasets, transforms

dataPath = ''

import random

train_loaders=[]

transform = transforms.Compose([transforms.ToTensor(),
    ### if dataset is mnist
    transforms.Normalize((0.1307,), (0.3081,))])
    ### if dataset is cifar
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = datasets.FashionMNIST('', train=True, download=True,
                               transform=transform)
test_dataset = datasets.FashionMNIST('', train=False, transform=transform)

test_loader = torch.utils.data.DataLoader(
  test_dataset,
  batch_size=batch_size_test, shuffle=True)

all_range = list(range(len(train_dataset)))
random.shuffle(all_range)

def get_train_iid(all_range, model_no, iter_no):
    """
    This method equally splits the dataset.
    :param params:
    :param all_range:
    :param model_no:
    :return:
    """

    data_len_for_iter = int(len(train_dataset) / n_iter)
    data_len = int(data_len_for_iter/num_of_workers)
    sub_indices_for_iter = all_range[iter_no * data_len_for_iter: (iter_no + 1) * data_len_for_iter]
    sub_indices = sub_indices_for_iter[model_no * data_len: (model_no + 1) * data_len ]
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=batch_size_train,
                                        sampler=torch.utils.data.sampler.SubsetRandomSampler(sub_indices)
                                        )
    return train_loader

def get_train_noniid(indices):
    """
    This method is used along with Dirichlet distribution
    :param params:
    :param indices:
    :return:
    """
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=int(len(train_dataset)/num_of_workers),
                                        sampler=torch.utils.data.sampler.SubsetRandomSampler(
                                            indices))
    return train_loader

def poison_test_dataset(test_dataset, batch_size):
    logger.info('get poison test loader')
    # delete the test data with target label
    test_classes = {}
    for ind, x in enumerate(test_dataset):
        _, label = x
        if label in test_classes:
            test_classes[label].append(ind)
        else:
            test_classes[label] = [ind]

    range_no_id = list(range(0, len(test_dataset)))
    for image_ind in test_classes[poison_dict['poison_label_swap']]:
        if image_ind in range_no_id:
            range_no_id.remove(image_ind)
    poison_label_inds = test_classes[poison_dict['poison_label_swap']]

    return torch.utils.data.DataLoader(test_dataset,
                        batch_size=batch_size,
                        sampler=torch.utils.data.sampler.SubsetRandomSampler(
                            range_no_id)), \
            torch.utils.data.DataLoader(test_dataset,
                                        batch_size=batch_size,
                                        sampler=torch.utils.data.sampler.SubsetRandomSampler(
                                            poison_label_inds))
    

In [900]:
from collections import defaultdict

def sample_dirichlet_train_data(no_participants=num_of_workers, dataset=train_dataset, alpha=0.9, copylist=np.arange(num_of_workers)):
    """
        Input: Number of participants and alpha (param for distribution)
        Output: A list of indices denoting data in CIFAR training set.
        Requires: dataset_classes, a preprocessed class-indice dictionary.
        Sample Method: take a uniformly sampled 10-dimension vector as parameters for
        dirichlet distribution to sample number of images in each class.
    """

    dataset_classes = {}
    for ind, x in enumerate(dataset):
        _, label = x
        #if ind in self.params['poison_images'] or ind in self.params['poison_images_test']:
        #    continue
        if label in dataset_classes:
            dataset_classes[label].append(ind)
        else:
            dataset_classes[label] = [ind]
    class_size = len(dataset_classes[0])
    per_participant_list = defaultdict(list)
    no_classes = len(dataset_classes.keys())

    for n in range(no_classes):
        random.shuffle(dataset_classes[n])
        num_of_non_iid_participants = len(np.unique(copylist))
        sampled_probabilities = np.random.dirichlet(
            np.array(num_of_non_iid_participants * [alpha]))
        new_list = []
        for ip in copylist:
            #new_list.append(np.random.normal(loc=sampled_probabilities[ip], scale=0.005))
            new_list.append(sampled_probabilities[ip])
        sampled_probabilities = class_size * np.array(new_list)/np.sum(np.array(new_list))
        sigmas = 0.0 * sampled_probabilities
        sampled_probabilities = np.random.normal(sampled_probabilities, scale=sigmas)
        print(sampled_probabilities)
        label_skew_ratios.append(sampled_probabilities)
        for user in range(no_participants):
            no_imgs = int(round(sampled_probabilities[user]))
            sampled_list = dataset_classes[n][:min(len(dataset_classes[n]), no_imgs)]
            per_participant_list[user].extend(sampled_list)
            dataset_classes[n] = dataset_classes[n][min(len(dataset_classes[n]), no_imgs):]

    return per_participant_list

In [901]:
if iid:
    train_loaders=[]
    for i in range(n_iter):
        train_loaders.append([(i, pos, get_train_iid(all_range, pos, i))
                                for pos in range(num_of_workers)])
else:
    indices_per_participant = sample_dirichlet_train_data(
        num_of_workers,
        #dataset= torch.utils.data.Subset(train_dataset, list(range(240))),
        alpha=0.95,
        copylist=copylist)
    train_loaders = [(-1, pos, get_train_noniid(indices)) for pos, indices in
                    indices_per_participant.items()]
    train_loaders = n_iter * [train_loaders]

print(train_loaders)

[ 29.69597106 170.24622788  85.18080779 144.55761235 170.24622788
  21.55850334   8.47436622  23.81291249  23.81291249  55.28636056
  23.81291249 170.24622788   8.47436622  21.55850334  23.81291249
  29.69597106 144.55761235  21.55850334 170.24622788   8.47436622
  29.69597106 170.24622788  55.28636056  23.81291249  85.18080779
  21.55850334  28.49562733  20.59026894  14.33540942  23.81291249
 144.55761235  29.69597106   8.47436622  55.28636056 144.55761235
  20.59026894  21.55850334  28.49562733  20.59026894  55.28636056
  20.59026894 144.55761235  85.18080779  50.28353473  20.59026894
 170.24622788   8.47436622  20.59026894  85.18080779  20.59026894
  20.59026894  20.59026894  85.18080779 170.24622788  28.49562733
  55.28636056  20.59026894 144.55761235  28.49562733 170.24622788
  50.28353473 144.55761235  28.49562733  50.28353473  85.18080779
 144.55761235  23.81291249  20.59026894 144.55761235  55.28636056
 170.24622788  29.69597106  50.28353473  23.81291249  20.59026894
   8.47436

In [902]:
train_loaders = []
for id_worker in range(len(ewd)):
    dataset_per_worker=[]
    for idx in range(len(ewd[id_worker])):
        dataset_per_worker.append((ewd[id_worker][idx], ewl[id_worker][idx]))
    train_loader = torch.utils.data.DataLoader(dataset_per_worker, batch_size=batch_size_train, shuffle=True)
    train_loaders.append((-1, id_worker, train_loader))
    
dataset_server=[]
for idx in range(len(sd)):
    dataset_server.append((sd[idx], sl[idx]))
train_loader = torch.utils.data.DataLoader(dataset_server, batch_size=batch_size_train, shuffle=True)
train_loaders.append((-1, id_worker+1, train_loader))

#train_loaders = [(-1, idx, torch.utils.data.DataLoader(ew, batch_size=batch_size_train, shuffle=True)) for idx, ew in enumerate(ewd)]
train_loaders = n_iter * [train_loaders]

copylist=[int(np.floor(i/((num_of_workers-1)/10))) for i in range(num_of_workers-1)]
copylist.append(copylist[-1]+1)
print(copylist)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10]


In [None]:
fl = CustomFL(n_iter=n_iter, n_epochs=n_epochs, poison_starts_at_iter=poison_starts_at_iter, num_of_benign_nets=num_of_workers-num_of_mal_workers, num_of_mal_nets=num_of_mal_workers, 
              inertia=inertia, momentum=momentum,
              attack_type=attack_type, scale_up=scale_up, minimizeDist=minimizeDist
)
fl.train()

100%|████████████████████████████████████████████████████████████████████████████████| 101/101 [01:15<00:00,  1.33it/s]


Original Copylist [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10]
Found clusters [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0 10 10 10 10
 10 10 10 10 10 10  5  5  5  5  5  5  5  5  5  5  0  0  0  0  0  0  0  0
  0  0  8  8  2  8  9  8  8  2  8  8  0  6  0  0  6  6  0  0  0  6 11 11
 11  3 11 11 11 11  3 11  4  4  4  4  4  4  4  4  4  4  1  1  1  1  1  1
  1  1  1  1  7]
Original groups [array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64), array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=int64), array([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int64), array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39], dtype=int64), array([40, 41, 42, 43, 44, 45, 46, 47, 48, 49], dtype=int64), array([50, 51, 52, 53, 54, 55, 56, 5

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.94it/s]



Test set: Avg. loss: 0.0023, Accuracy: 2215/10000 (22%)



 19%|███████████████▏                                                                 | 19/101 [00:16<01:09,  1.19it/s]

In [None]:
import pickle
with open(f'file_1_dec27_plain_log_all_benign_iters_50_epoch_{n_epochs}.txt', 'wb') as f:
    pickle.dump(fl.debug_log, f)
f.close()

In [None]:
import os
from os import listdir
from os.path import isfile, join
mypath=os.getcwd()
onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))]

In [None]:
import pickle
with open('file_1_dec23_label_flip_minDist_off.txt', 'rb') as f:
    log_file=pickle.load(f)
f.close()

with open('file_1_dec23_label_flip_minDist_on.txt', 'rb') as f:
    log_file_2=pickle.load(f)
f.close()

log_by_iter=[]
for iter in range(n_iter):
    log_by_iter.append([])
idx=0
cluster_logs=[]
for l in log_file:
    (i, _, _, _)=l
    log_by_iter[i].append(l)
    if i==-1:
        cluster_logs.append(l)


log_by_iter_2=[]
for iter in range(n_iter):
    log_by_iter_2.append([])
idx=0
cluster_logs_2=[]
for l in log_file_2:
    (i, _, _, _)=l
    log_by_iter_2[i].append(l)
    if i==-1:
        cluster_logs_2.append(l)


cluster_scores=[]
for i, log in enumerate(cluster_logs):
    if (i+1)%3==0:
        (_, _, _, sc) = log
        cluster_scores.append(sc)
        
cluster_scores_2=[]
for i, log in enumerate(cluster_logs_2):
    if (i+1)%3==0:
        (_, _, _, sc) = log
        cluster_scores_2.append(sc)
        
print(cluster_scores, cluster_scores_2)

In [None]:
from sklearn.cluster import AgglomerativeClustering
import numpy as np
X = np.array([[1, 2], [1, 4], [1, 0],
              [4, 2], [4, 4], [4, 0]])
clustering = AgglomerativeClustering(n_clusters=3).fit(X)
labels=clustering.labels_
print(labels, labels[labels==0], np.argwhere(labels==0))
print([np.argwhere(labels==i).flatten() for i in range(3)])

print()

In [None]:
t_arr=np.array([100, 10, 1])
t_arr_sig = 0.1 * t_arr
print(np.random.normal(t_arr, t_arr_sig))

In [None]:
import matplotlib.pyplot as plt

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

for i in [0, 3, 5]:
    (_, _, tl) = train_loaders[0][i]

    print(tl)

    # get some random training images
    dataiter = iter(tl)
    images, labels = dataiter.next()
    labels = np.array(labels)
    imshow(torchvision.utils.make_grid(images[:16]))
    for n in range(10):
        print(n, np.count_nonzero(labels == n)/len(labels))


In [None]:
sampled_probabilities = 1000 * np.random.dirichlet(
                np.array(3 * [0.9]))
print(np.sum(sampled_probabilities), sampled_probabilities)

In [None]:
print(3 * [1, 2])

In [None]:
print(fl.log)

In [None]:
import pickle
with open('file_1_dec16.txt', 'rb') as f:
    log_file=pickle.load(f)
f.close()

In [None]:
log_by_iter=[]
for iter in range(n_iter):
    log_by_iter.append([])
idx=0
for l in log_file:
    (i, _, _, _)=l
    log_by_iter[i].append(l)


In [None]:
for i in range(5,10):
    log_by_iter[i]=log_by_iter[i][2:]

for l in log_by_iter[0]:
    print(l)
for l in log_by_iter[5]:
    print(l)

In [None]:
from matplotlib import pyplot as plt

cos_sim_benign=[]
cos_sim_mal=[]
for i in range(10):
    (_, _, _, sim) = log_by_iter[i][2]
    cos_sim_benign.append(sim[0])
    cos_sim_mal.append(sim[2])

plt.plot(cos_sim_benign, c='blue')
plt.plot(cos_sim_mal, c='red')
plt.show()

In [None]:
aggr_weight_benign=[]
aggr_weight_mal=[]
for i in range(10):
    (_, _, _, sim) = log_by_iter[i][4]
    aggr_weight_benign.append(sim[0])
    aggr_weight_mal.append(sim[2])

plt.plot(aggr_weight_benign, c='blue')
plt.plot(aggr_weight_mal, c='red')
plt.show()


In [None]:
distance_benign=[]
distance_mal=[]
for i in range(10):
    (_, _, _, sim) = log_by_iter[i][0]
    distance_benign.append(sim[0])
    (_, _, _, sim) = log_by_iter[i][1]
    distance_mal.append(sim)

plt.plot(distance_benign, c='blue')
plt.plot(distance_mal, c='red')
plt.show()

In [None]:
accuracy=[]
for i in range(10):
    (_, _, _, sim) = log_by_iter[i][5]
    accuracy.append(sim)

plt.plot(accuracy, c='blue')
