In [57]:
import os
import math
import datetime
import numpy as np
import matplotlib.pyplot as plt
import random

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms, models
from torchvision.transforms import ToTensor
import torch.nn.functional as F

In [59]:
#定义带两个卷积路径和一条捷径的残差基本块类
class BasicBlock(nn.Module):
	expansion = 1
	def __init__(self, in_planes, planes, stride=1): #初始化函数，in_planes为输入通道数，planes为输出通道数，步长默认为1
		super(BasicBlock, self).__init__()
#定义第一个卷积，默认卷积前后图像大小不变但可修改stride使其变化，通道可能改变
		self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, 				padding=1, bias=False)
#定义第一个批归一化
		self.bn1 = nn.BatchNorm2d(planes)
#定义第二个卷积，卷积前后图像大小不变，通道数不变
		self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=1, bias=False)
#定义第二个批归一化
		self.bn2 = nn.BatchNorm2d(planes)
#定义一条捷径，若两个卷积前后的图像尺寸有变化(stride不为1导致图像大小变化或通道数改变)，捷径通过1×1卷积用stride修改大小
#以及用expansion修改通道数，以便于捷径输出和两个卷积的输出尺寸匹配相加
		self.shortcut = nn.Sequential()
		if stride != 1 or in_planes != self.expansion*planes:
			self.shortcut = nn.Sequential(
				nn.Conv2d(in_planes, self.expansion*planes,kernel_size=1, stride=stride, bias=False),
				nn.BatchNorm2d(self.expansion*planes)
			)
#定义前向传播函数，输入图像为x，输出图像为out
	def forward(self, x):
		out = F.relu(self.bn1(self.conv1(x))) #第一个卷积和第一个批归一化后用ReLU函数激活
		out = self.bn2(self.conv2(out))
		out += self.shortcut(x) #第二个卷积和第二个批归一化后与捷径相加
		out = F.relu(out) #两个卷积路径输出与捷径输出相加后用ReLU激活
		return out
#定义残差网络ResNet18
class ResNet(nn.Module):
#定义初始函数，输入参数为残差块，残差块数量，默认参数为分类数10
	def __init__(self, block, num_blocks, num_classes=10):
		super(ResNet, self).__init__()
#设置第一层的输入通道数
		self.in_planes = 64
#定义输入图片先进行一次卷积与批归一化，使图像大小不变，通道数由3变为64得两个操作
		self.conv1 = nn.Conv2d(3, 64, kernel_size=3,stride=1, padding=1, bias=False)
		self.bn1 = nn.BatchNorm2d(64)
#定义第一层，输入通道数64，有num_blocks[0]个残差块，残差块中第一个卷积步长自定义为1
		self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
#定义第二层，输入通道数128，有num_blocks[1]个残差块，残差块中第一个卷积步长自定义为2
		self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
#定义第三层，输入通道数256，有num_blocks[2]个残差块，残差块中第一个卷积步长自定义为2
		self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
#定义第四层，输入通道数512，有num_blocks[3]个残差块，残差块中第一个卷积步长自定义为2
		self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
#定义全连接层，输入512*block.expansion个神经元，输出10个分类神经元
		self.linear = nn.Linear(512*block.expansion, num_classes)
#定义创造层的函数，在同一层中通道数相同，输入参数为残差块，通道数，残差块数量，步长
	def _make_layer(self, block, planes, num_blocks, stride):
#strides列表第一个元素stride表示第一个残差块第一个卷积步长，其余元素表示其他残差块第一个卷积步长为1
		strides = [stride] + [1]*(num_blocks-1)
#创建一个空列表用于放置层
		layers = []
#遍历strides列表，对本层不同的残差块设置不同的stride
		for stride in strides:
			layers.append(block(self.in_planes, planes, stride)) #创建残差块添加进本层
			self.in_planes = planes * block.expansion #更新本层下一个残差块的输入通道数或本层遍历结束后作为下一层的输入通道数
		return nn.Sequential(*layers) #返回层列表
#定义前向传播函数，输入图像为x，输出预测数据
	def forward(self, x):
		out = F.relu(self.bn1(self.conv1(x))) #第一个卷积和第一个批归一化后用ReLU函数激活
		out = self.layer1(out) #第一层传播
		out = self.layer2(out) #第二层传播
		out = self.layer3(out) #第三层传播
		out = self.layer4(out) #第四层传播
		out = F.avg_pool2d(out, 4) #经过一次4×4的平均池化
		out = out.view(out.size(0), -1) #将数据flatten平坦化
		out = self.linear(out) #全连接传播
		return out

fine_resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])

In [60]:
class Logger(object):
    def __init__(self, output):
        self.log_file_name = output
    def log(self, message):
        print(message)
        with open(self.log_file_name, 'a') as f:
            f.write(message+'\n')

def get_dataset(dir, name):

    if name=='mnist':
        train_dataset = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())
        eval_dataset = datasets.MNIST(dir, train=False, transform=transforms.ToTensor())
        
    elif name=='cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])


        train_dataset = datasets.CIFAR10(dir, train=True, download=True, transform=transform_train)
        eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)


    return train_dataset, eval_dataset

def get_model(name="fine", pretrained=True):
    if name == "fine":
        model = fine_resnet18
    elif name == "resnet18":
        model = models.resnet18(pretrained=pretrained)
    elif name == "resnet50":
        # model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        model = models.resnet50(pretrained=pretrained) 
    elif name == "densenet121":
        model = models.densenet121(pretrained=pretrained)        
    elif name == "alexnet":
        model = models.alexnet(pretrained=pretrained)
    elif name == "vgg16":
        model = models.vgg16(pretrained=pretrained)
    elif name == "vgg19":
        model = models.vgg19(pretrained=pretrained)
    elif name == "inception_v3":
        model = models.inception_v3(pretrained=pretrained)
    elif name == "googlenet":        
        model = models.googlenet(pretrained=pretrained)

    # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        model = model.cuda()

    return model
        

def serialize_model(model: torch.nn.Module) -> torch.Tensor:
    parameters = [param.data.view(-1) for param in model.parameters()]
    m_parameters = torch.cat(parameters)

    return m_parameters

def deserialize_model(model: torch.nn.Module, serialized_parameters: torch.Tensor, mode="copy"):

    current_index = 0  # keep track of where to read from grad_update
    for parameter in model.parameters():
        
        numel = parameter.data.numel()
        size = parameter.data.size()
        if mode == "copy":
            parameter.data.copy_(serialized_parameters[current_index:current_index + numel].view(size))
        elif mode == "add":
            parameter.data.add_(serialized_parameters[current_index:current_index + numel].view(size))
        else:
            raise ValueError("Invalid deserialize mode {}, require \"copy\" or \"add\" ".format(mode))
        current_index += numel

In [61]:
class Client(object):

    def __init__(self, conf, train_dataset, id = -1):
        
        self.conf = conf
        self.local_model = get_model(self.conf["model_name"]) 
        self.id = id
        self.train_dataset = train_dataset
        
        all_range = list(range(len(self.train_dataset)))
        data_len = int(len(self.train_dataset) / self.conf['no_clients'])
        train_indices = all_range[id * data_len: (id + 1) * data_len]
        # logger.log("client: %d dataset len %d" %(id, data_len))
                
        self.train_loader = DataLoader(self.train_dataset, batch_size=conf["batch_size"], sampler=SubsetRandomSampler(train_indices))
        
        
    def local_train(self, global_model):

        # Download global model to client
        for name, param in global_model.state_dict().items():
            self.local_model.state_dict()[name].copy_(param.clone())
        
        # Initialize the loss function
        loss_fn = nn.CrossEntropyLoss()
        if self.conf['opt'] == 'sgd':
          optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'], momentum=self.conf['momentum'], weight_decay=self.conf['weight_decay'], nesterov="nesterov" in conf['opt'])
        elif self.conf['opt'] == 'rmsprop':
          optimizer = torch.optim.RMSprop(self.local_model.parameters(), lr=self.conf['lr'], momentum=self.conf['momentum'], weight_decay=self.conf['weight_decay'])
        elif self.conf['opt'] == 'adam':
          optimizer = torch.optim.Adam(self.local_model.parameters(), lr=self.conf['lr'], weight_decay=self.conf['weight_decay'])
        elif self.conf['opt'] == 'adamw':
          optimizer = torch.optim.AdamW(self.local_model.parameters(), lr=self.conf['lr'], weight_decay=self.conf['weight_decay'])
        elif self.conf['opt'] == 'nadam':
          optimizer = torch.optim.NAdam(self.local_model.parameters(), lr=self.conf['lr'], weight_decay=self.conf['weight_decay'])
        elif self.conf['opt'] == 'radam':
          optimizer = torch.optim.RAdam(self.local_model.parameters(), lr=self.conf['lr'], weight_decay=self.conf['weight_decay'])

        # size = len(self.train_loader.dataset)
        
        self.local_model.train()
        # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        for e in range(self.conf["local_epochs"]):      
            for inputs, lables in self.train_loader:
                
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    lables = lables.cuda()

                # inputs.to(device)
                # lables.to(device)
                
                # Compute prediction and loss
                pred = self.local_model(inputs)
                loss = loss_fn(pred, lables)
                
                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                                
                #if batch_id % 100 == 0:
                    #loss, current = loss.item(), batch_id * len(data)
            if self.conf['no_clients'] <= 5:
                logger.log(f"Begine{self.id+1:>d} | loss: {loss:>7f}  [{self.id+1:>d}/{self.conf['no_clients']:>d}]")

        # update = torch.sub(serialize_model(self.local_model), serialize_model(global_model))
        update = serialize_model(self.local_model)
        return update

In [62]:
class Server(object):
    
    def __init__(self, conf, eval_dataset):
    
        self.conf = conf 
        self.global_model = get_model(self.conf["model_name"])
        self.eval_loader = DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)
        
    def get_mal_updates(self, updates):
        
        attack_type = self.conf['attack_type']
        agg_type = self.conf['agg_type']
        m = self.conf['no_attackers']
        p = self.conf['pertubation']
         
        serialzed_updates = torch.stack(updates)
        avg = torch.mean(serialzed_updates, 0)
        
        if p == 'unit':
            deviation = avg / torch.norm(avg)
        elif p == 'sign':
            deviation = torch.sign(avg)
        elif p == 'std':
            deviation = torch.std(serialzed_updates)
          
        if attack_type == "None":
            mal_updates = updates
        elif attack_type == "lie":
            mal_updates = lie_attack(updates, m)
        elif attack_type == "fang":
            mal_updates = fang_attack(updates, m, avg, agg_type)
        elif attack_type == "optim":
            mal_updates = optim_attack(updates, m, avg, deviation, agg_type)
        elif attack_type == "optim_min_max":
            mal_updates = optim_min(updates, m, avg, deviation, min='max')
        elif attack_type == "optim_min_sum":
            mal_updates = optim_min(updates, m, avg, deviation, min='sum')
            
        return mal_updates
                              
    def model_aggregate(self, updates):
        
        agg_type = self.conf['agg_type']
        m = self.conf['no_attackers']
        serialzed_updates = torch.stack(updates)  

        if agg_type == "fedavg":
            agg_update = torch.mul(torch.sum(serialzed_updates, 0), self.conf['lambda'])
        elif agg_type == "median":
            agg_update = torch.median(serialzed_updates, 0)[0]
        elif agg_type == "mean":
            agg_update = torch.mean(serialzed_updates, 0)
        elif agg_type == "krum":
            candidates = krum(updates, m)
            agg_update = torch.mean(torch.stack(list(candidates.values())), 0)
            logger.log(list(candidates.keys()))
        elif agg_type == "mkrum":
            candidates = krum(updates, m, mkrum=True)
            agg_update = torch.mean(torch.stack(list(candidates.values())), 0)
            logger.log(list(candidates.keys()))
        elif agg_type == "trmean":
            agg_update = trmean(updates, m)
        elif agg_type == "bulyan":
            agg_update = bulyan(updates, m)

        deserialize_model(self.global_model, agg_update)
        
    def model_eval(self):
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.global_model.eval()
        
        loss_fn = nn.CrossEntropyLoss()
        size = len(self.eval_loader.dataset)           
        num_batches  = len(self.eval_loader)
        test_loss, correct = 0, 0
        
        with torch.no_grad():
            for inputs, lables in self.eval_loader:
                
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    lables = lables.cuda()

                # X.to(device)
                # y.to(device)
                pred = self.global_model(inputs)
                test_loss += loss_fn(pred, lables).item()
                correct += (pred.argmax(1) == lables).type(torch.float).sum().item()
        
        correct /= size
        test_loss /= num_batches
        return correct, test_loss

In [63]:
def lie_attack(updates, m):
    z = {1: 0.68, 3:0.69847, 5:0.7054, 8:0.71904, 10:0.72575, 12:0.73891}
    
    serialzed_updates = torch.stack(updates)
    avg = torch.mean(serialzed_updates, 0)
    std = torch.std(serialzed_updates)
    
    mal_update = avg + z[m] * std
    mal_updates = [mal_update] * m + updates
        
    return mal_updates
    
def fang_attack(updates, m, avg, agg_type):

    serialzed_updates = torch.stack(updates)
    deviation = torch.sign(avg)
    # Note that Fang attacks on Multi-krum and Bulyan are the same
    if agg_type in ["krum", "mkrum", "bulyan"]:
    
        
    
        n, d = serialzed_updates.shape
    
        for update in updates:
            dist = torch.norm((updates - update), dim=1)
            dists = dist[None, :] if not len(dists) else torch.cat((dists, dist[None, :]), 0)
        
        dists[dists == 0] = 10000
        dists = torch.sort(dists, dim=1)[0]
        scores = torch.sum(dists[:, :n - 2 - m], dim=1)
        min_score = torch.min(scores)

    
        term_1 = min_score / ((n - m - 1) * torch.sqrt(torch.Tensor([d]))[0])
        max_wre_dist = torch.max(torch.norm((serialzed_updates - avg), 1)) / (torch.sqrt(torch.Tensor([d]))[0])
    
        lamda = term_1 + max_wre_dist
        threshold = 1e-5
    
        while lamda > threshold:
            mal_update = (- lamda * deviation)
            mal_updates = mal_update * m + updates
        
            candidates = krum(updates, m)
        
            if np.array(candidates) < m:
                return mal_updates
        
            lamda *= 0.5
        
        if not len(mal_updates):
            # logger.log(lamda, threshold)
            mal_update = (avg - lamda * deviation)
        
            mal_updates = [mal_update] * m + updates
        
        return mal_updates
    
    # Note that Fang attacks on Trimmed-mean and median are the same
    elif agg_type in ["mean", "trmean", "median"]:
        # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        b = 2
        max_vector = torch.max(serialzed_updates, 0)[0]
        min_vector = torch.max(serialzed_updates, 0)[0]
        
        max_ = (max_vector > 0).float()
        min_ = (min_vector < 0).float()
        if torch.cuda.is_available():
            max_ = max_.cuda()
            min_ = min_.cuda()
        # max_ = (max_vector > 0).float().to(device)
        # min_ = (min_vector < 0).float().to(device)
           
        max_[max_ == 1] = b
        max_[max_ == 0] = 1 / b
        min_[min_ == 1] = b
        min_[min_ == 0] = 1 / b

        max_range = torch.cat((max_vector[:, None], (max_vector * max_)[:, None]), 1)
        min_range = torch.cat(((min_vector * min_)[:, None], min_vector[:, None]), 1)

        # rand = torch.rand(len(deviation), m).to(device)
        rand = torch.rand(m, len(deviation))
        if torch.cuda.is_available():
            rand = rand.cuda()
        
        for i in range(m):
            max_rand = max_range[:, 0] + rand[i] * (max_range[:, 1] - max_range[:, 0])
            min_rand = min_range[:, 0] + rand[i] * (min_range[:, 1] - min_range[:, 0])
            mal_update = (deviation > 0).float()* (max_rand + min_rand)
            mal_updates = [mal_update] + updates
        
        return mal_updates
    
    else:
        return updates

def optim_attack(updates, m, avg, dev, agg_type):
    # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    if agg_type in ["krum", "mkrum", "bulyan"]:
        # lamda = torch.Tensor([3.0]).to(device)
        lamda = torch.Tensor([3.0])

    elif agg_type in ["mean", "median", "trmean"]:
        # lamda = torch.Tensor([10.0]).to(device)
        lamda = torch.Tensor([10.0])
    if torch.cuda.is_available():
        lamda = lamda.cuda()
    threshold = 1e-5
    prev_loss = -1
    lamda_fail = lamda
    lamda_succ = 0
    iters = 0
    
    while torch.abs(lamda_succ - lamda) > threshold:
        mal_update = (avg - lamda * dev)
        mal_updates = mal_update * m + updates
        
        # Note that optim attacks on multi-krum and Bulyan aggregations are the same
        if agg_type in ["krum", "mkrum", "bulyan"]:
            if agg_type == "krum":
                candidates = krum(mal_updates, m)
            else:
                candidates = krum(mal_updates, m, mkrum=True)
        
            if np.sum(np.array(candidates) < m) == m:
                lamda_succ = lamda
                lamda = lamda + lamda_fail / 2
            else:
                lamda = lamda - lamda_fail / 2
            lamda_fail = lamda_fail / 2
        
        elif agg_type in ["mean", "median", "trmean"]:
            if agg_type == "trmean":
                agg_update = trmean(mal_updates, m)
            elif agg_type == "median":
                agg_update = torch.median(mal_updates, 0)[0]
            elif agg_type == "mean":
                agg_update = torch.mean(mal_updates, 0)
            
            loss = torch.norm(agg_update - avg)
            
            if prev_loss < loss:
                lamda_succ = lamda
                lamda = lamda + lamda_fail / 2
            else:
                lamda = lamda - lamda_fail / 2

            lamda_fail = lamda_fail / 2
            prev_loss = loss
        else:
            return updates
        
    mal_update = (avg - lamda_succ * dev)
    mal_updates = [mal_update] * m + updates
    
    return mal_updates    

def optim_min(updates, m, avg, dev, min='max'):
    # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    lamda = torch.Tensor([10.0])
    if torch.cuda.is_available:
        # lamda = torch.Tensor([10.0]).to(device)
        lamda = lamda.cuda()
    threshold = 1e-5
    lamda_fail = lamda
    lamda_succ = 0
    
    for update in updates:
        dist = torch.norm((updates - update), dim=1)
        dists = dist[None, :] if not len(dists) else torch.cat((dists, dist[None, :]), 0)
        
    max_distance = torch.max(dists)
    scores = torch.sum(dists, dim=1)
    min_score = torch.min(scores)
    
                         
    while torch.abs(lamda_succ - lamda) > threshold:
        mal_update = (avg - lamda * dev)
        distance = torch.norm((updates - mal_update), dim=1)
        
        if min == 'max':
            max_d = torch.max(distance)
            
            if max_d <= max_distance:
                # logger.log('successful lamda is ', lamda)
                lamda_succ = lamda
                lamda = lamda + lamda_fail / 2
            else:
                lamda = lamda - lamda_fail / 2
            lamda_fail = lamda_fail / 2
        elif min == 'sum':
            score = torch.sum(distance)
        
            if score <= min_score:
                # logger.log('successful lamda is ', lamda)
                lamda_succ = lamda
                lamda = lamda + lamda_fail / 2
            else:
                lamda = lamda - lamda_fail / 2
            lamda_fail = lamda_fail / 2

    # logger.log(lamda_succ)
    mal_update = (avg- lamda_succ * dev)
    mal_updates = [mal_update] * m + updates
    
    return mal_updates 

In [64]:
def krum(updates, m, mkrum=False, bulyan=False):
    candidates = {}
    dists = {}
    # Initalize dist matrix
    for i, update in enumerate(updates):
        dist = {}
        for j, update_ in enumerate(updates):
            if j == i:
                continue         
            dist[j] = torch.dist(update, update_).item()  
        dists[i] = dist    
    if bulyan:
        c = 2 * m
    else:
        c = 2 * m + 2

    while len(dists) > c:
        # Find the minimum summed distance to cloesest n-m-2 neighbors
        scores = {}
        for i, dist in dists.items():
            scores[i] = sum(sorted(list(dist.values()), key=lambda x:float(x))[:len(dists) - m - 2])
        # sel_ind = sorted(scores.items(), key=lambda x: float(x[1]))[0][0]
        # logger.log(scores)
        ind = min(scores, key=scores.get)
        candidates[ind] = updates[ind]

        # Remove the selected update from the dist matrix
        del dists[ind]
        for dist_ in dists.values():
            del dist_[ind]

        if not mkrum:
            break
    
        # logger.log(candidates)
    logger.log("Select %d Malicious updates" %(np.sum(np.array(candidates)<m)))
    return candidates

def trmean(updates, m):
    serialzed_updates = torch.stack(updates)
    sorted_updates = torch.sort(serialzed_updates, 0)[0]
    if not m:
        agg_update = torch.mean(sorted_updates, 0)
    else:
        agg_update = torch.mean(sorted_updates[m:-m], 0)
    return agg_update

def bulyan(updates, m):

    candidates = krum(updates, m, mkrum=True, bulyan=True)
    agg_update = trmean(list(candidates.values()), m)
    return agg_update
   

In [65]:
if __name__ == '__main__':

    conf = {
        'model_name': 'fine',
        'datasets': 'cifar10',
        'batch_size': 128,
        'lr': 0.01,
        'momentum': 0.9,
        'weight_decay': 5e-4,
        'opt': 'sgd',
        'no_clients': 5,
        'no_attackers': 1,
        'agg_type': 'mean',
        'attack_type': 'fang',
        'pertubation': 'std',
        'local_epochs': 1,
        'global_epochs': 100,
        'k': 3,
        'lambda': 0.3,
        'chpt_path': './checkpoints/',
        'resume': False,
        'output': './output/',
        # 'eta': 2,
        # 'alpha': 1.0,
        # 'poison_label': 2,
        # 'poisoning_per_batch': 4
    }
    
    test = 0

    # agg_type = ['fedavg', 'mean', 'median', 'krum', 'mkrum', 'trmean', 'bulyan']
    # attack_type = ['lie', 'fang', 'optim', 'optim_min_max', 'optiom_min_sum']
    # pertubation = ['unit', 'sign', 'std']
    
    n = conf['no_clients']
    m = conf['no_attackers']
    att_type = conf["attack_type"]
    agg_type = conf['agg_type']
    p = conf['pertubation']
    epochs = conf['global_epochs']

    if not os.path.exists(conf['chpt_path']):
        os.mkdir(conf['chpt_path'])
    if not os.path.exists(conf['output']):
        os.mkdir(conf['output'])

    chpt_path = conf['chpt_path'] + conf['model_name'] + '_' + conf['datasets'] + '_' + str(n) + '_' + str(m) + '_' + att_type + '_' + agg_type + '_' + p + '.pth'
    output = conf['output'] + conf['model_name'] + '_' + conf['datasets'] + '_' + str(n) + '_' + str(m) + '_' + att_type + '_' + agg_type + '_' + p + '.txt'

    logger = Logger(output)
    # Check the validation of aggregation rule
    if m >= 0.5 * n:
        logger.log("No. of attackers must be less than no. of clients")
        exit()
    if agg_type == "bulyan":
        if n < 4 * m + 3:
            logger.log("Too much malicious clinets")
            exit()

    train_datasets, eval_datasets = get_dataset("./data/", conf["datasets"])
    
    if not conf['resume']:
        logger.log("Model: %s | optimizer: %s | batch_size: %d" %(conf["model_name"], conf["opt"], conf["batch_size"]))
        logger.log("Learning rate: %f | Momentum: %f | Weight decay: %f" %(conf["lr"], conf["momentum"], conf["weight_decay"]))
        logger.log("Datasets: %s | Train datasets: %d | Evaluate datasets: %d" %(conf["datasets"], len(train_datasets), len(eval_datasets)))
        logger.log("Global epochs: %d | Local epochs: %d" %(conf["global_epochs"], conf["local_epochs"]))
        logger.log("Clients: %d | Attackers: %d" %(n, m))
        logger.log("Attack: %s | Aggregation: %s | Pertubation: %s" %(att_type, agg_type, p))
        logger.log("\n")
    
    start = datetime.datetime.now()
    logger.log("Start training at: %s" %(start))
    
    # Initialize the server and clients model
    server = Server(conf, eval_datasets)

    clients = []

    for c in range(n - m):
        clients.append(Client(conf, train_datasets, c))
    

    # Set resume to True to load the checkpoint
    if conf['resume']:
        epoch, server_state_dict = torch.load(chpt_path)
        server.global_model.load_state_dict(server_state_dict)
        epochs = conf['global_epochs'] - epoch
        print("Resume training from epoch %d at %s" %(epoch, datetime.datetime.now()))

    for e in range(epochs):

        # if e == 25:
        #     conf['lr'] = conf['lr'] / 10
        # elif e == 50:
        #     conf['lr'] = conf['lr'] / 10
        # logger.log(f"Epoch {e+1}\n-------------------------------")

        if agg_type == "fedavg":
            clients = random.sample(clients, conf["k"])
        
        updates = []
            
        for client in clients:
        
            update = client.local_train(server.global_model)
            updates.append(update)
        if m > 0:
            uptates = server.get_mal_updates(updates)

        server.model_aggregate(updates)
        correct, test_loss = server.model_eval()
        end = datetime.datetime.now()
        logger.log(f"Time: {(end-start).seconds}s Epoch: {e+1} Learning rate: {conf['lr']} Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} ")
        if test:
            break
        
        # Save the checkpoint
        torch.save([e, server.global_model.state_dict()], chpt_path)
    logger.log("Complete at %s" %(datetime.datetime.now()))

Files already downloaded and verified
Model: fine | optimizer: sgd | batch_size: 128
Learning rate: 0.010000 | Momentum: 0.900000 | Weight decay: 0.000500
Datasets: cifar10 | Train datasets: 50000 | Evaluate datasets: 10000
Global epochs: 100 | Local epochs: 1
Clients: 5 | Attackers: 1
Attack: fang | Aggregation: mean | Pertubation: std


Start training at: 2022-07-24 21:39:55.092510
Begine1 | loss: 1.734655  [1/5]
Begine2 | loss: 1.367931  [2/5]
Begine3 | loss: 1.174004  [3/5]
Begine4 | loss: 1.224118  [4/5]
cuda:0
cuda:0
cpu


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!