# Experiment 3: Poisoning Attacks in Federated Learning

* Poisoning of
    * experiment_id = 1; label flipping 5 -> 4,  poisoned_clients[10, 25, 50, 75, 100, 125], rounds 50
    * experiment_id = 0; label flipping 5 -> 4,  poisoned_clients[0, 10, 25, 50, 75, 100] rounds 200

In [None]:
from federated_learning.utils import SHAPUtil
from federated_learning import ClientPlane, Configuration
from federated_learning.server import Server
from datetime import datetime
import random

In [None]:
class ObserverConfiguration():
    experiment_type = "shap_fl_poisoned"
    experiment_id = 0
    test = False
    dataset_type = "MNIST"
    
    # Client Configurations 
    client_name = "client"
    client_type = "client"
    
    # Server Configurations 
    server_name = "server"
    server_type = "server"
    server_id = 0

In [None]:
config = Configuration()
config.FROM_LABEL = 5
config.TO_LABEL = 4
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader)
observer_config = ObserverConfiguration()

In [None]:
def set_rounds(rounds):
    client_plane.set_rounds(rounds)
    server.set_rounds(rounds)
    
def update_configs():
    client_plane.update_config(config, observer_config)
    server.update_config(config, observer_config)
    
def run_round(rounds):
    # Federated Learning Round 
    client_plane.update_clients(server.get_nn_parameters())
    selected_clients = server.select_clients()
    client_parameters = client_plane.train_selected_clients(selected_clients)
    server.aggregate_model(client_parameters)

def select_random_clean():
    idx = 0
    while idx in client_plane.poisoned_clients:
        idx = random.randint(0,config.NUMBER_OF_CLIENTS)
    return idx

def train_poisoned_client_only(rounds): 
    client_plane.clients[client_plane.poisoned_clients[0]].train(rounds)
    client_plane.clients[client_plane.poisoned_clients[0]].push_metrics()
    if rounds == 5: 
        print(client_plane.clients[client_plane.poisoned_clients[0]].train_dataloader.dataset.dataset.targets[client_plane.clients[client_plane.poisoned_clients[0]].poisoned_indices][0])
    
def train_clean_client_only(idx, rounds): 
    client_plane.clients[idx].train(rounds)
    client_plane.clients[idx].push_metrics()

In [None]:
server = Server(config, observer_config, data.test_dataloader, shap_util)
server.create_default_model()
client_plane = ClientPlane(config, observer_config, data, shap_util)

In [None]:
for num_p_clients in [0, 10, 25, 50, 75, 100]:
    server = Server(config, observer_config, data.test_dataloader, shap_util)
    client_plane = ClientPlane(config, observer_config, data, shap_util)
    config.POISONED_CLIENTS = num_p_clients
    update_configs()
    client_plane.poison_clients()
    clean_idx = select_random_clean()
    
    for i in range(200):
        set_rounds(i+1)
        run_round(i+1)
        if (i+1)%5 == 0:
            server.test()
            server.push_metrics()
            client_plane.update_clients(server.get_nn_parameters())
            if num_p_clients > 0:
                train_poisoned_client_only(i+1)
            train_clean_client_only(clean_idx, i+1)
        print("Round {} finished".format(i+1))
        
        # Monitor a poisoned client 
    server.load_default_model()
    client_plane.load_default_client_nets()
    client_plane.reset_poisoning_attack()

In [None]:
server.get_shap_values()