In [1]:
import tensorflow as tf
import tensorflow.keras as tk
from chexpert_parser import load_dataset
# Evita di allocare tutta la memoria video a tensorflow (Chiamare solo al primo import di tf)
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
else:
    print("No GPU found, model running on CPU")


In [23]:
class SplitModel():
    def __init__(self, configuration, model_folder='./models/'):
        ''' 
            Represents an Entity partecipating in Split Learning (either a Server or a Client).
        '''
        self.model_folder = model_folder
        if configuration['type'] == 'server':
            self.is_server = True
            self.is_client = False
            self.base_server_model = configuration['server_model']
            
        else:
            self.is_server = False
            self.is_client = True
            self.base_model_bottom = configuration['model_bottom']
            self.base_model_top = configuration['model_top']
            
        
        self.split_layer_top = configuration['split_layer_top']
        self.split_layer_bottom = configuration['split_layer_bottom']
        self.name = configuration['name']
        
        self.input_shape = configuration['input_shape'] # Shape of input data
        self.output_shape = configuration['output_shape'] # Shape of predictions
        self.bottom_output_shape = None # Will be computed according to the model
        self.top_input_shape = None # Will be computed according to the model
        
        self.server_model = None # Model that is owned by the server
        self.model_bottom = None # Model that receives the data as input
        self.model_top = None # Optional: Model that produces the predictions in the U topology
        
        self.current_epoch = 1
        self.current_batch = None # Stores the current batch that is used for training the model
        
    
    
    def _cutmodel(self, model_arch, cut_input=None, cut_output=None):
        ''' Given a base model (from tf.keras.applications.*) creates a new model 
            by truncating the original model between cut_input (included) and cut_output (excluded) layers id.
            If one of the cuts is None, then the cut will be considered as either the original input or output layer '''
        
        # FIXME: THIS WORKS ONLY WITH SEQUENTIAL MODELS - WE NEED TO FIND A WAY TO MAKE IT WORK WITH PRETRAINED NETWORKS
        base_model = model_arch(input_shape=self.input_shape, include_top=False, weights='imagenet')
        #print("Cutting {} between {} and {}".format(str(base_model), cut_input, cut_output))
        model = tk.models.Sequential() # TODO: Check if the result is actually the same with all common architectures
        for l, layer in enumerate(base_model.layers[cut_input:cut_output]):
            try:
                model.add(layer)
                #print("{}:{} OK: {}".format(l, layer.name, layer.input_shape))
            except:
                print("{}:{} ERROR: {}".format(l, layer.name, layer.input_shape))
        # Calculate the input shape for THIS model. Input layers provide a list of tuples, others provide the tuple. 
        # Also removing the first "None" dimension relative to batch size
        input_shape = self.input_shape if cut_input is None or cut_input == 0 else base_model.layers[cut_input].input_shape
        model.build(input_shape=input_shape)
        return model        
    
    
    def setup(self, load_model='last', seed=1234567890):
        '''
        Setup the client or server according to the current configuration.
        - Loads the architecture(s) and optimizers, according to the provided split layers
        - Loads a previous checkpoint (if any)
        - Initializes Tensorboard Writers
        '''
        # Loading and splitting the models
        
        if self.is_client:
            # We are a client. We have a bottom model:
            self.model_bottom = self._cutmodel(self.base_model_bottom, cut_input=None, cut_output=self.split_layer_bottom)
            
            # We can also have a top model:
            if self.base_model_top is not None:
                self.model_top = self._cutmodel(self.base_model_top, cut_input=self.split_layer_top, cut_output=None)
            
        if self.is_server:
            # We are the server. If the topology is not a U, split_layer_top will be None (since the server will send predictions)
            self.server_model = self._cutmodel(self.base_server_model, cut_input=self.split_layer_bottom, cut_output=self.split_layer_top)
    
        # TODO: Eventually move these into the config dict instead of hardcoding them here.
        self.loss_fn = tk.losses.BinaryCrossentropy()
        self.optimizer = tk.optimizers.SGD(1e-3)
        self.metrics = [tk.metrics.BinaryCrossentropy()]
        
        # TODO: Initialize tensorboard and model checkpointing
        
        
    def load_datasets(self, dataset_dict):
        '''
            Loads the datasets according to their presence in the provided dictionary.
            Optionally truncates the dataset, if the 'take_only' key is present with a positive value
        '''
        if 'training' in dataset_dict and dataset_dict['training'] is not None:
            take_only = dataset_dict['training']['take_only'] if 'take_only' in dataset_dict['training'] else None
            self.training_dataset = load_dataset(dataset_dict['training'], take=take_only)
            print("{}: Loaded training dataset: {}".format(self.name, dataset_dict['training']))
            
        if 'validation' in dataset_dict and dataset_dict['validation'] is not None:
            take_only = dataset_dict['validation']['take_only'] if 'take_only' in dataset_dict['validation'] else None
            self.validation_dataset = load_dataset(dataset_dict['validation'], take=take_only)
            print("{}: Loaded validation dataset: {}".format(self.name, dataset_dict['validation']))
            
        if 'testing' in dataset_dict and dataset_dict['testing'] is not None:
            take_only = dataset_dict['testing']['take_only'] if 'take_only' in dataset_dict['testing'] else None
            self.testing_dataset = load_dataset(dataset_dict['testing'], take=take_only)
            print("{}: Loaded testing dataset: {}".format(self.name, dataset_dict['testing']))
        
    
    def initiate_training(self):
        '''
        Initiate a new training round by providing the intermediate layer output for a new batch:
        The client privately selects a batch of data and sends the model output to the server
        Can only be called on a client.
        '''
        assert self.is_client, "Servers cannot initiate training!"
        # TODO: Also implement for validation and testing
        # FIXME: THIS MAY RETURN ALWAYS THE SAME BATCH WHEN CALLED (DEPENDING ON TF VERSION) - WE NEED TO CREATE A LAMBDA FOR READING THE DATASET - LIKE IN DCSEG
        for step, row in enumerate(self.training_dataset):
            self.current_batch = row # We store the current batch for later.. when we have to train our section of the network
            return self.model_bottom(row['x'], training=True), row['y']
    
        
    def backpropagation(self, input_=None, gt=None, gradients=None, network_to_train='bottom'):
        '''
        Perform backpropadation given either an input tensor, a ground truth, or gradients for remote layers (depending on the context).
        input_: The input of the network. It's usually received from the server or the client, depending on the topology
        gt: It has a value only when the client bottom_model is sending the intermediate output to the server. It's necessary for the server to compute the loss in the I topology
        network_to_train: either 'bottom' or 'top', which network to train in the client, for the U topology.
        '''
        if self.is_server:
            # We are receiving intermediate outputs from a client, we need to train the server and return the gradients
            with tf.GradientTape() as tape:
                predictions = self.server_model(input_, training=True)
                loss = self.loss_fn(gt, predictions)
            server_gradients = tape.gradient(loss, self.server_model.trainable_weights)
            # Train our section of the network
            self.optimizer.apply_gradients(zip(server_gradients, self.server_model.trainable_weights))
            # TODO: Log training accuracy
            return server_gradients
        elif self.is_client:
            if network_to_train == 'bottom':
                pass # TODO: Apply the gradients to bottom network using self.current_batch as target
                raise NotImplementedError("TODO: Implement backpropagation starting from gradients received by the server")
            elif network_to_train == 'top':
                with tf.GradientTape() as tape:
                    predictions = self.model_top(input_, training=True)
                    loss = self.loss_fn(self.current_batch['y'], predictions)
                server_gradients = tape.gradient(loss, self.server_model.trainable_weights)
        
    def forward_pass(self, intermediate_output_tensor)
        if self.is_server:
            # We are receiving intermediate outputs from a client, we need to provide the model output
            return self.server_model(intermediate_output_tensor, training=True)
        elif self.is_client:
            # We received an intermediate_output from the server for calculating predictions
            predictions = self.model_top(intermediate_output_tensor, training=False)
            # TODO: Do something with the predictions
            
        

In [24]:
class SplitTraining():
    def __init__(self, configuration, topology, model_folder='./models/'):
        ''' Initialize the Server/Clients given a Server/Client definition and a topology. Orchestrates the training process and the aggregation according to the topology'''
        self.model_folder = model_folder
        self.configuration = configuration
        self.topology = topology
        self.server = None
        self.clients = dict()
        
    def setup(self):
        # Creating the server and the clients
        for config in self.configuration:
            if config['type'] == 'server':
                self.server = SplitModel(configuration=config, model_folder=self.model_folder)
            elif config['type'] == 'client':
                self.clients[config['name']] = SplitModel(configuration=config, model_folder=self.model_folder)
        print("Created {} server and {} clients".format(len([self.server]), len(self.clients)))
        
        print("Building the models...")
        self.server.setup()
        for c_name, client in self.clients.items():
            print("Building {}...".format(c_name))
            client.setup()
            
    def print_config(self):
        print("Server Model:")
        self.server.server_model.summary()
        for c_name, client in self.clients.items():
            print("{}: Bottom Model".format(c_name))
            client.model_bottom.summary()
            if client.model_top is not None:
                print("{}: Top Model".format(c_name))
                client.model_top.summary()
        
    def load_datasets(self, datasets_dict):
        for client_name, datasets in datasets_dict.items():
            self.clients[client_name].load_datasets(datasets)
    
    def split_training(self):
        
        # TODO: Eventually define more complex strategies, like:
        # Client selection, Output/Gradient aggregation, etc.
        # For the moment we train one client at a time
        
        for current_client_name, current_client in self.clients.items():
            # The client provides its intermediate output
            client_intermediate_output, client_gt = current_client.initiate_training()
            print("Client {} sending a tensor with shape: {}".format(current_client_name, client_intermediate_output.shape))
            
            if self.topology == 'I':
                # Sending the intermediate output to the server, which trains its section of the network and return its gradient
                server_gradient = self.server.backpropagation(input_=client_intermediate_output, gt=client_gt, network='server')
                # Sending the server gradient back to the client, which trains its section of the network
                current_client.backpropagation(gradients=server_gradient, network='bottom')
                
            elif self.topology == 'U':
                # Sending the intermediate output to the server, which returns its intermediate output
                server_intermediate_output = self.server.forward_pass(client_intermediate_output)
                # Sending the server output to the client's top model to compute the gradient
                client_top_gradient = current_client.backpropagation(input_=server_gradient, network='top')
                # Use the gradient to train the server model, obtaining the server gradient
                server_gradient = self.server.backpropagation(gradients=client_top_gradient, network='server')
                # Use the server gradient to train the client bottom model
                current_client.backpropagation(gradients=server_gradient, network='bottom')
            
            elif self.topology == 'V':
                # TODO: Implement the aggregation for V topology
                raise NotImplementedError("V Topology not implemented yet")
            # Sending the intermediate output to the server, which returns its intermediate output or predictions
            server_intermediate_output = self.server.on_forwardpass_received(intermediate_output, phase='training')
            
            # Sending the output back to the client
            current_client = on_forwardpass_received(intermediate_output, phase='training')
            
        

In [25]:
class DummyModel(tk.Model):
    ''' THIS IS A PLACEHOLDER THAT CREATES A DUMMY MODEL UNTIL WE FIND A WAY TO SPLIT PRE-TRAINED MODELS CORRECTLY'''
    def __init__(self, input_shape,  include_top=False, weights='imagenet'):
        input_layer = tk.layers.Input(shape=input_shape)
        
        latest_layer = input_layer
        for i in range(3):
            # Define a block 
            conv = tk.layers.Conv2D(3, 3, activation='relu', name='block_{}_conv'.format(i))(latest_layer)
            bn = tk.layers.BatchNormalization(name='block_{}_bn'.format(i))(conv)
            latest_layer=bn
        gap = tk.layers.GlobalAveragePooling2D()(latest_layer)
        predictions = tk.layers.Dense(14, activation='sigmoid')(gap)
        return super().__init__(inputs=input_layer, outputs=predictions)

In [29]:
# Testing

from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras.applications.densenet import DenseNet121

# Definitions:
# Model Bottom: It's the model of the client that receives data as input.
# Model Top: It's the last part of the network (the ones that gives predictions). It is missing in the I an V topologies as it's part of the server model.

# Topology (from https://arxiv.org/pdf/1812.00564.pdf):
# I: Vanilla Split Learning - It is implemented as a U model with a missing top (i.e. predictions are always sent to the client that made the request.)
# U: SL Without Label Sharing
# V: SL for Vertically Partitioned Data
BASE_DATASET_PATH = './datasets/CheXpertFederated/'
# Test with Densenet (53, 313)
SL_BOTTOM = 3
SL_TOP = None
input_shape = (224, 224, 3) # Shape of the input of client models (data shape). Can be different for each client.
output_shape = (14,) # Shape of the outputs. Can be different for each client when using the U topology.
client_model_top = None
client_model_bottom = DummyModel
server_model = DummyModel


configuration = [   
                    {'type':'server', 'name':'server', 'server_model':server_model, 'split_layer_top':SL_TOP, 'split_layer_bottom':SL_BOTTOM, 'input_shape':input_shape, 'output_shape':output_shape},
                    {'type':'client', 'name':'client_1', 'model_bottom': client_model_bottom, 'split_layer_bottom': SL_BOTTOM, 'model_top': client_model_top, 'split_layer_top':SL_TOP, 'input_shape':input_shape, 'output_shape':output_shape},
                    {'type':'client', 'name':'client_2', 'model_bottom': client_model_bottom, 'split_layer_bottom': SL_BOTTOM, 'model_top': client_model_top, 'split_layer_top':SL_TOP, 'input_shape':input_shape, 'output_shape':output_shape},
                    {'type':'client', 'name':'client_3', 'model_bottom': client_model_bottom, 'split_layer_bottom': SL_BOTTOM, 'model_top': client_model_top, 'split_layer_top':SL_TOP, 'input_shape':input_shape, 'output_shape':output_shape}
                ]

split_datasets = {
                    'client_1' : {'training': BASE_DATASET_PATH + 'train_def-part-0.tfrecord', 'validation': None, 'testing': None},
                    'client_2' : {'training': BASE_DATASET_PATH + 'train_def-part-1.tfrecord', 'validation': None, 'testing': None},
                    'client_3' : {'training': BASE_DATASET_PATH + 'train_def-part-2.tfrecord', 'validation': None, 'testing': None},
                 }

splitprocess = SplitTraining(configuration, topology='I')


In [30]:
splitprocess.setup()
# splitprocess.print_config()
splitprocess.load_datasets(split_datasets)
splitprocess.split_training()

Created 1 server and 3 clients
Building the models...
Building client_1...
Building client_2...
Building client_3...
client_1: Loaded training dataset: ./datasets/CheXpertFederated/train_def-part-0.tfrecord
client_2: Loaded training dataset: ./datasets/CheXpertFederated/train_def-part-1.tfrecord
client_3: Loaded training dataset: ./datasets/CheXpertFederated/train_def-part-2.tfrecord
