In [11]:
from federated_learning.utils import SHAPUtil, VMUtil
from federated_learning import ClientPlane, ObserverConfiguration

In [8]:
%load_ext autoreload
%autoreload 2

# Configuration

In [6]:
import os
import torch.nn as nn
from torch import device
from federated_learning.nets import MNISTFFNN, FMNISTCNN, FashionMNISTCNN
from federated_learning.dataset import MNISTDataset, FashionMNISTDataset
from federated_learning.client import FFNNClient, CNNClient

In [12]:
class Configuration():
    
    # Dataset Config
    BATCH_SIZE_TRAIN = 64
    BATCH_SIZE_TEST = 1000
    DATASET = MNISTDataset
    
    #MNIST_FASHION_DATASET Configurations
    MNIST_FASHION_DATASET_PATH = os.path.join('./data/mnist_fashion')
    MNIST_FASHION_LABELS = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker',  'Bag', 'Ankle Boot']
    
    #MNIST_DATASET Configurations
    MNIST_DATASET_PATH = os.path.join('./data/mnist')
    
    #CIFAR_DATASET Configurations
    CIFAR10_DATASET_PATH = os.path.join('./data/cifar10')
    CIFAR10_LABELS = ['Plane', 'Car', 'Bird', 'Cat','Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
    
    #Model Training Configurations
    N_EPOCHS = 4
    LEARNING_RATE = 0.01
    MOMENTUM = 0.5
    LOG_INTERVAL = 2
    CRITERION = nn.CrossEntropyLoss
    NETWORK = MNISTFFNN
    NUMBER_TARGETS = 10
    
    #Local Environment Configurations
    NUMBER_OF_CLIENTS = 1
    CLIENT_TYPE = FFNNClient
    DEVICE = device('cpu')
    
    #Label Flipping Attack
    POISONED = False
    POISONED_CLIENTS = 0
    DATA_POISONING_PERCENTAGE = 1
    FROM_LABEL = 5
    TO_LABEL = 4
    
    #Victoria Metrics Configurations
    VM_URL = os.getenv('VM_URL') #URL settings in docker-compose.yml

# Monitoring

In [None]:
class ObserverConfigurations():
    experiment_type = "datasize_shap"
    experiment_id = 0
    test = True
    
    # Client Configurations 
    client_name = "client"
    client_type = "client"
    
    


In [None]:
class Observer(VMUtil):
    def __init__(self, config, observer_config):
        super(Observer, self).__init__(config)
        self.config = config
        self.observer_config = observer_config
        self.experiment_type = self.observer_config.experiment_type
        self.experiment_id = self.observer_config.experiment_id
        self.poisoned_clients = self.config.POISONED_CLIENTS
        self.test = self.observer_config.test

In [None]:
from datetime import datetime

class ClientObserver(Observer):
    def __init__(self, config, observer_config, client_id, poisoned, dataset_size):
        super(ClientObserver, self).__init__(config, observer_config)
        self.name = self.observer_config.client_name 
        self.client_id = client_id
        self.poisoned = poisoned
        self.poisoned_data = self.config.DATA_POISONING_PERCENTAGE
        self.dataset_size = dataset_size
        self.type = self.observer_config.client_type
        self.metric_labels = { 
            "accuracy": "",
            "recall" : ",target={}",
            "precision" : ",target={}",
            "shap_pos": ",target={},source={}",
            "shap_neg": ",target={},source={}",
            "shap_mean": ",target={},source={}"
        }
        self.metrics = ["accuracy", "recall", "precision", "shap_pos", "shap_neg", "shap_mean"]
    
    def get_labels(self): 
        return "client_id={},test={},poisoned={},poisoned_data={},dataset_size={},type={},experiment_type={},experiment_id={},poisoned_clients={}".format(
            self.client_id,
            self.test,
            self.poisoned,
            self.poisoned_data,
            self.dataset_size,
            self.type,
            self.experiment_type,
            self.experiment_id,
            self.poisoned_clients,
        )
    
    def get_datastr(self, accuracy, recall, precision, shap_pos, shap_neg, shap_mean):
        timestamp = int(datetime.timestamp(datetime.now()))
        data = []
        labels = self.get_labels()
        datastr = "{},{} {} {}"
        data.append(datastr.format(self.name, labels, "accuracy=%f"%(accuracy), timestamp))
        for i in range(self.config.NUMBER_TARGETS): 
            data.append(datastr.format(self.name, labels + self.metric_labels["recall"].format(i), "recall=%f"%(recall[i]), timestamp))
            data.append(datastr.format(self.name, labels + self.metric_labels["precision"].format(i), "precision=%f"%(precision[i]), timestamp))
            for j in range(self.config.NUMBER_TARGETS): 
                data.append(datastr.format(self.name, labels + self.metric_labels["shap_pos"].format(i, j), "shap_pos=%f"%(shap_pos[i][j]), timestamp))
                data.append(datastr.format(self.name, labels + self.metric_labels["shap_neg"].format(i, j), "shap_neg=%f"%(shap_neg[i][j]), timestamp))
                data.append(datastr.format(self.name, labels + self.metric_labels["shap_mean"].format(i, j), "shap_mean=%f"%(shap_mean[i][j]), timestamp))
        return data
    
    def push_metrics(self, accuracy, recall, precision, shap_pos, shap_neg, shap_mean):
        data = self.get_datastr(accuracy, recall, precision, shap_pos, shap_neg, shap_mean)
        print(data[0])
        for d in data:
            self.push_data(d)
        print("Successfully pushed client data to victoria metrics")
        
        
        
                
                
        
        
    

In [13]:
config = Configuration()
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader)
observer_config = ObserverConfiguration()
client_plane = ClientPlane(config, observer_config, data, shap_util)

MNIST training data loaded.
MNIST test data loaded.
Create 1 clients with dataset of size 60000


In [None]:
client_plane.clients[0].test()
for epoch in range(1, config.N_EPOCHS + 1):
    client_plane.clients[0].train(epoch)
    client_plane.clients[0].test()

In [None]:
client_plane.clients[0].analize()

In [None]:
client_plane.clients[0].get_shap_values()

In [None]:
client_plane.clients[0].analize_shap_values()

In [15]:
vm = VMUtil(config) 
vm.delete_old_metrics('client', ["accuracy", "recall","precision", "shap_pos", "shap_neg", "shap_mean"], "test = 'True'")

Delete old metrics from client_accuracy with test = 'True'
Delete old metrics from client_recall with test = 'True'
Delete old metrics from client_precision with test = 'True'
Delete old metrics from client_shap_pos with test = 'True'
Delete old metrics from client_shap_neg with test = 'True'
Delete old metrics from client_shap_mean with test = 'True'


In [None]:
import numpy as np
array = np.array(client_plane.clients[0].positive_shap)/(28*28)
array.round(2)

In [None]:
obser_config = ObserverConfigurations()
observer = ClientObserver(config, obser_config, 0, False, 7500)

In [None]:
observer.push_metrics(client_plane.clients[0].accuracy, client_plane.clients[0].recall, client_plane.clients[0].precision, client_plane.clients[0].positive_shap, client_plane.clients[0].negative_shap, client_plane.clients[0].non_zero_mean)