In [1]:
from federated_learning.utils import SHAPUtil, experiment_util, Visualizer
from federated_learning import ClientPlane, Configuration, ObserverConfiguration
from federated_learning.server import Server
from datetime import datetime
import random

In [None]:
class ExperimentalSummary():

In [15]:
import numpy as np
def diag_mean_values(shap_values, server_shap = False):
    diag_mean = []
    diag_diff_mean = []
    for row_idx, row in enumerate(shap_values):
        for img_idx, image in enumerate(row):
            if row_idx == img_idx:
                diag_mean.append(np.median(shap_values[row_idx][img_idx][np.nonzero(shap_values[row_idx][img_idx])]))
                if server_shap:
                    arr = np.subtract(image, server_shap[row_idx][img_idx])
                    diag_diff_mean.append(np.median(arr[np.nonzero(arr)]))
    print(diag_mean)
    print(diag_diff_mean)
    
def convolve_values(s_client, s_server):
    convolution = []
    shap_subtract = np.subtract(s_client, s_server)
    norms = np.linalg.norm(shap_subtract, axis=1)
    normed_shap = np.subtract(s_client, s_server)
    for row_idx, row in enumerate(shap_subtract):
        for img_idx, image in enumerate(row):
            if row_idx == img_idx:
                convolution.append(np.sum(image.flatten()))
    print(convolution)

    


## Configurations

In [7]:
config = Configuration()
config.FROM_LABEL = 4
config.TO_LABEL = 5
config.POISONED_CLIENTS = 10
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader)
observer_config = ObserverConfiguration()
visualizer = Visualizer(shap_util)

MNIST training data loaded.
MNIST test data loaded.


## Experimental Setup

In [8]:
server = Server(config, observer_config,data.train_dataloader, data.test_dataloader, shap_util)
client_plane = ClientPlane(config, observer_config, data, shap_util)
client_plane.poison_clients()
clean_idx = experiment_util.select_random_clean(client_plane, config, 10)
poisoned_idx = experiment_util.select_poisoned(client_plane, 10)
print(clean_idx)

Create 200 clients with dataset of size 300
Poison 10/200 clients
Flip 100.0% of the 4 labels to 5
[ 75 196  77  87 137  97  55 159  31  37]
[108, 16, 121, 162, 29, 175, 104, 127, 91, 48]


In [None]:
#len(client_plane.clients[0].train_dataloader.dataset.dataset.targets[client_plane.clients[0].train_dataloader.dataset.dataset.targets == 5])

## Experiment

In [16]:
#for i in range(15):
#    experiment_util.run_round(client_plane, server, i+1)
for i in range(5):
    experiment_util.run_round(client_plane, server, i+1)
    server.test()
    print("Server {}".format(i+1))
    server_shap = server.get_shap_values()
    #visualizer.plot_shap_values(server_shap)
    client_plane.update_clients(server.get_nn_parameters())
    for j in range(1):
        print("Client Clean {}".format(i+1))
        client_plane.clients[clean_idx[j]].train(i+1)
        clean_client_shap = client_plane.clients[clean_idx[j]].get_shap_values()
        convolve_values(clean_client_shap, server_shap)
        #visualizer.plot_shap_values(clean_client_shap)
        #visualizer.compare_shap_values(clean_client_shap, server_shap)
    client_plane.update_clients(server.get_nn_parameters())    
    for j in range(1):
        print("Client Poisoned {}".format(i+1))
        client_plane.clients[poisoned_idx[j]].train(i+1)
        poisoned_client_shap = client_plane.clients[poisoned_idx[j]].get_shap_values()
        #visualizer.plot_shap_values(poisoned_client_shap)
        #visualizer.compare_shap_values(poisoned_client_shap, server_shap)
        convolve_values(poisoned_client_shap, server_shap)
    print("Round {} finished".format(i+1))


Test set: Average loss: 0.0003, Accuracy: 9415/10000 (94%)

Server 1
Client Clean 1
[-0.004380190790798322, -0.20036727593768427, -0.004485519679174016, -0.302339301003963, -0.0003993992544772951, 0.010438596160126878, -0.012211120367652484, -0.04504875549427023, 0.03665115551720932, -0.03939160862273189]
Client Poisoned 1
[-0.000816721596110126, -0.2057804693262631, -0.0022164908827472813, 0.049921215988741, -0.7903016121173079, -0.13275992653938307, -0.08305722087333756, -0.1777078956162419, -0.0889569565577264, -0.11322076118135027]
Round 1 finished

Test set: Average loss: 0.0003, Accuracy: 9404/10000 (94%)

Server 2
Client Clean 2
[-0.012109553039119092, -0.1625669348169787, -0.013338611212023732, -0.30611625270340515, -0.03818214033859979, -0.012931845540720133, -0.09032663451926481, -0.10528357900504126, -0.2070120151794086, 0.003908809871264296]
Client Poisoned 2
[0.0004764699953390128, -0.027580840326681733, -0.02191239834535308, 0.0702014855760753, -0.6960191056289224, -0.10