In [1]:
from federated_learning.utils import SHAPUtil
from federated_learning import ClientPlane, Configuration

In [2]:
class ObserverConfiguration():
    experiment_type = "shap_clean_run"
    experiment_id = 0
    test = True
    dataset_type = "MNIST"
    
    # Client Configurations 
    client_name = "client"
    client_type = "client"

In [3]:
config = Configuration()
data = config.DATASET(config)
shap_util = SHAPUtil(data.test_dataloader)
observer_config = ObserverConfiguration()

MNIST training data loaded.
MNIST test data loaded.


In [4]:
class ModelAggregator():
    def model_avg(self, parameters):
        new_params = {}
        for name in parameters[0].keys():
            new_params[name] = sum([param[name].data for param in parameters]) / len(parameters)
        return new_params

In [5]:
from numpy.random import default_rng
class ClientSelector():
    def random_selector(self, number_of_clients, clients_per_round):
        rng = default_rng()
        return rng.choice(number_of_clients, size=clients_per_round, replace=False)


In [9]:
import torch
from pathlib import Path
import os
class Server():
    def __init__(self, config, shap_util):
        self.config = config
        self.default_model_path = os.path.join(self.config.TEMP, 'models', "{}.model".format(self.config.MODELNAME))
        self.net = self.load_default_model()
        self.aggregator = ModelAggregator()
        self.selector = ClientSelector()
        self.shap_util = shap_util
        self.rounds = 0
        self.e = []
    
    def set_rounds(self, rounds):
        self.rounds = rounds
        
    def create_default_model(self):
        Path(os.path.dirname(self.default_model_path)).mkdir(parents=True, exist_ok=True)
        torch.save(self.net.state_dict(), self.default_model_path)
        print("default model saved to:{}".format(os.path.dirname(self.default_model_path)))
    
    def load_default_model(self):
        """
        Load a model from a file.
        """
        if os.path.exists(self.default_model_path):
            try:
                model = self.config.NETWORK()
                model.load_state_dict(torch.load(self.default_model_path))
                model.eval()
                print("Load model successfully")
            except:
                print("Couldn't load model")
        else:
            print("Could not find model: {}".format(self.default_model_path))   
        return model
            
    def get_nn_parameters(self):
        """
        Return the NN's parameters.
        """
        return self.net.state_dict()
    
    def update_nn_parameters(self, new_params):
        """
        Update the NN's parameters.

        :param new_params: New weights for the neural network
        :type new_params: dict
        """
        self.net.load_state_dict(new_params, strict=True)
        
    def select_clients(self):
        return self.selector.random_selector(self.config.NUMBER_OF_CLIENTS, self.config.CLIENTS_PER_ROUND)

    def aggregate_model(self, client_parameters): 
        new_parameters = self.aggregator.model_avg(client_parameters)
        self.update_nn_parameters(new_parameters)
        if (self.rounds + 1)%50 == 0:
            print("Model aggregation in round {} was successful".format(self.rounds+1))
        
    def get_shap_values(self):
        """
        Calculate SHAP values and SHAP image predictions 
        """
        if not self.e: 
            self.e = self.shap_util.set_deep_explainer(self.net)
        self.shap_values = self.shap_util.get_shap_values(self.e)
        self.shap_prediction = self.shap_util.predict(self.net)
    
    def set_explainer(self): 
        self.e = self.shap_util.deep_explainer(self.net)

In [None]:
server = Server(config, shap_util)
client_plane = ClientPlane(config, observer_config, data, shap_util)

def run_round():
    client_plane.update_clients(server.get_nn_parameters())
    selected_clients = server.select_clients()
    client_parameters = client_plane.train_selected_clients(selected_clients)
    server.aggregate_model(client_parameters)
    
for i in range(200):
    client_plane.set_rounds(i)
    server.set_rounds(i)
    run_round()

Load model successfully
Create 200 clients with dataset of size 300
Model aggregation in round 49 was successful
Model aggregation in round 99 was successful
Model aggregation in round 149 was successful


In [None]:
server.get_shap_values()