In [1]:
import sys
sys.path.append("..")
import os
import torch
import random
import neptune
import syft as sy
import numpy as np
from tqdm import tqdm
import cvxpy as cp
from time import sleep
from copy import deepcopy
import torchattacks
import torchvision
import coloredlogs, logging
from torchvision import transforms
from collections import defaultdict
from torch.nn import functional as F
from federated_learning.FLNet import FLNet
from federated_learning.FLCustomDataset import FLCustomDataset
from federated_learning.Arguments import Arguments
from federated_learning.helper import utils
import time

CONFIG_PATH = '../configs/defaults.yml'
TQDM_R_BAR = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{postfix}] '


In [2]:

    
arguments = dict()
arguments['--log'] = False
arguments['--nep-log'] = False
arguments['--avg'] = True
arguments['--opt'] = False

configs = utils.load_config(CONFIG_PATH)

args = Arguments(
    configs['runtime']['batch_size'],
    configs['runtime']['test_batch_size'],
    configs['runtime']['rounds'],
    configs['runtime']['epochs'],
    configs['runtime']['lr'],
    configs['runtime']['momentum'],
    configs['runtime']['weight_decay'],
    configs['mnist']['shards_num'],
    configs['mnist']['shards_per_worker_num'],
    configs['mnist']['total_users_num'],
    configs['mnist']['selected_users_num'],
    configs['server']['data_fraction'],
    "avg" if arguments['--avg'] else "opt",
    configs['attack']['attack_type'],
    configs['attack']['attackers_num'],
    configs['runtime']['use_cuda'],
    torch.device("cuda" if configs['runtime']['use_cuda'] else "cpu"),
    configs['runtime']['random_seed'],
    configs['log']['interval'],
    configs['log']['level'],
    configs['log']['format'],
    utils.make_output_dir(
        configs['log']['root_output_dir'], arguments['--output-prefix']
        ) if arguments['--log'] else "",
    True if arguments['--nep-log'] else False,
    True if arguments['--log'] else False
)

logger = logging.getLogger(__name__)
coloredlogs.install(level=args.log_level, fmt=args.log_format)

print(
        "Configs:\n\
        Epoch:\t{}\n\
        Rounds:\t{}\n\
        Total Number of Users:\t{}\n\
        Selected Users:\t{}\n\
        Mode:\t{}\n\
        Attack:\t{}".format(
            args.epochs, args.rounds, args.total_users_num, args.selected_users_num, 
            args.mode, args.attack_type, args.attackers_num
        ))

torch.manual_seed(args.seed)
random.seed(args.seed)

# syft initialization
hook = sy.TorchHook(torch)

Configs:
        Epoch:	5
        Rounds:	500
        Total Number of Users:	100
        Selected Users:	30
        Mode:	avg
        Attack:	1


In [3]:
server_model = FLNet().to(args.device)
# server_model.load_state_dict(torch.load("/home/savi/ehsan/FederatedLearning/data/R290_server_model"))

In [4]:
def create_workers(hook, workers_idx):
    logging.info("Creating {} workers...".format(len(workers_idx)))
    workers = dict()
    for worker_id in workers_idx:
        logging.debug("Creating the worker: {}".format(worker_id))
        workers[worker_id] = sy.VirtualWorker(hook, id=worker_id)
    logging.info("Creating {} workers..... OK".format(len(workers_idx)))
    return workers

In [5]:
logging.info("Total number of users: {}".format(args.total_users_num))
workers_idx = ["worker_" + str(i) for i in range(args.total_users_num)]
workers = create_workers(hook, workers_idx)

2021-02-04 22:45:34: Total number of users: 100
2021-02-04 22:45:34: Creating 100 workers...
2021-02-04 22:45:34: Creating 100 workers..... OK


In [6]:
train_dataset = utils.load_mnist_dataset(
        train=True, transform=transforms.Compose([
                                transforms.ToTensor(),]))

2021-02-04 22:45:35: Loading train data from MNIST dataset...


In [7]:
sorted_train_dataset = utils.sort_mnist_dataset(train_dataset)
splitted_train_dataset = utils.split_dataset(
    sorted_train_dataset, int(len(sorted_train_dataset) / args.shards_num))
mapped_train_datasets = utils.map_shards_to_worker(splitted_train_dataset, workers_idx, args.shards_per_worker_num)

server_dataset = utils.fraction_of_datasets(mapped_train_datasets, 5)

2021-02-04 22:45:36: Sorting the MNIST dataset based on labels...
2021-02-04 22:45:38: Splitting the dataset into tensors with 300 samples...
2021-02-04 22:45:38: Federated data to 100 users..... OK
2021-02-04 22:45:38: Extracting 500.0 of users data (total: 300000) to be sent to the server...
2021-02-04 22:45:38: Extracted... Ok, The size of the extracted data: torch.Size([60000, 28, 28])


In [8]:
server_dataset = utils.fraction_of_datasets(mapped_train_datasets, args.server_data_fraction)

2021-02-04 22:46:11: Extracting 5.0 of users data (total: 3000) to be sent to the server...


IndexError: too many indices for tensor of dimension 3

In [31]:
datasets = mapped_train_datasets
fraction = 0.05
logging.info("Extracting {}% of users data (total: {}) to be sent to the server...".format(
    fraction * 100.0, int(fraction * len(datasets) * len(list(datasets.values())[0].targets))))
images, labels = [], []
for ww_id, dataset in datasets.items():
    idx = torch.randperm(len(dataset.targets))[:int(fraction * len(dataset.targets))]
    print(idx, type(idx))
    print(len(dataset.data[idx.tolist()]))
    print(dataset.targets[idx.tolist()])
#     print(len(dataset.targets), int(fraction * len(dataset.targets)))
#     images.append(dataset.data[list(idx)])
#     labels.append(dataset.targets[(idx)])
    break

# aggregate_dataset = FLCustomDataset(
#     cat(images), cat(labels),
#     transform=transforms.Compose([
#         transforms.ToTensor()])
# )
# logging.info("Extracted... Ok, The size of the extracted data: {}".format(
#     aggregate_dataset.data.shape))
# return aggregate_dataset

2021-02-04 22:55:20: Extracting 5.0% of users data (total: 3000) to be sent to the server...


tensor([213,  84, 324, 491, 137, 120, 374, 129, 486, 325, 241, 481, 311,  20,
        562,  19, 424, 280, 149, 124, 243, 323,  68,  11, 568, 454, 548, 489,
        512,  12]) <class 'torch.Tensor'>
30
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0])


In [None]:
(server_dataset.data.shape)

In [None]:
test_dataset = utils.load_mnist_dataset(
        train=False, transform=transforms.Compose([
                                transforms.ToTensor(),]))
test_loader = utils.get_dataloader(
    test_dataset, args.test_batch_size, shuffle=True, drop_last=False)

In [None]:
def test(model, test_loader, round_no, args, atk=None):
    model.eval()
    test_loss = 0
    correct = 0
    with tqdm(total=len(test_loader), ncols=80, leave=False, desc="Test\t", bar_format=TQDM_R_BAR) as t1:
        for jj, (data, target) in enumerate(test_loader):
            if atk is not None:
                data = atk(data, target)
                data, target = data.to(args.device), target.to(args.device, dtype=torch.int64)
            else:
                data, target = data.to(args.device), target.to(args.device, dtype=torch.int64)
            with torch.no_grad():
                output = model(data)
                test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
            t1.update()

    test_loss /= len(test_loader.dataset)
    test_acc = 100. * correct / len(test_loader.dataset)

    if args.neptune_log:
        neptune.log_metric("test_loss", test_loss)
        neptune.log_metric("test_acc", test_acc)
    if args.local_log:
        file = open(args.log_dir +  "accuracy", "a")
        TO_FILE = '{} {} "{{/*Accuracy:}}\\n{}%" {}\n'.format(
            round_no, test_loss, test_acc, test_acc)
        file.write(TO_FILE)
        file.close()
    
    logging.debug('Test Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), test_acc))
    return test_loss, test_acc

In [None]:
test_loss, test_acc = test(server_model, test_loader, 1, args)
print('Test Average loss: {:.4f}, Accuracy: {:.0f}%'.format(test_loss, test_acc))

In [None]:
import torchattacks
atks = [torchattacks.PGD(server_model, eps=8/255, alpha=2/255, steps=7),
        torchattacks.BIM(server_model, eps=8/255, alpha=2/255, steps=7),
        torchattacks.CW(server_model, c=1, kappa=0, steps=1000, lr=0.01),
        torchattacks.RFGSM(server_model, eps=8/255, alpha=4/255, steps=1),
        torchattacks.FGSM(server_model, eps=8/255),
        torchattacks.FFGSM(server_model, eps=8/255, alpha=12/255),
        torchattacks.TPGD(server_model, eps=8/255, alpha=2/255, steps=7),
        torchattacks.MIFGSM(server_model, eps=8/255, decay=1.0, steps=5),
       ]

In [None]:
import time
for atk in atks :
    print("-"*70)
    print(atk)
    start = time.time()
    test_loss, test_acc = test(server_model, test_loader, 1, args, atk)
    print('Test Average loss: {:.4f}, Accuracy: {:.0f}%, Time: {}'.format(test_loss, test_acc, time.time()-start))

In [None]:
atk = torchattacks.FGSM(server_model, eps=0.06)
print("-"*70)
print(atk)
start = time.time()
test_loss, test_acc = test(server_model, test_loader, 1, args, atk)
print('Test Average loss: {:.4f}, Accuracy: {:.0f}%, Time: {}'.format(test_loss, test_acc, time.time()-start))

In [None]:
atk = torchattacks.FGSM(server_model, eps=0.06)
print("-"*70)
print(atk)
start = time.time()
test_loss, test_acc = test(server_model, test_loader, 1, args, atk)
print('Test Average loss: {:.4f}, Accuracy: {:.0f}%, Time: {}'.format(test_loss, test_acc, time.time()-start))

In [None]:
atk = torchattacks.FGSM(server_model, eps=0.09)
print("-"*70)
print(atk)
start = time.time()
test_loss, test_acc = test(server_model, test_loader, 1, args, atk)
print('Test Average loss: {:.4f}, Accuracy: {:.0f}%, Time: {}'.format(test_loss, test_acc, time.time()-start))

In [None]:
atk = torchattacks.FGSM(server_model, eps=0.25)
print("-"*70)
print(atk)
start = time.time()
test_loss, test_acc = test(server_model, test_loader, 1, args, atk)
print('Test Average loss: {:.4f}, Accuracy: {:.0f}%, Time: {}'.format(test_loss, test_acc, time.time()-start))

In [None]:
data, target = next(iter(test_loader))

In [None]:
# atk = torchattacks.FGSM(server_model, eps=0.25)
# data_ = atk(data, target)
data_ = data.numpy()
target_ = target.numpy()
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(8,8))
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.axis("off")
    plt.title(target_[i])
    plt.imshow(data_[i][0], cmap='gray')

In [None]:
atk = torchattacks.FGSM(server_model, eps=0.10)
data_ = atk(data, target)
data_ = data_.numpy()
target_ = target.numpy()
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(8,8))
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.axis("off")
    plt.title(target_[i])
    plt.imshow(data_[i][0], cmap='gray')

In [None]:
atk = torchattacks.FGSM(server_model, eps=0.20)
data_ = atk(data, target)
data_ = data_.numpy()
target_ = target.numpy()
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(8,8))
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.axis("off")
    plt.title(target_[i])
    plt.imshow(data_[i][0], cmap='gray')

In [None]:
atk = torchattacks.FGSM(server_model, eps=0.25)
data_ = atk(data, target)
data_ = data_.numpy()
target_ = target.numpy()
import matplotlib.pyplot as plt
figure = plt.figure(figsize=(8,8))
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.axis("off")
    plt.title(target_[i])
    plt.imshow(data_[i][0], cmap='gray')

In [None]:
data, target = next(iter(test_loader))

worker = sy.VirtualWorker(hook, id="worker2")

data = data.send(worker)
target = target.send(worker)

print(target.location)
print(target)
tt = target.get()
print(tt)

# atk = torchattacks.FGSM(server_model, eps=0.25)
# data, target = data.get(), target.get()
# data_ = atk(data, target)
# data_ = data_.numpy()
# target_ = target.numpy()
# print(data_.shape)
# print(target_.shape)

# import matplotlib.pyplot as plt
# figure = plt.figure(figsize=(8,8))
# for i in range(20):
#     plt.subplot(4, 5, i + 1)
#     plt.axis("off")
#     plt.title(target_[i])
#     plt.imshow(data_[i][0], cmap='gray')