In [1]:
from federated_learning.utils import SHAPUtil
from federated_learning import ClientPlane

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import torch.nn as nn
from torch import device
from federated_learning.nets import MNISTCNN, FashionMNISTCNN, MNISTFFNN
from federated_learning.dataset import MNISTDataset, FashionMNISTDataset
from federated_learning.client.ffnn_client import FFNNClient

In [4]:
class Configuration():
    
    # Dataset Config
    BATCH_SIZE_TRAIN = 132
    BATCH_SIZE_TEST = 1000
    DATASET = MNISTDataset
    
    #MNIST_FASHION_DATASET Configurations
    MNIST_FASHION_DATASET_PATH = os.path.join('./data/mnist_fashion')
    MNIST_FASHION_LABELS = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker',  'Bag', 'Ankle Boot']
    
    #MNIST_DATASET Configurations
    MNIST_DATASET_PATH = os.path.join('./data/mnist')
    
    #CIFAR_DATASET Configurations
    CIFAR10_DATASET_PATH = os.path.join('./data/cifar10')
    CIFAR10_LABELS = ['Plane', 'Car', 'Bird', 'Cat','Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
    
    #Model Training Configurations
    N_EPOCHS = 4
    LEARNING_RATE = 0.01
    MOMENTUM = 0.5
    LOG_INTERVAL = 10
    CRITERION = nn.CrossEntropyLoss
    NETWORK = MNISTFFNN
    NUMBER_TARGETS = 10
    
    #Local Environment Configurations
    NUMBER_OF_CLIENTS = 3
    CLIENT_TYPE = FFNNClient
    DEVICE = device('cpu')
    
    #Label Flipping Attack
    DATA_POISONING_PERCENTAGE = 1
    FROM_LABEL = 5
    TO_LABEL = 4

In [5]:
config = Configuration()
data = config.DATASET(config)
sim_env = ClientPlane(config, data)

MNIST training data loaded.
MNIST test data loaded.
Create 1 clients


In [6]:
sim_env.clients[0].test()
for epoch in range(1, config.N_EPOCHS + 1):
    sim_env.clients[0].train(epoch)
    sim_env.clients[0].test()


Test set: Average loss: 0.0023, Accuracy: 1019/10000 (10%)


Test set: Average loss: 0.0007, Accuracy: 8807/10000 (88%)


Test set: Average loss: 0.0004, Accuracy: 9303/10000 (93%)


Test set: Average loss: 0.0003, Accuracy: 9469/10000 (95%)


Test set: Average loss: 0.0002, Accuracy: 9565/10000 (96%)



In [7]:
sim_env.reset_client_nets()
sim_env.poison_clients()

Reset network successfully
1/1 clients poisoned
Label Flipping 100.0% from 5 to 4


In [8]:
sim_env.clients[0].test()
for epoch in range(1, config.N_EPOCHS + 1):
    sim_env.clients[0].train(epoch)
    sim_env.clients[0].test()


Test set: Average loss: 0.0023, Accuracy: 899/10000 (9%)


Test set: Average loss: 0.0010, Accuracy: 8194/10000 (82%)


Test set: Average loss: 0.0010, Accuracy: 8450/10000 (84%)


Test set: Average loss: 0.0010, Accuracy: 8593/10000 (86%)


Test set: Average loss: 0.0010, Accuracy: 8685/10000 (87%)

