In [None]:
from utils.Encoder import Encoder
from utils.Classifier import Classifier
from utils.Detector import Detector
from utils.Segmenter import Segmenter

In [None]:
import keras.backend as K
from keras.layers import Layer, Dense, TimeDistributed, Concatenate, InputSpec,  RNN
from keras.layers.wrappers import Wrapper
import numpy as np
import tensorflow as tf

class ScaledDotProductAttention(Layer):
    """
        Implementation according to:
            "Attention is all you need" by A Vaswani, N Shazeer, N Parmar (2017)
    """

    def __init__(self, return_attention=False, **kwargs):    
        self._return_attention = return_attention
        self.supports_masking = True
        super(ScaledDotProductAttention, self).__init__(**kwargs)
    
    def compute_output_shape(self, input_shape):
        self._validate_input_shape(input_shape)

        if not self._return_attention:
            return input_shape[-1]
        else:
            return [input_shape[-1], [input_shape[0][0], input_shape[0][1], input_shape[1][2]]]
    
    def _validate_input_shape(self, input_shape):
        if len(input_shape) != 3:
            raise ValueError("Layer received an input shape {0} but expected three inputs (Q, V, K).".format(input_shape))
        else:
            if input_shape[0][0] != input_shape[1][0] or input_shape[1][0] != input_shape[2][0]:
                raise ValueError("All three inputs (Q, V, K) have to have the same batch size; received batch sizes: {0}, {1}, {2}".format(input_shape[0][0], input_shape[1][0], input_shape[2][0]))
            if input_shape[0][1] != input_shape[1][1] or input_shape[1][1] != input_shape[2][1]:
                raise ValueError("All three inputs (Q, V, K) have to have the same length; received lengths: {0}, {1}, {2}".format(input_shape[0][0], input_shape[1][0], input_shape[2][0]))
            if input_shape[0][2] != input_shape[1][2]:
                raise ValueError("Input shapes of Q {0} and V {1} do not match.".format(input_shape[0], input_shape[1]))
    
    def build(self, input_shape):
        self._validate_input_shape(input_shape)
        
        super(ScaledDotProductAttention, self).build(input_shape)
    
    def call(self, x, mask=None):
        q, k, v = x
        d_k = q.shape.as_list()[2]

        # in pure tensorflow:
        # weights = tf.matmul(x_batch, tf.transpose(y_batch, perm=[0, 2, 1]))
        # normalized_weights = tf.nn.softmax(weights/scaling)
        # output = tf.matmul(normalized_weights, x_batch)
        
        weights = K.batch_dot(q,  k, axes=[2, 2])

        if mask is not None:
            # add mask weights
            if isinstance(mask, (list, tuple)):
                if len(mask) > 0:
                    raise ValueError("mask can only be a Tensor or a list of length 1 containing a tensor.")

                mask = mask[0]

            weights += -1e10*(1-mask)

        normalized_weights = K.softmax(weights / np.sqrt(d_k))
        output = K.batch_dot(normalized_weights, v)
        
        if self._return_attention:
            return [output, normalized_weights]
        else:
            return output

    def get_config(self):
        config = {'return_attention': self._return_attention}
        base_config = super(ScaledDotProductAttention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

class MultiHeadAttention():
    """
        Implementation according to:
            "Attention is all you need" by A Vaswani, N Shazeer, N Parmar (2017)
    """

    def __init__(self, h, d_k=None, d_v=None, d_model=None, activation=None, return_attention=False, **kwargs):    
        super(MultiHeadAttention, self).__init__(**kwargs)
        
        if (type(h) is not int or h < 2):
            raise ValueError("You have to set `h` to an int >= 2.")
        self._h = h
        
        if d_model and (type(d_model) is not int or d_model < 1):
                raise ValueError("You have to set `d_model` to an int >= 1.")
        self._d_model = d_model
        
        if d_k and int (type(d_k) is not int or d_k < 1):
            raise ValueError("You have to set `d_k` to an int >= 1.")
        self._d_k = d_k
        
        if d_v and (type(d_v) is not int or d_v < 1):
            raise ValueError("You have to set `d_v` to an int >= 1.")
        self._d_v = d_v
        
        self._activation = None
        self._return_attention = return_attention
    
    def compute_output_shape(self, input_shape):
        self._validate_input_shape(input_shape)
        
        if self._return_attention:
            return [input_shape[-1], [input_shape[0][0], input_shape[1][1], self._h*input_shape[2][2]]]
        else:
            return input_shape[-1]
    
    def _validate_input_shape(self, input_shape):
        if len(input_shape) != 3:
            raise ValueError("Layer received an input shape {0} but expected three inputs (Q, V, K).".format(input_shape))
        else:
            if input_shape[0][0] != input_shape[1][0] or input_shape[1][0] != input_shape[2][0]:
                raise ValueError("All three inputs (Q, V, K) have to have the same batch size; received batch sizes: {0}, {1}, {2}".format(input_shape[0][0], input_shape[1][0], input_shape[2][0]))
            if input_shape[0][1] != input_shape[1][1] or input_shape[1][1] != input_shape[2][1]:
                raise ValueError("All three inputs (Q, V, K) have to have the same length; received lengths: {0}, {1}, {2}".format(input_shape[0][0], input_shape[1][0], input_shape[2][0]))
            if input_shape[0][2] != input_shape[1][2]:
                raise ValueError("Input shapes of Q {0} and V {1} do not match.".format(input_shape[0], input_shape[1]))
    
    def build(self, input_shape):
        self._validate_input_shape(input_shape)
        
        d_k = self._d_k if self._d_k else input_shape[1][-1]
        d_model = self._d_model if self._d_model else input_shape[1][-1]
        d_v = self._d_v

        if type(d_k) == tf.Dimension:
            d_k = d_k.value
        if type(d_model) == tf.Dimension:
            d_model = d_model.value
        
        self._q_layers = []
        self._k_layers = []
        self._v_layers = []
        self._sdp_layer = ScaledDotProductAttention(return_attention=self._return_attention)
    
        for _ in range(self._h):
            self._q_layers.append(
                TimeDistributed(
                    Dense(d_k, activation=self._activation, use_bias=False)
                )
            )
            self._k_layers.append(
                TimeDistributed(
                    Dense(d_k, activation=self._activation, use_bias=False)
                )
            )
            self._v_layers.append(
                TimeDistributed(
                    Dense(d_v, activation=self._activation, use_bias=False)
                )
            )
        
        self._output = TimeDistributed(Dense(d_model))
        #if self._return_attention:
        #    self._output = Concatenate()
    
    def __call__(self, x, mask=None):
        if isinstance(x, (list, tuple)):
            self.build([it.shape for it in x])
        else:
            self.build(x.shape)

        q, k, v = x
        
        outputs = []
        attentions = []
        for i in range(self._h):
            qi = self._q_layers[i](q)
            ki = self._k_layers[i](k)
            vi = self._v_layers[i](v)
            
            if self._return_attention:
                output, attention = self._sdp_layer([qi, ki, vi], mask=mask)
                outputs.append(output)
                attentions.append(attention)
            else:
                output = self._sdp_layer([qi, ki, vi], mask=mask)
                outputs.append(output)
            
        concatenated_outputs = Concatenate()(outputs)
        output = self._output(concatenated_outputs)
        
        if self._return_attention:
            attention = Concatenate()(attentions)
            # print("attention", attention, attention.shape)
       
        if self._return_attention:
            return [output, attention]
        else:
            return output        

# https://wanasit.github.io/attention-based-sequence-to-sequence-in-keras.html
# https://arxiv.org/pdf/1508.04025.pdf
class SequenceAttention(Layer):
    """
        Takes two inputs of the shape (batch_size, T, dim1) and (batch_size, T, dim2),
        whereby the first item is the source data and the second one the key data.
        This layer then calculates for each batch's element and each time step a softmax attention 
        vector between the key data and the source data. Finally, this attention vector is multiplied
        with the source data to obtain a weighted output. This means, that the key data is used to
        interpret the source data in a special way to create an output of the same shape as the source data.
    """
    def __init__(self, similarity, kernel_initializer="glorot_uniform", **kwargs):
        super(SequenceAttention, self).__init__(**kwargs)
        if isinstance(similarity, str):
            ALLOWED_SIMILARITIES = ["additive", "multiplicative" ]
            if similarity not in ALLOWED_SIMILARITIES:
                raise ValueError("`similarity` has to be either a callable or one of the following: {0}".format(ALLOWED_SIMILARITIES))
            else:
                self._similarity = getattr(self, "_" + similarity + "_similarity")
        elif callable(similarity):
            self._similarity = similarity
        else:
            raise ValueError("`similarity` has to be either a callable or one of the following: {0}".format(ALLOWED_SIMILARITIES))
            
        self._kernel_initializer = kernel_initializer
            
    def build(self, input_shape):
        super(SequenceAttention, self).build(input_shape)
        self._validate_input_shape(input_shape)
        
        self._weights = {}
        if self._similarity == self._additive_similarity:
            self._weights["w_a"] = self.add_weight(
                name='w_a', 
                shape=(input_shape[0][-1] + input_shape[1][-1], input_shape[0][-1]),
                initializer=self._kernel_initializer,
                trainable=True
            )
            
            self._weights["v_a"] = self.add_weight(
                name='v_a', 
                shape=(1, input_shape[0][-1]),
                initializer=self._kernel_initializer,
                trainable=True
            )
            
        elif self._similarity == self._multiplicative_similarity:
            self._weights["w_a"] = self.add_weight(
                name='w_a', 
                shape=(input_shape[1][-1], input_shape[0][-1]),
                initializer=self._kernel_initializer,
                trainable=True
            )

        self.built = True
        
    def compute_output_shape(self, input_shape):
        self._validate_input_shape(input_shape)
        
        return input_shape[0]
            
    def _validate_input_shape(self, input_shape):
        if len(input_shape) != 2:
            raise ValueError("Layer received an input shape {0} but expected two inputs (source, query).".format(input_shape))
        else:
            if input_shape[0][0] != input_shape[1][0]:
                raise ValueError("Both two inputs (source, query) have to have the same batch size; received batch sizes: {0}, {1}".format(input_shape[0][0], input_shape[1][0]))
            if input_shape[0][1] != input_shape[1][1]:
                raise ValueError("Both inputs (source, query) have to have the same length; received lengths: {0}, {1}".format(input_shape[0][0], input_shape[1][0]))
        
    def call(self, x):
        source, query = x
        
        similarity = self._similarity(source, query)
        expected_similarity_shape = [source.shape.as_list()[0], source.shape.as_list()[1], source.shape.as_list()[1]]
       
        if similarity.shape.as_list() != expected_similarity_shape:
            raise RuntimeError("The similarity function has returned a similarity with shape {0}, but expected {1}".format(similarity.shape.as_list()[:2], expected_similarity_shape))
        
        score = K.softmax(similarity)
        output = K.batch_dot(score, source, axes=[1, 1])
        
        return output
    
    def _additive_similarity(self, source, query):
        concatenation = K.concatenate([source, query], axis=2)
        nonlinearity = K.tanh(K.dot(concatenation, self._weights["w_a"]))
        
        # tile the weight vector (1, 1, dim) for each time step and each element of the batch -> (bs, T, dim)
        source_shape = K.shape(source)
        vaeff = K.tile(K.expand_dims(self._weights["v_a"], 0), [source_shape[0], source_shape[1], 1])

        similarity = K.batch_dot(K.permute_dimensions(vaeff, [0, 2, 1]), nonlinearity, axes=[1, 2])
        
        return similarity

    def _multiplicative_similarity(self, source, query):
        qp = K.dot(query, self._weights["w_a"])
        similarity = K.batch_dot(K.permute_dimensions(qp, [0, 2, 1]), source, axes=[1, 2])
        
        return similarity

    def get_config(self):
        config = {'similarity': self._similarity, 'kernel_initializer': self._kernel_initializer}
        base_config = super(SequenceAttention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

class AttentionRNNWrapper(Wrapper):
    """
        The idea of the implementation is based on the paper:
            "Effective Approaches to Attention-based Neural Machine Translation" by Luong et al.
        This layer is an attention layer, which can be wrapped around arbitrary RNN layers.
        This way, after each time step an attention vector is calculated
        based on the current output of the LSTM and the entire input time series.
        This attention vector is then used as a weight vector to choose special values
        from the input data. This data is then finally concatenated to the next input
        time step's data. On this a linear transformation in the same space as the input data's space
        is performed before the data is fed into the RNN cell again.
        This technique is similar to the input-feeding method described in the paper cited
    """

    def __init__(self, layer, weight_initializer="glorot_uniform", **kwargs):
        assert isinstance(layer, RNN)
        self.layer = layer
        self.supports_masking = True
        self.weight_initializer = weight_initializer
        
        super(AttentionRNNWrapper, self).__init__(layer, **kwargs)
        
    def _validate_input_shape(self, input_shape):
        if len(input_shape) != 3:
            raise ValueError("Layer received an input with shape {0} but expected a Tensor of rank 3.".format(input_shape[0]))

    def build(self, input_shape):
        self._validate_input_shape(input_shape)

        self.input_spec = InputSpec(shape=input_shape)
        
        if not self.layer.built:
            self.layer.build(input_shape)
            self.layer.built = True
            
        input_dim = input_shape[-1]

        if self.layer.return_sequences:
            output_dim = self.layer.compute_output_shape(input_shape)[0][-1]
        else:
            output_dim = self.layer.compute_output_shape(input_shape)[-1]
      
        self._W1 = self.add_weight(shape=(input_dim, input_dim), name="{}_W1".format(self.name), initializer=self.weight_initializer)
        self._W2 = self.add_weight(shape=(output_dim, input_dim), name="{}_W2".format(self.name), initializer=self.weight_initializer)
        self._W3 = self.add_weight(shape=(2*input_dim, input_dim), name="{}_W3".format(self.name), initializer=self.weight_initializer)
        self._b2 = self.add_weight(shape=(input_dim,), name="{}_b2".format(self.name), initializer=self.weight_initializer)
        self._b3 = self.add_weight(shape=(input_dim,), name="{}_b3".format(self.name), initializer=self.weight_initializer)
        self._V = self.add_weight(shape=(input_dim,1), name="{}_V".format(self.name), initializer=self.weight_initializer)
        
        super(AttentionRNNWrapper, self).build()
        
    def compute_output_shape(self, input_shape):
        self._validate_input_shape(input_shape)

        return self.layer.compute_output_shape(input_shape)
    
    @property
    def trainable_weights(self):
        return self._trainable_weights + self.layer.trainable_weights

    @property
    def non_trainable_weights(self):
        return self._non_trainable_weights + self.layer.non_trainable_weights

    def step(self, x, states):   
        h = states[0]
        # states[1] necessary?

        # equals K.dot(X, self._W1) + self._b2 with X.shape=[bs, T, input_dim]
        total_x_prod = states[-1]
        # comes from the constants (equals the input sequence)
        X = states[-2]
        
        # expand dims to add the vector which is only valid for this time step
        # to total_x_prod which is valid for all time steps
        hw = K.expand_dims(K.dot(h, self._W2), 1)
        additive_atn = total_x_prod + hw
        attention = K.softmax(K.dot(additive_atn, self._V), axis=1)
        x_weighted = K.sum(attention * X, [1])

        x = K.dot(K.concatenate([x, x_weighted], 1), self._W3) + self._b3
        
        h, new_states = self.layer.cell.call(x, states[:-2])
        
        return h, new_states
    
    def call(self, x, constants=None, mask=None, initial_state=None):
        # input shape: (n_samples, time (padded with zeros), input_dim)
        input_shape = self.input_spec.shape

        if self.layer.stateful:
            initial_states = self.layer.states
        elif initial_state is not None:
            initial_states = initial_state
            if not isinstance(initial_states, (list, tuple)):
                initial_states = [initial_states]

            base_initial_state = self.layer.get_initial_state(x)
            if len(base_initial_state) != len(initial_states):
                raise ValueError("initial_state does not have the correct length. Received length {0} but expected {1}".format(len(initial_states), len(base_initial_state)))
            else:
                # check the state' shape
                for i in range(len(initial_states)):
                    if not initial_states[i].shape.is_compatible_with(base_initial_state[i].shape): #initial_states[i][j] != base_initial_state[i][j]:
                        raise ValueError("initial_state does not match the default base state of the layer. Received {0} but expected {1}".format([x.shape for x in initial_states], [x.shape for x in base_initial_state]))
        else:
            initial_states = self.layer.get_initial_state(x)
            
        if not constants:
            constants = []
            
        constants += self.get_constants(x)
        
        last_output, outputs, states = K.rnn(
            self.step,
            x,
            initial_states,
            go_backwards=self.layer.go_backwards,
            mask=mask,
            constants=constants,
            unroll=self.layer.unroll,
            input_length=input_shape[1]
        )
        
        if self.layer.stateful:
            self.updates = []
            for i in range(len(states)):
                self.updates.append((self.layer.states[i], states[i]))

        if self.layer.return_sequences:
            output = outputs
        else:
            output = last_output 

        # Properly set learning phase
        if getattr(last_output, '_uses_learning_phase', False):
            output._uses_learning_phase = True
            for state in states:
                state._uses_learning_phase = True

        if self.layer.return_state:
            if not isinstance(states, (list, tuple)):
                states = [states]
            else:
                states = list(states)
            return [output] + states
        else:
            return output

    def get_constants(self, x):
        # add constants to speed up calculation
        constants = [x, K.dot(x, self._W1) + self._b2]
        
        return constants

    def get_config(self):
        config = {'weight_initializer': self.weight_initializer}
        base_config = super(AttentionRNNWrapper, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

class ExternalAttentionRNNWrapper(Wrapper):
    """
        The basic idea of the implementation is based on the paper:
            "Effective Approaches to Attention-based Neural Machine Translation" by Luong et al.
        This layer is an attention layer, which can be wrapped around arbitrary RNN layers.
        This way, after each time step an attention vector is calculated
        based on the current output of the LSTM and the entire input time series.
        This attention vector is then used as a weight vector to choose special values
        from the input data. This data is then finally concatenated to the next input
        time step's data. On this a linear transformation in the same space as the input data's space
        is performed before the data is fed into the RNN cell again.
        This technique is similar to the input-feeding method described in the paper cited.
        The only difference compared to the AttentionRNNWrapper is, that this layer
        applies the attention layer not on the time-depending input but on a second
        time-independent input (like image clues) as described in:
            Show, Attend and Tell: Neural Image Caption Generation with Visual Attention
            https://arxiv.org/abs/1502.03044
    """
    def __init__(self, layer, weight_initializer="glorot_uniform", return_attention=False, **kwargs):
        assert isinstance(layer, RNN)
        self.layer = layer
        self.supports_masking = True
        self.weight_initializer = weight_initializer
        self.return_attention = return_attention
        self._num_constants = None

        super(ExternalAttentionRNNWrapper, self).__init__(layer, **kwargs)

        self.input_spec = [InputSpec(ndim=3), InputSpec(ndim=3)]
        
    def _validate_input_shape(self, input_shape):
        if len(input_shape) >= 2:
            if len(input_shape[:2]) != 2:
                raise ValueError("Layer has to receive two inputs: the temporal signal and the external signal which is constant for all time steps")
            if len(input_shape[0]) != 3:
                raise ValueError("Layer received a temporal input with shape {0} but expected a Tensor of rank 3.".format(input_shape[0]))
            if len(input_shape[1]) != 3:
                raise ValueError("Layer received a time-independent input with shape {0} but expected a Tensor of rank 3.".format(input_shape[1]))
        else:
            raise ValueError("Layer has to receive at least 2 inputs: the temporal signal and the external signal which is constant for all time steps")

    def build(self, input_shape):
        self._validate_input_shape(input_shape)

        for i, x in enumerate(input_shape):
            self.input_spec[i] = InputSpec(shape=x)
        
        if not self.layer.built:
            self.layer.build(input_shape)
            self.layer.built = True
            
        temporal_input_dim = input_shape[0][-1]
        static_input_dim = input_shape[1][-1]

        if self.layer.return_sequences:
            output_dim = self.layer.compute_output_shape(input_shape[0])[0][-1]
        else:
            output_dim = self.layer.compute_output_shape(input_shape[0])[-1]
      
        self._W1 = self.add_weight(shape=(static_input_dim, temporal_input_dim), name="{}_W1".format(self.name), initializer=self.weight_initializer)
        self._W2 = self.add_weight(shape=(output_dim, temporal_input_dim), name="{}_W2".format(self.name), initializer=self.weight_initializer)
        self._W3 = self.add_weight(shape=(temporal_input_dim + static_input_dim, temporal_input_dim), name="{}_W3".format(self.name), initializer=self.weight_initializer)
        self._b2 = self.add_weight(shape=(temporal_input_dim,), name="{}_b2".format(self.name), initializer=self.weight_initializer)
        self._b3 = self.add_weight(shape=(temporal_input_dim,), name="{}_b3".format(self.name), initializer=self.weight_initializer)
        self._V = self.add_weight(shape=(temporal_input_dim, 1), name="{}_V".format(self.name), initializer=self.weight_initializer)
        
        super(ExternalAttentionRNNWrapper, self).build()
        
    @property
    def trainable_weights(self):
        return self._trainable_weights + self.layer.trainable_weights

    @property
    def non_trainable_weights(self):
        return self._non_trainable_weights + self.layer.non_trainable_weights

    def compute_output_shape(self, input_shape):
        self._validate_input_shape(input_shape)

        output_shape =  self.layer.compute_output_shape(input_shape[0])

        if self.return_attention:
            if not isinstance(output_shape, list):
                output_shape = [output_shape]

            output_shape = output_shape + [(None, input_shape[1][1])]

        return output_shape
    
    def step(self, x, states):  
        h = states[0]
        # states[1] necessary?
        
        # comes from the constants
        X_static = states[-2]
        # equals K.dot(static_x, self._W1) + self._b2 with X.shape=[bs, L, static_input_dim]
        total_x_static_prod = states[-1]

        # expand dims to add the vector which is only valid for this time step
        # to total_x_prod which is valid for all time steps
        hw = K.expand_dims(K.dot(h, self._W2), 1)
        additive_atn = total_x_static_prod + hw
        attention = K.softmax(K.dot(additive_atn, self._V), axis=1)
        static_x_weighted = K.sum(attention * X_static, [1])
        
        x = K.dot(K.concatenate([x, static_x_weighted], 1), self._W3) + self._b3

        h, new_states = self.layer.cell.call(x, states[:-2])
        
        # append attention to the states to "smuggle" it out of the RNN wrapper
        attention = K.squeeze(attention, -1)
        h = K.concatenate([h, attention])

        return h, new_states
    
    def call(self, x, constants=None, mask=None, initial_state=None):
        # input shape: (n_samples, time (padded with zeros), input_dim)
        input_shape = self.input_spec[0].shape

        if len(x) > 2:
            initial_state = x[2:]
            x = x[:2]
            assert len(initial_state) >= 1

        static_x = x[1]
        x = x[0]

        if self.layer.stateful:
            initial_states = self.layer.states
        elif initial_state is not None:
            initial_states = initial_state
            if not isinstance(initial_states, (list, tuple)):
                initial_states = [initial_states]
        else:
            initial_states = self.layer.get_initial_state(x)
            
        if not constants:
            constants = []
        constants += self.get_constants(static_x)

        last_output, outputs, states = K.rnn(
            self.step,
            x,
            initial_states,
            go_backwards=self.layer.go_backwards,
            mask=mask,
            constants=constants,
            unroll=self.layer.unroll,
            input_length=input_shape[1]
        )

        # output has at the moment the form:
        # (real_output, attention)
        # split this now up

        output_dim = self.layer.compute_output_shape(input_shape)[0][-1]
        last_output = last_output[:output_dim]

        attentions = outputs[:, :, output_dim:]
        outputs = outputs[:, :, :output_dim]
        
        if self.layer.stateful:
            self.updates = []
            for i in range(len(states)):
                self.updates.append((self.layer.states[i], states[i]))

        if self.layer.return_sequences:
            output = outputs
        else:
            output = last_output 

        # Properly set learning phase
        if getattr(last_output, '_uses_learning_phase', False):
            output._uses_learning_phase = True
            for state in states:
                state._uses_learning_phase = True

        if self.layer.return_state:
            if not isinstance(states, (list, tuple)):
                states = [states]
            else:
                states = list(states)
            output = [output] + states

        if self.return_attention:
            if not isinstance(output, list):
                output = [output]
            output = output + [attentions]

        return output

    def _standardize_args(self, inputs, initial_state, constants, num_constants):
        """Standardize `__call__` to a single list of tensor inputs.
        When running a model loaded from file, the input tensors
        `initial_state` and `constants` can be passed to `RNN.__call__` as part
        of `inputs` instead of by the dedicated keyword arguments. This method
        makes sure the arguments are separated and that `initial_state` and
        `constants` are lists of tensors (or None).
        # Arguments
        inputs: tensor or list/tuple of tensors
        initial_state: tensor or list of tensors or None
        constants: tensor or list of tensors or None
        # Returns
        inputs: tensor
        initial_state: list of tensors or None
        constants: list of tensors or None
        """
        inputs=inputs[:2]
        if isinstance(inputs, list) and len(inputs) > 2:
            assert initial_state is None and constants is None
            if num_constants is not None:
                constants = inputs[-num_constants:]
                inputs = inputs[:-num_constants]
            initial_state = inputs[2:]
            inputs = inputs[:2]

        def to_list_or_none(x):
            if x is None or isinstance(x, list):
                return x
            if isinstance(x, tuple):
                return list(x)
            return [x]

        initial_state = to_list_or_none(initial_state)
        constants = to_list_or_none(constants)

        return inputs, initial_state, constants

    def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
        inputs, initial_state, constants = self._standardize_args(
            inputs, initial_state, constants, self._num_constants)

        if initial_state is None and constants is None:
            return super(ExternalAttentionRNNWrapper, self).__call__(inputs, **kwargs)

        # If any of `initial_state` or `constants` are specified and are Keras
        # tensors, then add them to the inputs and temporarily modify the
        # input_spec to include them.

        additional_inputs = []
        additional_specs = []
        if initial_state is not None:
            kwargs['initial_state'] = initial_state
            additional_inputs += initial_state
            self.state_spec = [InputSpec(shape=K.int_shape(state))
                               for state in initial_state]
            additional_specs += self.state_spec
        if constants is not None:
            kwargs['constants'] = constants
            additional_inputs += constants
            self.constants_spec = [InputSpec(shape=K.int_shape(constant))
                                   for constant in constants]
            self._num_constants = len(constants)
            additional_specs += self.constants_spec
        # at this point additional_inputs cannot be empty
        is_keras_tensor = K.is_keras_tensor(additional_inputs[0])
        for tensor in additional_inputs:
            if K.is_keras_tensor(tensor) != is_keras_tensor:
                raise ValueError('The initial state or constants of an ExternalAttentionRNNWrapper'
                                 ' layer cannot be specified with a mix of'
                                 ' Keras tensors and non-Keras tensors'
                                 ' (a "Keras tensor" is a tensor that was'
                                 ' returned by a Keras layer, or by `Input`)')

        if is_keras_tensor:
            # Compute the full input spec, including state and constants
            full_input = inputs + additional_inputs
            full_input_spec = self.input_spec + additional_specs
            # Perform the call with temporarily replaced input_spec
            original_input_spec = self.input_spec
            self.input_spec = full_input_spec
            output = super(ExternalAttentionRNNWrapper, self).__call__(full_input, **kwargs)
            self.input_spec = self.input_spec[:len(original_input_spec)]
            return output
        else:
            return super(ExternalAttentionRNNWrapper, self).__call__(inputs, **kwargs)

    def get_constants(self, x):
        # add constants to speed up calculation
        constants = [x, K.dot(x, self._W1) + self._b2]
        return constants

    def get_config(self):
        config = {'return_attention': self.return_attention, 'weight_initializer': self.weight_initializer}
        base_config = super(ExternalAttentionRNNWrapper, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [None]:
import keras



class SeqSelfAttention(keras.layers.Layer):

    ATTENTION_TYPE_ADD = 'additive'
    ATTENTION_TYPE_MUL = 'multiplicative'

    def __init__(self,
                 units=32,
                 attention_width=None,
                 attention_type=ATTENTION_TYPE_ADD,
                 return_attention=False,
                 history_only=False,
                 kernel_initializer='glorot_normal',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 use_additive_bias=True,
                 use_attention_bias=True,
                 attention_activation=None,
                 attention_regularizer_weight=0.0,
                 **kwargs):
        """Layer initialization.
        For additive attention, see: https://arxiv.org/pdf/1806.01264.pdf
        :param units: The dimension of the vectors that used to calculate the attention weights.
        :param attention_width: The width of local attention.
        :param attention_type: 'additive' or 'multiplicative'.
        :param return_attention: Whether to return the attention weights for visualization.
        :param history_only: Only use historical pieces of data.
        :param kernel_initializer: The initializer for weight matrices.
        :param bias_initializer: The initializer for biases.
        :param kernel_regularizer: The regularization for weight matrices.
        :param bias_regularizer: The regularization for biases.
        :param kernel_constraint: The constraint for weight matrices.
        :param bias_constraint: The constraint for biases.
        :param use_additive_bias: Whether to use bias while calculating the relevance of inputs features
                                  in additive mode.
        :param use_attention_bias: Whether to use bias while calculating the weights of attention.
        :param attention_activation: The activation used for calculating the weights of attention.
        :param attention_regularizer_weight: The weights of attention regularizer.
        :param kwargs: Parameters for parent class.
        """
        super(SeqSelfAttention, self).__init__(**kwargs)
        self.supports_masking = True
        self.units = units
        self.attention_width = attention_width
        self.attention_type = attention_type
        self.return_attention = return_attention
        self.history_only = history_only
        if history_only and attention_width is None:
            self.attention_width = int(1e9)

        self.use_additive_bias = use_additive_bias
        self.use_attention_bias = use_attention_bias
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.bias_initializer = keras.initializers.get(bias_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)
        self.bias_regularizer = keras.regularizers.get(bias_regularizer)
        self.kernel_constraint = keras.constraints.get(kernel_constraint)
        self.bias_constraint = keras.constraints.get(bias_constraint)
        self.attention_activation = keras.activations.get(attention_activation)
        self.attention_regularizer_weight = attention_regularizer_weight
        self._backend = keras.backend.backend()

        if attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
            self.Wx, self.Wt, self.bh = None, None, None
            self.Wa, self.ba = None, None
        elif attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
            self.Wa, self.ba = None, None
        else:
            raise NotImplementedError('No implementation for attention type : ' + attention_type)

    def get_config(self):
        config = {
            'units': self.units,
            'attention_width': self.attention_width,
            'attention_type': self.attention_type,
            'return_attention': self.return_attention,
            'history_only': self.history_only,
            'use_additive_bias': self.use_additive_bias,
            'use_attention_bias': self.use_attention_bias,
            'kernel_initializer': keras.initializers.serialize(self.kernel_initializer),
            'bias_initializer': keras.initializers.serialize(self.bias_initializer),
            'kernel_regularizer': keras.regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': keras.regularizers.serialize(self.bias_regularizer),
            'kernel_constraint': keras.constraints.serialize(self.kernel_constraint),
            'bias_constraint': keras.constraints.serialize(self.bias_constraint),
            'attention_activation': keras.activations.serialize(self.attention_activation),
            'attention_regularizer_weight': self.attention_regularizer_weight,
        }
        base_config = super(SeqSelfAttention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def build(self, input_shape):
        if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
            self._build_additive_attention(input_shape)
        elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
            self._build_multiplicative_attention(input_shape)
        super(SeqSelfAttention, self).build(input_shape)

    def _build_additive_attention(self, input_shape):
        feature_dim = int(input_shape[2])

        self.Wt = self.add_weight(shape=(feature_dim, self.units),
                                  name='{}_Add_Wt'.format(self.name),
                                  initializer=self.kernel_initializer,
                                  regularizer=self.kernel_regularizer,
                                  constraint=self.kernel_constraint)
        self.Wx = self.add_weight(shape=(feature_dim, self.units),
                                  name='{}_Add_Wx'.format(self.name),
                                  initializer=self.kernel_initializer,
                                  regularizer=self.kernel_regularizer,
                                  constraint=self.kernel_constraint)
        if self.use_additive_bias:
            self.bh = self.add_weight(shape=(self.units,),
                                      name='{}_Add_bh'.format(self.name),
                                      initializer=self.bias_initializer,
                                      regularizer=self.bias_regularizer,
                                      constraint=self.bias_constraint)

        self.Wa = self.add_weight(shape=(self.units, 1),
                                  name='{}_Add_Wa'.format(self.name),
                                  initializer=self.kernel_initializer,
                                  regularizer=self.kernel_regularizer,
                                  constraint=self.kernel_constraint)
        if self.use_attention_bias:
            self.ba = self.add_weight(shape=(1,),
                                      name='{}_Add_ba'.format(self.name),
                                      initializer=self.bias_initializer,
                                      regularizer=self.bias_regularizer,
                                      constraint=self.bias_constraint)

    def _build_multiplicative_attention(self, input_shape):
        feature_dim = int(input_shape[2])

        self.Wa = self.add_weight(shape=(feature_dim, feature_dim),
                                  name='{}_Mul_Wa'.format(self.name),
                                  initializer=self.kernel_initializer,
                                  regularizer=self.kernel_regularizer,
                                  constraint=self.kernel_constraint)
        if self.use_attention_bias:
            self.ba = self.add_weight(shape=(1,),
                                      name='{}_Mul_ba'.format(self.name),
                                      initializer=self.bias_initializer,
                                      regularizer=self.bias_regularizer,
                                      constraint=self.bias_constraint)

    def call(self, inputs, mask=None, **kwargs):
        input_len = K.shape(inputs)[1]

        if self.attention_type == SeqSelfAttention.ATTENTION_TYPE_ADD:
            e = self._call_additive_emission(inputs)
        elif self.attention_type == SeqSelfAttention.ATTENTION_TYPE_MUL:
            e = self._call_multiplicative_emission(inputs)

        if self.attention_activation is not None:
            e = self.attention_activation(e)
        if self.attention_width is not None:
            if self.history_only:
                lower = K.arange(0, input_len) - (self.attention_width - 1)
            else:
                lower = K.arange(0, input_len) - self.attention_width // 2
            lower = K.expand_dims(lower, axis=-1)
            upper = lower + self.attention_width
            indices = K.expand_dims(K.arange(0, input_len), axis=0)
            e -= 10000.0 * (1.0 - K.cast(lower <= indices, K.floatx()) * K.cast(indices < upper, K.floatx()))
        if mask is not None:
            mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1)
            e -= 10000.0 * ((1.0 - mask) * (1.0 - K.permute_dimensions(mask, (0, 2, 1))))

        # a_{t} = \text{softmax}(e_t)
        e = K.exp(e - K.max(e, axis=-1, keepdims=True))
        a = e / K.sum(e, axis=-1, keepdims=True)

        # l_t = \sum_{t'} a_{t, t'} x_{t'}
        v = K.batch_dot(a, inputs)
        if self.attention_regularizer_weight > 0.0:
            self.add_loss(self._attention_regularizer(a))

        if self.return_attention:
            return [v, a]
        return v

    def _call_additive_emission(self, inputs):
        input_shape = K.shape(inputs)
        batch_size, input_len = input_shape[0], input_shape[1]

        # h_{t, t'} = \tanh(x_t^T W_t + x_{t'}^T W_x + b_h)
        q = K.expand_dims(K.dot(inputs, self.Wt), 2)
        k = K.expand_dims(K.dot(inputs, self.Wx), 1)
        if self.use_additive_bias:
            h = K.tanh(q + k + self.bh)
        else:
            h = K.tanh(q + k)

        # e_{t, t'} = W_a h_{t, t'} + b_a
        if self.use_attention_bias:
            e = K.reshape(K.dot(h, self.Wa) + self.ba, (batch_size, input_len, input_len))
        else:
            e = K.reshape(K.dot(h, self.Wa), (batch_size, input_len, input_len))
        return e

    def _call_multiplicative_emission(self, inputs):
        # e_{t, t'} = x_t^T W_a x_{t'} + b_a
        e = K.batch_dot(K.dot(inputs, self.Wa), K.permute_dimensions(inputs, (0, 2, 1)))
        if self.use_attention_bias:
            e += self.ba[0]
        return e

    def compute_output_shape(self, input_shape):
        output_shape = input_shape
        if self.return_attention:
            attention_shape = (input_shape[0], output_shape[1], input_shape[1])
            return [output_shape, attention_shape]
        return output_shape

    def compute_mask(self, inputs, mask=None):
        if self.return_attention:
            return [mask, None]
        return mask

    def _attention_regularizer(self, attention):
        batch_size = K.cast(K.shape(attention)[0], K.floatx())
        input_len = K.shape(attention)[-1]
        indices = K.expand_dims(K.arange(0, input_len), axis=0)
        diagonal = K.expand_dims(K.arange(0, input_len), axis=-1)
        eye = K.cast(K.equal(indices, diagonal), K.floatx())
        return self.attention_regularizer_weight * K.sum(K.square(K.batch_dot(
            attention,
            K.permute_dimensions(attention, (0, 2, 1))) - eye)) / batch_size

    @staticmethod
    def get_custom_objects():
        return {'SeqSelfAttention': SeqSelfAttention}

# Train

In [None]:
from utils.Encoder import Encoder
from utils.Classifier import Classifier
from utils.Detector import Detector
from utils.Segmenter import Segmenter


img_size = 256
n_classes = 1
    
encoder = Encoder(weights=None)
classifier = Classifier(encoder)
detector = Detector(encoder, img_size, n_classes)
segmenter = Segmenter(encoder)
heads=[]
heads.append(classifier)
heads.append(detector)
heads.append(segmenter)

In [None]:
#from kulc.attention import ExternalAttentionRNNWrapper
from tensorflow.keras.layers import Input,Embedding,Lambda,Dense,TimeDistributed,LSTM,Reshape,Dropout
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from tensorflow.keras.utils import plot_model

def create_model(encoder_model, vocabulary_size, embedding_size, T, L, D):
    
#     image_input = Input(shape=(256,256,3), name="image_input")
    
#     image_model = tf.keras.applications.DenseNet121(include_top=False,
#                                           input_shape=(256,256,3),
#                                           weights='imagenet')
    
#     image_model = Model(image_model.input, image_model.layers[-2].output)
    
    #for layer in image_model.layers:
    #    layer.trainable = False
    
    
    image_features_input = encoder_model.model.output #image_model(image_input)
    #image_features_input = Reshape((16*16,512))(image_features_input)
    image_features_input = Reshape((8*8,1024))(image_features_input)
    #image_features_input = Dropout(0.2)(image_features_input)
    
    captions_input = Input(shape=(T,), name="captions_input")
    captions = Embedding(vocabulary_size, embedding_size, input_length=T)(captions_input)

    averaged_image_features = Lambda(lambda x: K.mean(x, axis=1))
    averaged_image_features = averaged_image_features(image_features_input)
    initial_state_h = Dense(embedding_size)(averaged_image_features)
    initial_state_c = Dense(embedding_size)(averaged_image_features)
  
    image_features = TimeDistributed(Dense(D, activation="relu" ))(image_features_input)
    #image_features = Dropout(0.2)(image_features)
    
    encoder = LSTM(embedding_size, return_sequences=True, return_state=True, recurrent_dropout=0.1)
    attented_encoder = ExternalAttentionRNNWrapper(encoder, return_attention=True )
    self_attention_layer = SeqSelfAttention(attention_activation='relu')
    
    output = TimeDistributed(Dense(vocabulary_size, activation="softmax"), name="output")

    # for training purpose
    attented_encoder_training_data, _, _ , _= attented_encoder([captions, image_features], initial_state=[initial_state_h, initial_state_c])
    
    training_output_data = self_attention_layer(attented_encoder_training_data)
    training_output_data = output(training_output_data)
    
    
    
    training_model = Model(inputs=[encoder_model.model.input,captions_input], outputs=[heads[0].model,heads[1].model,heads[2].model,training_output_data])
    
    initial_state_inference_model = Model(inputs=[encoder_model.model.input], outputs=[heads[0].model,heads[1].model,heads[2].model,image_features, initial_state_h, initial_state_c])
    
    inference_initial_state_h = Input(shape=(embedding_size,))
    inference_initial_state_c = Input(shape=(embedding_size,))
    image_input_feat = Input(shape=(64,D,))
    
    attented_encoder_inference_data, inference_encoder_state_h, inference_encoder_state_c, inference_attention = attented_encoder(
        [captions, image_input_feat],
        initial_state=[inference_initial_state_h, inference_initial_state_c]
        )
   
    inference_output_data = self_attention_layer(attented_encoder_inference_data)
    inference_output_data = output(inference_output_data)
     
    
    
    inference_model = Model(
        inputs=[image_input_feat, captions_input, inference_initial_state_h, inference_initial_state_c],
        outputs=[inference_output_data, inference_encoder_state_h, inference_encoder_state_c, inference_attention]
    )
    
    return training_model, inference_model, initial_state_inference_model

In [None]:
from data_loader.MTL_dataloader import get_train_validation_generator

det_csv_path = "/kaggle/input/rsna-pneumonia-detection-challenge/stage_2_train_labels.csv"
seg_csv_path = "/kaggle/input/siim-acr-pneumothorax-segmentation-data/train-rle.csv"
det_images_path = "/kaggle/input/rsna-pneumonia-detection-challenge/stage_2_train_images/"
seg_images_path = "/kaggle/input/siim-acr-pneumothorax-segmentation-data/dicom-images-train/"

report_csv1_path = "/kaggle/input/chest-xrays-indiana-university/indiana_reports.csv"
report_csv2_path = "/kaggle/input/chest-xrays-indiana-university/indiana_projections.csv"
report_images_path="/kaggle/input/chest-xrays-indiana-university/images/images_normalized/"

In [None]:
train_gen,val_gen = get_train_validation_generator(det_csv_path,seg_csv_path , det_images_path, seg_images_path,
                                                   report_csv1_path,
                                                   report_csv2_path,
                                                   report_images_path,
                                                   augmentation=True,hist_eq=True,normalize=True ,only_positive=False,batch_positive_portion=0.5)

In [None]:
X,Y = next(enumerate(train_gen))[1]
print(X[0].shape)
print(X[1].shape)

print(Y[0].shape)
print(Y[1].shape)
print(Y[2].shape)
print(Y[3].shape)

In [None]:
from utils.Encoder import Encoder
from utils.Classifier import Classifier
from utils.Detector import Detector
from utils.Segmenter import Segmenter

img_size = 256
n_classes = 1
    
encoder = Encoder(weights=None)
classifier = Classifier(encoder)
detector = Detector(encoder, img_size, n_classes)
segmenter = Segmenter(encoder)
heads=[]
heads.append(classifier)
heads.append(detector)
heads.append(segmenter)

In [None]:
embedding_size = 512
T= None
L= 8*8
D= 512

vocab_size = train_gen.report_gen.vocab_size

training_model, inference_model, initial_state_inference_model = create_model(encoder, vocab_size, embedding_size, T, L, D)

In [None]:
training_model.load_weights('../input/mtl-with-report-weights/7.hdf5')

In [None]:
from tensorflow.keras.losses import sparse_categorical_crossentropy
import tensorflow as tf

def loss_d(y_true,y_pred):
    output =  tf.cond(
                tf.math.reduce_all(tf.math.equal(y_true,-1))
                ,true_fn= lambda: tf.convert_to_tensor(0, dtype=tf.float32)
                ,false_fn= lambda: sparse_categorical_crossentropy(y_true,y_pred) )
    return output

In [None]:
from tensorflow.keras.losses import categorical_crossentropy

def class_loss(y_true,y_pred):
    return tf.cond(
                    tf.math.reduce_all(tf.math.equal(y_true,-1))
                    ,true_fn=  lambda: tf.convert_to_tensor(0, dtype=tf.float32)
                    ,false_fn= lambda: categorical_crossentropy(y_true,y_pred)
                    )

In [None]:
from tensorflow.keras.optimizers import Adam
epochs = 1
lr=1e-4
training_model.compile(loss=[classifier.loss , detector.loss , segmenter.loss , loss_d ],optimizer=Adam(lr))


In [None]:
 callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="./",
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_freq=1)

In [None]:
# this way the nan doesn't occur in training
train_gen.nb_iteration = 4

In [None]:
for k in range(200):
    training_model.fit_generator( train_gen,
                    epochs = epochs,
                    callbacks=[callback]
                                   )

In [None]:
training_model.save_weights("8.hdf5")

# Validate

In [None]:
import numpy as np

def tokens_to_text(tokens,tok,end_token='endseq'):
    sentence=""
    for token in tokens:
        if token ==0:
            break
        
        word = tok.index_word[token]
        
        if word==end_token:
            break
            
        sentence+= word+" "
        
    sentence = sentence.strip()
    
    return sentence

def predict(image,tok, initial_state_inference_model, inference_model,Y,start_token='startseq',end_token='endseq', max_len=100):

    _,_,_,image_features, init_h,init_c = initial_state_inference_model(np.expand_dims(X[0][0],axis=0))
    word = tok.word_index[start_token]

    predictions=[]
    for index in range(max_len):
        #word = Y[index]
        
        word, init_h, init_c, inference_attention = inference_model([  np.array(image_features),
                                                                       np.array([[word]]),
                                                                       np.array(init_h),
                                                                       np.array(init_c)  ] )
        
        
        
        word = tf.expand_dims(tf.squeeze(word),axis=0)
        
        word = tf.random.categorical(word, 1)[0][0].numpy()
        
        
        
        if word==tok.word_index[end_token]:
            break

        predictions.append(word)
        
    return predictions
    
    
def get_sentence_preds(image,tok, initial_state_inference_model, inference_model,Y,start_token='startseq',end_token='endseq', max_len=100):
    tokens = predict(image,tok,initial_state_inference_model,inference_model,Y)
    sentence = tokens_to_text(tokens,tok)
    return sentence



In [None]:
from MultiCheXNet.data_loader.indiana_dataloader import get_train_validation_generator
from tensorflow.keras.applications.densenet import preprocess_input
max_vocab_size=10000
max_len=100

csv_path1  ="/kaggle/input/chest-xrays-indiana-university/indiana_reports.csv"
csv_path2  ="/kaggle/input/chest-xrays-indiana-university/indiana_projections.csv"
img_path   ="/kaggle/input/chest-xrays-indiana-university/images/images_normalized/"
batch_sz = 8
validation_split = 0.2

train_dataloader, val_dataloader, vocab_size, tok = get_train_validation_generator(csv_path1,csv_path2,img_path, max_vocab_size,max_len, normalize= True,hist_eq=True, augmentation=True, batch_size=batch_sz, validation_split=validation_split,shuffle_GT_sentences=True , over_sample=True)

In [None]:
GTs = []
preds = []
for index,(X,Y) in enumerate(val_dataloader):
    print(index)
    for img,y in zip(X[0],Y):
        GT = tokens_to_text(list(y),tok)
        pred = get_sentence_preds(img,tok, initial_state_inference_model, inference_model,None)
        
        GTs.append(GT)
        preds.append(pred)
        #print(GT)
        #print("==================================================================")
        #print(pred)


        #print("")
    if index ==80:
        break

In [None]:
from nltk.translate.bleu_score import corpus_bleu
def calculate_bleu_evaluation(GT_sentences, predicted_sentences):
    BLEU_1 = corpus_bleu(GT_sentences, predicted_sentences, weights=(1.0, 0, 0, 0))
    BLEU_2 = corpus_bleu(GT_sentences, predicted_sentences, weights=(0.5, 0.5, 0, 0))
    BLEU_3 = corpus_bleu(GT_sentences, predicted_sentences, weights=(0.3, 0.3, 0.3, 0))
    BLEU_4 = corpus_bleu(GT_sentences, predicted_sentences, weights=(0.25, 0.25, 0.25, 0.25))
    
    return BLEU_1,BLEU_2,BLEU_3,BLEU_4

In [None]:
calculate_bleu_evaluation(GTs,preds)