In [1]:
from federated_learning.utils import SHAPUtil
from federated_learning import LocalEnvironment

In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch.nn as nn
from torch import device
from federated_learning.nets import MNISTCNN, FashionMNISTCNN, MNISTFFNN
from federated_learning.dataset import MNISTDataset, FashionMNISTDataset
from federated_learning.dataloader import MNISTDataloader, FashionMNISTDataloader
from federated_learning.client.ffnn_client import FFNNClient

In [3]:
class Configuration():
    
    # Dataset Config
    BATCH_SIZE_TRAIN = 132
    BATCH_SIZE_TEST = 1000
    DATASET = MNISTDataset
    
    # DEPRICATED CONFIG
    DATALOADER = MNISTDataloader
    
    
    #MNIST_FASHION_DATASET Configurations
    MNIST_FASHION_DATASET_PATH = os.path.join('./data/mnist_fashion')
    MNIST_FASHION_LABELS = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker',  'Bag', 'Ankle Boot']
    
    #MNIST_DATASET Configurations
    MNIST_DATASET_PATH = os.path.join('./data/mnist')
    
    #CIFAR_DATASET Configurations
    CIFAR10_DATASET_PATH = os.path.join('./data/cifar10')
    CIFAR10_LABELS = ['Plane', 'Car', 'Bird', 'Cat','Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
    
    #Model Training Configurations
    N_EPOCHS = 4
    LEARNING_RATE = 0.01
    MOMENTUM = 0.5
    LOG_INTERVAL = 10
    CRITERION = nn.CrossEntropyLoss
    NETWORK = MNISTFFNN
    NUMBER_TARGETS = 10
    
    #Local Environment Configurations
    NUMBER_OF_CLIENTS = 1
    CLIENT_TYPE = FFNNClient
    DEVICE = device('cpu')
    DATA_POISONING_PERCENTAGE = 1
    
    #Label Flipping Attack
    FROM_LABEL = 5
    TO_LABEL = 4

In [4]:
config = Configuration()
data = config.DATASET(config)
sim_env = LocalEnvironment(config, data)

MNIST training data loaded.
MNIST test data loaded.
Create 1 clients


In [5]:
sim_env.clients[0].test()
for epoch in range(1, config.N_EPOCHS + 1):
    sim_env.clients[0].train(epoch)
    sim_env.clients[0].test()


Test set: Average loss: 0.0023, Accuracy: 349/10000 (3%)


Test set: Average loss: 0.0006, Accuracy: 8948/10000 (89%)


Test set: Average loss: 0.0003, Accuracy: 9376/10000 (94%)


Test set: Average loss: 0.0002, Accuracy: 9500/10000 (95%)


Test set: Average loss: 0.0002, Accuracy: 9583/10000 (96%)



In [6]:
sim_env.reset_client_nets()

In [7]:
sim_env.clients[0].test()
for epoch in range(1, config.N_EPOCHS + 1):
    sim_env.clients[0].train(epoch)
    sim_env.clients[0].test()


Test set: Average loss: 0.0023, Accuracy: 1033/10000 (10%)


Test set: Average loss: 0.0006, Accuracy: 8911/10000 (89%)


Test set: Average loss: 0.0003, Accuracy: 9354/10000 (94%)


Test set: Average loss: 0.0002, Accuracy: 9479/10000 (95%)


Test set: Average loss: 0.0002, Accuracy: 9580/10000 (96%)



In [6]:
sim_env.reset_client_nets()
sim_env.poison_clients()

1/1 clients poisoned
tensor([[4],
        [4],
        [4],
        ...,
        [5],
        [5],
        [5]])
Label Flipping 50.0% from 5 to 4


In [None]:
sim_env.clients[0].test()
for epoch in range(1, config.N_EPOCHS + 1):
    sim_env.clients[0].train(epoch)
    sim_env.clients[0].test()


Test set: Average loss: 0.0023, Accuracy: 1009/10000 (10%)

