In [1]:
import tensorflow.compat.v2 as tf

import collections
import warnings

import numpy as np
from keras import activations
from keras import backend
from keras import constraints
from keras import initializers
from keras import regularizers
from keras.engine.base_layer import Layer
from keras.engine.input_spec import InputSpec
from keras.saving.saved_model import layer_serialization
from keras.utils import control_flow_util
from keras.utils import generic_utils
from keras.utils import tf_utils
from keras.layers import RNN
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls

In [2]:
class DropoutRNNCellMixin(object):
    """Object that hold dropout related fields for RNN Cell.

    This class is not a standalone RNN cell. It suppose to be used with a RNN cell
    by multiple inheritance. Any cell that mix with class should have following
    fields:
    dropout: a float number within range [0, 1). The ratio that the input
      tensor need to dropout.
    recurrent_dropout: a float number within range [0, 1). The ratio that the
      recurrent state weights need to dropout.
    This object will create and cache created dropout masks, and reuse them for
    the incoming data, so that the same mask is used for every batch input.
    """

    def __init__(self, *args, **kwargs):
        self._create_non_trackable_mask_cache()
        super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)

    @tf.__internal__.tracking.no_automatic_dependency_tracking
    def _create_non_trackable_mask_cache(self):
        """Create the cache for dropout and recurrent dropout mask.

        Note that the following two masks will be used in "graph function" mode,
        e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
        tensors will be generated differently than in the "graph function" case,
        and they will be cached.

        Also note that in graph mode, we still cache those masks only because the
        RNN could be created with `unroll=True`. In that case, the `cell.call()`
        function will be invoked multiple times, and we want to ensure same mask
        is used every time.

        Also the caches are created without tracking. Since they are not picklable
        by python when deepcopy, we don't want `layer._obj_reference_counts_dict`
        to track it by default.
        """
        self._dropout_mask_cache = backend.ContextValueCache(
            self._create_dropout_mask)
        self._recurrent_dropout_mask_cache = backend.ContextValueCache(
            self._create_recurrent_dropout_mask)

    def reset_dropout_mask(self):
        """Reset the cached dropout masks if any.

        This is important for the RNN layer to invoke this in it `call()` method so
        that the cached mask is cleared before calling the `cell.call()`. The mask
        should be cached across the timestep within the same batch, but shouldn't
        be cached between batches. Otherwise it will introduce unreasonable bias
        against certain index of data within the batch.
        """
        self._dropout_mask_cache.clear()

    def reset_recurrent_dropout_mask(self):
        """Reset the cached recurrent dropout masks if any.

        This is important for the RNN layer to invoke this in it call() method so
        that the cached mask is cleared before calling the cell.call(). The mask
        should be cached across the timestep within the same batch, but shouldn't
        be cached between batches. Otherwise it will introduce unreasonable bias
        against certain index of data within the batch.
        """
        self._recurrent_dropout_mask_cache.clear()

    def _create_dropout_mask(self, inputs, training, count=1):
        return _generate_dropout_mask(
            tf.ones_like(inputs),
            self.dropout,
            training=training,
            count=count)

    def _create_recurrent_dropout_mask(self, inputs, training, count=1):
        return _generate_dropout_mask(
            tf.ones_like(inputs),
            self.recurrent_dropout,
            training=training,
            count=count)

    def get_dropout_mask_for_cell(self, inputs, training, count=1):
        """Get the dropout mask for RNN cell's input.

        It will create mask based on context if there isn't any existing cached
        mask. If a new mask is generated, it will update the cache in the cell.

        Args:
          inputs: The input tensor whose shape will be used to generate dropout
            mask.
          training: Boolean tensor, whether its in training mode, dropout will be
            ignored in non-training mode.
          count: Int, how many dropout mask will be generated. It is useful for cell
            that has internal weights fused together.
        Returns:
          List of mask tensor, generated or cached mask based on context.
        """
        if self.dropout == 0:
            return None
        init_kwargs = dict(inputs=inputs, training=training, count=count)
        return self._dropout_mask_cache.setdefault(kwargs=init_kwargs)

    def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
        """Get the recurrent dropout mask for RNN cell.

        It will create mask based on context if there isn't any existing cached
        mask. If a new mask is generated, it will update the cache in the cell.

        Args:
          inputs: The input tensor whose shape will be used to generate dropout
            mask.
          training: Boolean tensor, whether its in training mode, dropout will be
            ignored in non-training mode.
          count: Int, how many dropout mask will be generated. It is useful for cell
            that has internal weights fused together.
        Returns:
          List of mask tensor, generated or cached mask based on context.
        """
        if self.recurrent_dropout == 0:
            return None
        init_kwargs = dict(inputs=inputs, training=training, count=count)
        return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)

    def __getstate__(self):
        # Used for deepcopy. The caching can't be pickled by python, since it will
        # contain tensor and graph.
        state = super(DropoutRNNCellMixin, self).__getstate__()
        state.pop('_dropout_mask_cache', None)
        state.pop('_recurrent_dropout_mask_cache', None)
        return state

    def __setstate__(self, state):
        state['_dropout_mask_cache'] = backend.ContextValueCache(
            self._create_dropout_mask)
        state['_recurrent_dropout_mask_cache'] = backend.ContextValueCache(
            self._create_recurrent_dropout_mask)
        super(DropoutRNNCellMixin, self).__setstate__(state)

In [3]:
#2 Mode implementation

class MultiModalLSTMCell(DropoutRNNCellMixin, Layer):
    def __init__(self,
               units,
               activation='softmax',
               recurrent_activation='hard_sigmoid',
               use_bias=True,
               kernel_initializer='glorot_uniform',
               recurrent_initializer='orthogonal',
               bias_initializer='zeros',
               unit_forget_bias=True,
               kernel_regularizer=None,
               recurrent_regularizer=None,
               bias_regularizer=None,
               kernel_constraint=None,
               recurrent_constraint=None,
               bias_constraint=None,
               dropout=0.,
               recurrent_dropout=0.,
               **kwargs):
        if tf.compat.v1.executing_eagerly_outside_functions():
            self._enable_caching_device = kwargs.pop('enable_caching_device', True)
        else:
            self._enable_caching_device = kwargs.pop('enable_caching_device', False)
        super(MultiModalLSTMCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.unit_forget_bias = unit_forget_bias

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.dropout = min(1., max(0., dropout))
        self.recurrent_dropout = min(1., max(0., recurrent_dropout))
        implementation = kwargs.pop('implementation', 1)
        if self.recurrent_dropout != 0 and implementation != 1:
            logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
            self.implementation = 1
        else:
            self.implementation = implementation
            
        self.state_size = [self.units, self.units]
        self.output_size = self.units
        
    def build(self, input_shape):
        print(input_shape)
        default_caching_device = _caching_device(self)
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(
            shape=(input_dim, self.units * 9),
            name='kernel',
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
            caching_device=default_caching_device)
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 9),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint,
            caching_device=default_caching_device)

        if self.use_bias:
            if self.unit_forget_bias:
                def bias_initializer(_, *args, **kwargs):
                    return backend.concatenate([
                      self.bias_initializer((self.units,), *args, **kwargs),
                      initializers.get('ones')((self.units,), *args, **kwargs),
                      self.bias_initializer((self.units * 2,), *args, **kwargs),
                    ])
            else:
                bias_initializer = self.bias_initializer
            self.bias = self.add_weight(
                shape=(self.units * 9,),
                name='bias',
                initializer=bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
                caching_device=default_caching_device)
        else:
            self.bias = None
        self.built = True
            
    def _compute_carry_and_output(self, x, h_tm1, m1_tm1, m2_tm1):
        """Computes carry and output using split kernels."""
        x_i1, x_i2, x_f1, x_f2, x_m1, x_m2, x_o = x
        h_tm1_i1, h_tm1_i2, h_tm1_f1, h_tm1_f2, h_tm1_m1, h_tm1_m2, h_tm1_o = h_tm1
        i1 = self.recurrent_activation(x_i1 + backend.dot(h_tm1_i1, self.recurrent_kernel[:, :self.units]))
        i2 = self.recurrent_activation(x_i2 + backend.dot(h_tm1_i2, self.recurrent_kernel[:, self.units:self.units*2]))
        f1 = self.recurrent_activation(x_f1 + backend.dot(h_tm1_f1, self.recurrent_kernel[:, self.units*2:self.units * 3]))
        f2 = self.recurrent_activation(x_f2 + backend.dot(h_tm1_f2, self.recurrent_kernel[:, self.units*3:self.units * 4]))
        m1 = f1 * m1_tm1 + i1 * self.activation(x_m1 + backend.dot(h_tm1_m1, self.recurrent_kernel[:, self.units * 4:self.units * 5]))
        m2 = f2 * m2_tm1 + i2 * self.activation(x_m2 + backend.dot(h_tm1_m2, self.recurrent_kernel[:, self.units * 5:self.units * 6]))
        w1 = m1 * self.recurrent_activation(self.recurrent_kernel[:, self.units * 6 : self.units * 7])
        w2 = m2 * self.recurrent_activation(self.recurrent_kernel[:, self.units * 7 : self.units * 8]) 
        o = w1 + w2 + self.activation(x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 8:]))
        return m1, m2, o
    
    def _compute_carry_and_output_fused(self, z, m1_tm1, m2_tm1):
        """Computes carry and output using fused kernels."""
        z0, z1, z2, z3, z4, z5, z6, z7, z8 = z
        i1 = self.recurrent_activation(z0)
        i2 = self.recurrent_activation(z1)
        f1 = self.recurrent_activation(z2)
        f2 = self.recurrent_activation(z3)
        m1 = f1 * m1_tm1 + i1 * self.activation(z4)
        m2 = f2 * m2_tm1 + i2 * self.activation(z5)
        w1 = m1 * self.recurrent_activation(z6)
        w2 = m2 * self.recurrent_activation(z7)
        o = w1 + w2 + self.activation(z8)
        return m1, m2, o

    def call(self, inputs, states, training=None):
        h_tm1 = states[0]  # previous memory state
        m2_tm1 = states[1]  # previous carry state
        m1_tm1 = states[2]

        dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=9)
        rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
        h_tm1, training, count=9)

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i1 = inputs * dp_mask[0]
                inputs_i2 = inputs * dp_mask[1]
                inputs_f1 = inputs * dp_mask[2]
                inputs_f2 = inputs * dp_mask[3]
                inputs_m1 = inputs * dp_mask[4]
                inputs_m2 = inputs * dp_mask[5]
                inputs_w1 = inputs * dp_mask[6]
                inputs_w2 = inputs * dp_mask[7]
                inputs_o = inputs * dp_mask[8]
            else:
                inputs_i1 = inputs
                inputs_i2 = inputs
                inputs_f1 = inputs
                inputs_f2 = inputs
                inputs_m1 = inputs
                inputs_m2 = inputs
                inputs_w1 = inputs
                inputs_w2 = inputs
                inputs_o = inputs
            k_i1, k_i2, k_f1, k_f2, k_m1, k_m2, k_w1, k_w2, k_o = tf.split(self.kernel, num_or_size_splits=9, axis=1)
            x_i1 = backend.dot(inputs_i1, k_i1)
            x_i2 = backend.dot(inputs_i2, k_i2)
            x_f1 = backend.dot(inputs_f1, k_f1)
            x_f2 = backend.dot(inputs_f2, k_f2)
            x_m1 = backend.dot(inputs_m1, k_m1)
            x_m2 = backend.dot(inputs_m2, k_m2)
            x_w1 = backend.dot(inputs_w1, k_w1)
            x_w2 = backend.dot(inputs_w2, k_w2)
            x_o = backend.dot(inputs_o, k_o)
            if self.use_bias:
                b_i1, b_i2, b_f1, b_f2, b_m1, b_m2, b_w1, b_w2, b_o = tf.split(self.bias, num_or_size_splits=9, axis=0)
                x_i1 = backend.bias_add(x_i1, b_i1)
                x_i2 = backend.bias_add(x_i2, b_i2)
                x_f1 = backend.bias_add(x_f1, b_f1)
                x_f2 = backend.bias_add(x_f2, b_f2)
                x_m1 = backend.bias_add(x_m1, b_m1)
                x_m2 = backend.bias_add(x_m2, b_m2)
                x_w1 = backend.bias_add(x_w1, b_w1)
                x_w2 = backend.bias_add(x_w2, b_w2)
                x_o = backend.bias_add(x_o, b_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i1 = h_tm1 * rec_dp_mask[0]
                h_tm1_i2 = h_tm1 * rec_dp_mask[1]
                h_tm1_f1 = h_tm1 * rec_dp_mask[2]
                h_tm1_f2 = h_tm1 * rec_dp_mask[3]
                h_tm1_m1 = h_tm1 * rec_dp_mask[4]
                h_tm1_m2 = h_tm1 * rec_dp_mask[5]
                h_tm1_w1 = h_tm1 * rec_dp_mask[6]
                h_tm1_w2 = h_tm1 * rec_dp_mask[7]
                h_tm1_o = h_tm1 * rec_dp_mask[8]
            else:
                h_tm1_i1 = h_tm1
                h_tm1_i2 = h_tm1
                h_tm1_f1 = h_tm1
                h_tm1_f2 = h_tm1
                h_tm1_m1 = h_tm1
                h_tm1_m2 = h_tm1
                h_tm1_w1 = h_tm1
                h_tm1_w2 = h_tm1
                h_tm1_o = h_tm1
            x = (x_i1, x_i2, x_f1, x_f2, x_m1, x_m2, x_w1, x_w2, x_o)
            h_tm1 = (h_tm1_i1, h_tm1_i2, h_tm1_f1, h_tm1_f2, h_tm1_m1, h_tm1_m2, h_tm1_w1, h_tm1_w2, h_tm1_o)
            m1, m2, o = self._compute_carry_and_output(x, h_tm1, m1_tm1, m2_tm1)
        else:
            if 0. < self.dropout < 1.:
                inputs = inputs * dp_mask[0]
            z = backend.dot(inputs, self.kernel)
            z += backend.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = backend.bias_add(z, self.bias)
            
            z = tf.split(z, num_or_size_splits=9, axis=1)
            m1, m2, o = self._compute_carry_and_output_fused(z, c_tm1)
        
        h = o
        
        return h

    def get_config(self):
        config = {
            'units':
                self.units,
            'activation':
                activations.serialize(self.activation),
            'recurrent_activation':
                activations.serialize(self.recurrent_activation),
            'use_bias':
                self.use_bias,
            'kernel_initializer':
                initializers.serialize(self.kernel_initializer),
            'recurrent_initializer':
                initializers.serialize(self.recurrent_initializer),
            'bias_initializer':
                initializers.serialize(self.bias_initializer),
            'unit_forget_bias':
                self.unit_forget_bias,
            'kernel_regularizer':
                regularizers.serialize(self.kernel_regularizer),
            'recurrent_regularizer':
                regularizers.serialize(self.recurrent_regularizer),
            'bias_regularizer':
                regularizers.serialize(self.bias_regularizer),
            'kernel_constraint':
                constraints.serialize(self.kernel_constraint),
            'recurrent_constraint':
                constraints.serialize(self.recurrent_constraint),
            'bias_constraint':
                constraints.serialize(self.bias_constraint),
            'dropout':
                self.dropout,
            'recurrent_dropout':
                self.recurrent_dropout,
            'implementation':
                self.implementation
        }
        config.update(_config_for_enable_caching_device(self))
        base_config = super(MultiModalLSTMCell, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
        return list(_generate_zero_filled_state_for_cell(
            self, inputs, batch_size, dtype))


In [4]:
class MultiModalLSTM(RNN):
    def __init__(self,
                 units,
                 activation='softmax',
                 recurrent_activation='hard_sigmoid',
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 bias_initializer='zeros',
                 unit_forget_bias=True,
                 kernel_regularizer=None,
                 recurrent_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 dropout=0.,
                 recurrent_dropout=0.,
                 return_sequences=False,
                 return_state=False,
                 go_backwards=False,
                 stateful=False,
                 unroll=False,
                 **kwargs):
        implementation = kwargs.pop('implementation', 1)
        if implementation == 0:
            logging.warning('`implementation=0` has been deprecated, '
                            'and now defaults to `implementation=1`.'
                            'Please update your layer call.')
        if 'enable_caching_device' in kwargs:
            cell_kwargs = {'enable_caching_device':
                           kwargs.pop('enable_caching_device')}
        else:
            cell_kwargs = {}
        cell = MultiModalLSTMCell(units,
                        activation=activation,
                        recurrent_activation=recurrent_activation,
                        use_bias=use_bias,
                        kernel_initializer=kernel_initializer,
                        recurrent_initializer=recurrent_initializer,
                        unit_forget_bias=unit_forget_bias,
                        bias_initializer=bias_initializer,
                        kernel_regularizer=kernel_regularizer,
                        recurrent_regularizer=recurrent_regularizer,
                        bias_regularizer=bias_regularizer,
                        kernel_constraint=kernel_constraint,
                        recurrent_constraint=recurrent_constraint,
                        bias_constraint=bias_constraint,
                        dropout=dropout,
                        recurrent_dropout=recurrent_dropout,
                        implementation=implementation,
                        dtype=kwargs.get('dtype'),
                        trainable=kwargs.get('trainable', True),
                        **cell_kwargs)
        super(MultiModalLSTM, self).__init__(
            cell,
            return_sequences=return_sequences,
            return_state=return_state,
            go_backwards=go_backwards,
            stateful=stateful,
            unroll=unroll,
            **kwargs)
        self.activity_regularizer = regularizers.get(activity_regularizer)
        self.input_spec = [InputSpec(ndim=3)]
        
    def call(self, inputs, mask=None, training=None, initial_state=None):
        return super(MultiModalLSTM, self).call(inputs, mask=mask, training=training, initial_state=initial_state)
    
    @property
    def units(self):
        return self.cell.units

    @property
    def activation(self):
        return self.cell.activation

    @property
    def recurrent_activation(self):
        return self.cell.recurrent_activation

    @property
    def use_bias(self):
        return self.cell.use_bias

    @property
    def kernel_initializer(self):
        return self.cell.kernel_initializer

    @property
    def recurrent_initializer(self):
        return self.cell.recurrent_initializer

    @property
    def bias_initializer(self):
        return self.cell.bias_initializer

    @property
    def unit_forget_bias(self):
        return self.cell.unit_forget_bias

    @property
    def kernel_regularizer(self):
        return self.cell.kernel_regularizer

    @property
    def recurrent_regularizer(self):
        return self.cell.recurrent_regularizer

    @property
    def bias_regularizer(self):
        return self.cell.bias_regularizer

    @property
    def kernel_constraint(self):
        return self.cell.kernel_constraint

    @property
    def recurrent_constraint(self):
        return self.cell.recurrent_constraint

    @property
    def bias_constraint(self):
        return self.cell.bias_constraint

    @property
    def dropout(self):
        return self.cell.dropout

    @property
    def recurrent_dropout(self):
        return self.cell.recurrent_dropout

    @property
    def implementation(self):
        return self.cell.implementation

    def get_config(self):
        config = {
            'units': 
                self.units,
            'activation':
                activations.serialize(self.activation),
            'recurrent_activation':
                activations.serialize(self.recurrent_activation),
            'use_bias':
                self.use_bias,
            'kernel_initializer':
                initializers.serialize(self.kernel_initializer),
            'recurrent_initializer':
                initializers.serialize(self.recurrent_initializer),
            'bias_initializer':
                initializers.serialize(self.bias_initializer),
            'unit_forget_bias':
                self.unit_forget_bias,
            'kernel_regularizer':
                regularizers.serialize(self.kernel_regularizer),
            'recurrent_regularizer':
                regularizers.serialize(self.recurrent_regularizer),
            'bias_regularizer':
                regularizers.serialize(self.bias_regularizer),
            'activity_regularizer':
                regularizers.serialize(self.activity_regularizer),
            'kernel_constraint':
                constraints.serialize(self.kernel_constraint),
            'recurrent_constraint':
                constraints.serialize(self.recurrent_constraint),
            'bias_constraint':
                constraints.serialize(self.bias_constraint),
            'dropout':
                self.dropout,
            'recurrent_dropout':
                self.recurrent_dropout,
            'implementation':
                self.implementation
        }
        config.update(_config_for_enable_caching_device(self.cell))
        base_config = super(MultiModalLSTM, self).get_config()
        del base_config['cell']
        return dict(list(base_config.items()) + list(config.items()))
    
    @classmethod
    def from_config(cls, config):
        if 'implementation' in config and config['implementation'] == 0:
            config['implementation'] = 1
        return cls(**config)
    
def _generate_dropout_mask(ones, rate, training=None, count=1):
    def dropped_inputs():
        return backend.dropout(ones, rate)

    if count > 1:
        return [
            backend.in_train_phase(dropped_inputs, ones, training=training)
            for _ in range(count)
        ]
    
    return backend.in_train_phase(dropped_inputs, ones, training=training)


def _standardize_args(inputs, initial_state, constants, num_constants):
    """Standardizes `__call__` to a single list of tensor inputs.

    When running a model loaded from a file, the input tensors
    `initial_state` and `constants` can be passed to `RNN.__call__()` as part
    of `inputs` instead of by the dedicated keyword arguments. This method
    makes sure the arguments are separated and that `initial_state` and
    `constants` are lists of tensors (or None).

    Args:
    inputs: Tensor or list/tuple of tensors. which may include constants
      and initial states. In that case `num_constant` must be specified.
    initial_state: Tensor or list of tensors or None, initial states.
    constants: Tensor or list of tensors or None, constant tensors.
    num_constants: Expected number of constants (if constants are passed as
      part of the `inputs` list.

    Returns:
    inputs: Single tensor or tuple of tensors.
    initial_state: List of tensors or None.
    constants: List of tensors or None.
    """
    if isinstance(inputs, list):
        # There are several situations here:
        # In the graph mode, __call__ will be only called once. The initial_state
        # and constants could be in inputs (from file loading).
        # In the eager mode, __call__ will be called twice, once during
        # rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
        # model.fit/train_on_batch/predict with real np data. In the second case,
        # the inputs will contain initial_state and constants as eager tensor.
        #
        # For either case, the real input is the first item in the list, which
        # could be a nested structure itself. Then followed by initial_states, which
        # could be a list of items, or list of list if the initial_state is complex
        # structure, and finally followed by constants which is a flat list.
        assert initial_state is None and constants is None
        if num_constants:
            constants = inputs[-num_constants:]
            inputs = inputs[:-num_constants]
        if len(inputs) > 1:
            initial_state = inputs[1:]
            inputs = inputs[:1]
            
        if len(inputs) > 1:
            inputs = tuple(inputs)
        else:
            inputs = inputs[0]

    def to_list_or_none(x):
        if x is None or isinstance(x, list):
            return x
        if isinstance(x, tuple):
            return list(x)
        return [x]

    initial_state = to_list_or_none(initial_state)
    constants = to_list_or_none(constants)
    
    return inputs, initial_state, constants


def _is_multiple_state(state_size):
    """Check whether the state_size contains multiple states."""
    return (hasattr(state_size, '__len__') and
            not isinstance(state_size, tf.TensorShape))


def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
    if inputs is not None:
        batch_size = tf.shape(inputs)[0]
        dtype = inputs.dtype
    return _generate_zero_filled_state(batch_size, cell.state_size, dtype)


def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
    """Generate a zero filled tensor with shape [batch_size, state_size]."""
    if batch_size_tensor is None or dtype is None:
        raise ValueError(
            'batch_size and dtype cannot be None while constructing initial state: '
            'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))

    def create_zeros(unnested_state_size):
        flat_dims = tf.TensorShape(unnested_state_size).as_list()
        init_state_size = [batch_size_tensor] + flat_dims
        return tf.zeros(init_state_size, dtype=dtype)

    if tf.nest.is_nested(state_size):
        return tf.nest.map_structure(create_zeros, state_size)
    else:
        return create_zeros(state_size)


def _caching_device(rnn_cell):
    """Returns the caching device for the RNN variable.

    This is useful for distributed training, when variable is not located as same
    device as the training worker. By enabling the device cache, this allows
    worker to read the variable once and cache locally, rather than read it every
    time step from remote when it is needed.

    Note that this is assuming the variable that cell needs for each time step is
    having the same value in the forward path, and only gets updated in the
    backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the
    cell body relies on any variable that gets updated every time step, then
    caching device will cause it to read the stall value.

    Args:
    rnn_cell: the rnn cell instance.
    """
    if tf.executing_eagerly():
        # caching_device is not supported in eager mode.
        return None
    if not getattr(rnn_cell, '_enable_caching_device', False):
        return None
    # Don't set a caching device when running in a loop, since it is possible that
    # train steps could be wrapped in a tf.while_loop. In that scenario caching
    # prevents forward computations in loop iterations from re-reading the
    # updated weights.
    if control_flow_util.IsInWhileLoop(tf.compat.v1.get_default_graph()):
        logging.warning(
            'Variable read device caching has been disabled because the '
            'RNN is in tf.while_loop loop context, which will cause '
            'reading stalled value in forward path. This could slow down '
            'the training due to duplicated variable reads. Please '
            'consider updating your code to remove tf.while_loop if possible.')
        return None
    if (rnn_cell._dtype_policy.compute_dtype !=
        rnn_cell._dtype_policy.variable_dtype):
        logging.warning(
            'Variable read device caching has been disabled since it '
            'doesn\'t work with the mixed precision API. This is '
            'likely to cause a slowdown for RNN training due to '
            'duplicated read of variable for each timestep, which '
            'will be significant in a multi remote worker setting. '
            'Please consider disabling mixed precision API if '
            'the performance has been affected.')
        return None
    # Cache the value on the device that access the variable.
    return lambda op: op.device


def _config_for_enable_caching_device(rnn_cell):
    """Return the dict config for RNN cell wrt to enable_caching_device field.

    Since enable_caching_device is a internal implementation detail for speed up
    the RNN variable read when running on the multi remote worker setting, we
    don't want this config to be serialized constantly in the JSON. We will only
    serialize this field when a none default value is used to create the cell.
    Args:
    rnn_cell: the RNN cell for serialize.

    Returns:
    A dict which contains the JSON config for enable_caching_device value or
    empty dict if the enable_caching_device value is same as the default value.
    """
    default_enable_caching_device = tf.compat.v1.executing_eagerly_outside_functions()
    if rnn_cell._enable_caching_device != default_enable_caching_device:
        return {'enable_caching_device': rnn_cell._enable_caching_device}
    return {}

In [5]:
from keras.models import Sequential
from keras.layers import Dense, Dropout
import numpy as np
import random

In [6]:
import pandas as pd

df = pd.read_csv("the-circor-digiscope-phonocardiogram-dataset-1.0.3/training_data.csv")

In [7]:
patient_id = list(df["Patient ID"])
recording_loc = list(df["Recording locations:"])
murmur = list(df["Murmur"])
murmur_loc = list(df["Murmur locations"])
systolic_murmur_timing = list(df["Systolic murmur timing"])

In [8]:
recording_loc = [x.split("+") for x in recording_loc]

for i in range(len(murmur_loc)):
    if murmur_loc[i] is np.nan:
        murmur_loc[i] = []
    else:
        murmur_loc[i] = murmur_loc[i].split("+")

In [9]:
X = []
y = []

for i in range(len(patient_id)):
    if murmur[i] == "Absent":
        for recloc in recording_loc[i]:
            full_recording = np.loadtxt("SSE/" + str(patient_id[i]) + "_" + recloc + "_features.csv", delimiter=',')
            initial = 0
            while (initial+400) <= len(full_recording):
                X.append(full_recording[initial:initial+400])
                y.append([1, 0, 0, 0, 0])
                initial += 200
    elif murmur[i] == "Present":
        for recloc in recording_loc[i]:
            if recloc in murmur_loc[i]:
                full_recording = np.loadtxt("SSE/" + str(patient_id[i]) + "_" + recloc + "_features.csv", delimiter=',')
                initial = 0
                if systolic_murmur_timing[i] == "Holosystolic":
                    while (initial+400) <= len(full_recording):
                        X.append(full_recording[initial:initial+400])
                        y.append([0, 1, 0, 0, 0])
                        initial += 200
                elif systolic_murmur_timing[i] == "Early-systolic":
                    while (initial+400) <= len(full_recording):
                        X.append(full_recording[initial:initial+400])
                        y.append([0, 0, 1, 0, 0])
                        initial += 200
                elif systolic_murmur_timing[i] == "Mid-systolic":
                    while (initial+400) <= len(full_recording):
                        X.append(full_recording[initial:initial+400])
                        y.append([0, 0, 0, 1, 0])
                        initial += 200
                elif systolic_murmur_timing[i] == "Late-systolic":
                    while (initial+400) <= len(full_recording):
                        X.append(full_recording[initial:initial+400])
                        y.append([0, 0, 0, 0, 1])
                        initial += 200
                else:
                    print("Error on patient ID:", patient_id[i])
            else:
                full_recording = np.loadtxt("SSE/" + str(patient_id[i]) + "_" + recloc + "_features.csv", delimiter=',')
                initial = 0
                while (initial+400) <= len(full_recording):
                    X.append(full_recording[initial:initial+400])
                    y.append([1, 0, 0, 0, 0])
                    initial += 200

Error on patient ID: 85119
Error on patient ID: 85119


In [10]:
time_steps = 400
feature_length = 18
batch_size = 3

In [11]:
from sklearn.model_selection import train_test_split as tts
X_train, X_test, y_train, y_test = tts(X, y, test_size = 0.25)

In [12]:
X_train_final = []
y_train_final = []

for i in range(len(X_train)):
    if y_train[i] == [1, 0, 0, 0, 0]:
        #randnum = random.uniform(0, 1)
        randnum = random.random()
        if randnum >= 0.66:
            X_train_final.append(X_train[i])
            y_train_final.append(y_train[i])
    else:
        X_train_final.append(X_train[i])
        y_train_final.append(y_train[i])

In [13]:
print(len(X_train), len(y_train))
print(len(X_train_final), len(y_train_final))
print(len(X_test), len(y_test))

22341 22341
10121 10121
7447 7447


In [14]:
X_train = np.array(X_train)
y_train = np.array(y_train)
X_train_final = np.array(X_train_final)
y_train_final = np.array(y_train_final)
X_test = np.array(X_test)
y_test = np.array(y_test)

In [15]:
model = Sequential()
model.add(MultiModalLSTM(256, input_shape = (400, 18), return_sequences=True))
model.add(Dropout(0.2))
model.add(MultiModalLSTM(128, return_sequences=True))
model.add(Dropout(0.2))
model.add(MultiModalLSTM(64, return_sequences=False))
model.add(Dense(5, activation='softmax'))
model.summary()

ValueError: The initial value's shape ((1024,)) is not compatible with the explicitly supplied `shape` argument ((2304,)).