In [1]:
from federated_learning.utils import SHAPUtil, experiment_util
from federated_learning import ClientPlane, Configuration, ObserverConfiguration
from federated_learning.server import Server
from datetime import datetime
import random

In [None]:
import shap
import numpy as np
class Visualizer():
    
    def __init__(self, shap_util):
        self.shap_util = shap_util
        
    def plot_shap_images(shap_indices, shap_images):
        """
        Plot sample images and their target labels
        """
        import matplotlib.pyplot as plt
        fig = plt.figure()
        for i, idx in enumerate(shap_indices):
            plt.subplot(3,4,i+1)
            plt.tight_layout()
            plt.imshow(shap_images[idx][0], cmap='gray', interpolation='none')
            plt.title("Ground Truth: {}".format(self.targets[idx]))
            plt.xticks([])
            plt.yticks([])
        plt.show()
        
    def plot_shap_values(self, shap_values, file=None):
        """
        Plot SHAP values and image
        :param shap_values: name of file
        :type shap_values: Tensor
        :param file: name of file
        :type file: os.path
        """
        import matplotlib.pyplot as plt
        shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
        test_numpy = np.swapaxes(np.swapaxes(self.shap_util.shap_images.numpy(), 1, -1), 1, 2)
        if file:
            shap.image_plot(shap_numpy*1000, -test_numpy, show=False)
            if not os.path.exists(os.path.dirname(file)):
                os.makedirs(os.path.dirname(file))
            plt.savefig(file)
        else: 
            shap.image_plot(shap_numpy, -test_numpy)
            
    def compare_shap_values(self, shap_values, server_shap, file=None):
        """
        Plot SHAP values and image
        :param shap_values: name of file
        :type shap_values: Tensor
        :param file: name of file
        :type file: os.path
        """
        import matplotlib.pyplot as plt
        for row_idx, row in enumerate(shap_values):
            for img_idx, image in enumerate(row):
                shap_values[row_idx][img_idx]= np.subtract(image, server_shap[row_idx][img_idx])
        shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
        test_numpy = np.swapaxes(np.swapaxes(self.shap_util.shap_images.numpy(), 1, -1), 1, 2)
        if file:
            shap.image_plot(shap_numpy, -test_numpy, show=False)
            if not os.path.exists(os.path.dirname(file)):
                os.makedirs(os.path.dirname(file))
            plt.savefig(file)
        else: 
            shap.image_plot(shap_numpy, -test_numpy)

In [None]:
class SHAPStats():
    

## Configurations

In [None]:
config = Configuration()
config.FROM_LABEL = 5
config.TO_LABEL = 4
config.POISONED_CLIENTS = 10
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader)
observer_config = ObserverConfiguration()
visualizer = Visualizer(shap_util)

## Experimental Setup

In [None]:
server = Server(config, observer_config,data.train_dataloader, data.test_dataloader, shap_util)
client_plane = ClientPlane(config, observer_config, data, shap_util)
client_plane.poison_clients()
clean_idx = experiment_util.select_random_clean(10, config, client_plane)
poisoned_idx = experiment_util.select_poisoned(10, client_plane)
print(clean_idx, poisoned_idx)

In [None]:
clean_idx = experiment_util.select_random_clean(10, config, client_plane)
poisoned_idx = experiment_util.select_poisoned(10, client_plane)
print(clean_idx, poisoned_idx)


## Experiment

In [None]:
for i in range(20):
    experiment_util.run_round(i+1, client_plane, server)
for i in range(1):
    experiment_util.run_round(i+1, client_plane, server)
    server.test()
    print("Server {}".format(i+1))
    server_shap = server.get_shap_values()
    visualizer.plot_shap_values(server_shap)
    client_plane.update_clients(server.get_nn_parameters())
    for j in range(10):
        print("Client Clean {}".format(i+1))
        client_plane.clients[clean_idx[j]].train(i+1)
        client_shap = client_plane.clients[clean_idx[j]].get_shap_values()
        visualizer.compare_shap_values(client_shap, server_shap)
    for j in range(10):
        print("Client Poisoned {}".format(i+1))
        client_plane.clients[poisoned_idx[j]].train(i+1)
        client_shap = client_plane.clients[poisoned_idx[j]].get_shap_values()
        visualizer.compare_shap_values(client_shap, server_shap)
    print("Round {} finished".format(i+1))

In [None]:
for i in range(10):
    experiment_util.run_round(i+1, client_plane, server)

In [None]:
server_shap = server.get_shap_values()
client_plane.update_clients(server.get_nn_parameters())
for j in range(1):
    print("Client Clean {}".format(10))
    client_plane.clients[clean_idx[j]].train(10)
    client_shap = client_plane.clients[clean_idx[j]].get_shap_values()
    visualizer.compare_shap_values(client_shap, server_shap)