In [1]:
from federated_learning.utils import SHAPUtil, experiment_util
from federated_learning import ClientPlane, Configuration, ObserverConfiguration
from federated_learning.server import Server
from datetime import datetime
import random

In [2]:
import shap
class Visualizer():
    def plot_shap_images(shap_indices, shap_images):
        """
        Plot sample images and their target labels
        """
        import matplotlib.pyplot as plt
        fig = plt.figure()
        for i, idx in enumerate(shap_indices):
            plt.subplot(3,4,i+1)
            plt.tight_layout()
            plt.imshow(shap_images[idx][0], cmap='gray', interpolation='none')
            plt.title("Ground Truth: {}".format(self.targets[idx]))
            plt.xticks([])
            plt.yticks([])
        plt.show()
        
    def plot_shap_values(self,shap_images, shap_values, file=None):
        """
        Plot SHAP values and image
        :param shap_values: name of file
        :type shap_values: Tensor
        :param file: name of file
        :type file: os.path
        """
        import matplotlib.pyplot as plt
        shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
        test_numpy = np.swapaxes(np.swapaxes(shap_images.numpy(), 1, -1), 1, 2)
        if file:
            shap.image_plot(shap_numpy, -test_numpy, show=False)
            if not os.path.exists(os.path.dirname(file)):
                os.makedirs(os.path.dirname(file))
            plt.savefig(file)
        else: 
            shap.image_plot(shap_numpy, -test_numpy)

## Configurations

In [3]:
config = Configuration()
config.FROM_LABEL = 5
config.TO_LABEL = 4
config.POISONED_CLIENTS = 10
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader)
observer_config = ObserverConfiguration()

MNIST training data loaded.
MNIST test data loaded.


## Experimental Setup

In [4]:
server = Server(config, observer_config, data.test_dataloader, shap_util)
client_plane = ClientPlane(config, observer_config, data, shap_util)
client_plane.poison_clients()
clean_idx = experiment_util.select_random_clean(1, config, client_plane)
poisoned_idx = experiment_util.select_poisoned(1, client_plane)
print(clean_idx, poisoned_idx)

Load model successfully
Create 200 clients with dataset of size 300
Poison 10/200 clients
Flip 100.0% of the 5 labels to 4
[69, 193, 0, 159, 113, 74, 133, 66, 89, 58] [135 197  40 108 171  65  39 147 116  90]


## Experiment

In [None]:
for i in range(25):
    experiment_util.run_round(i+1, client_plane, server)
    if (i+1)%5 == 0:
        server.test()
        client_plane.update_clients(server.get_nn_parameters())
        for i in range(1):
            experiment_util.train_client(clean_idx[i], i+1, client_plane)
            experiment_util.train_client(poisoned_idx[i], i+1, client_plane)
    print("Round {} finished".format(i+1))