In [84]:
import tensorflow as tf
from tensorflow import keras

## Standard LSTM

In [214]:
class iLSTMCell(keras.layers.Layer):
    def __init__(self, units=300, seed=42, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.state_size = (units, units)
        self.seed = seed

        self.f_layer = keras.layers.Dense(units, use_bias=True, activation='sigmoid', kernel_initializer='glorot_uniform', bias_initializer='ones', name='f_gate')
        self.i_layer = keras.layers.Dense(units, use_bias=True, activation='sigmoid', kernel_initializer='glorot_uniform', name='i_gate')
        self.o_layer = keras.layers.Dense(units, use_bias=True, activation='sigmoid', kernel_initializer='glorot_uniform', name='o_gate')
        self.c_layer = tf.keras.layers.Dense(units, use_bias=True, activation='tanh', kernel_initializer='glorot_uniform', name='input_layer')


    def call(self, x, state):
        h = state[0]
        c = state[1]
        inputs = tf.concat([x, h], axis=-1)

        f_gate = self.f_layer(inputs)
        i_gate = self.i_layer(inputs)
        o_gate = self.o_layer(inputs)
        c_tmp = self.c_layer(inputs)
        c = f_gate * c + i_gate * c_tmp
        h = o_gate * new_c
        return h, [h, c]

## Efficient LSTM

In [None]:
from tensorflow.keras import layers
from tensorflow.keras import initializers, activations

class iLSTMCell(layers.Layer):
    def __init__(self, units=300, seed=42, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.state_size = (units, units)
        self.seed = seed
        
        self.dropout = layers.Dropout(0.2)
        self.activation = activations.get('tanh')
        self.recurrent_activation = activations.get('sigmoid')

    def build(self, input_shape):
        super().build(input_shape)
        input_dim = input_shape[-1]
        
        self.kernel = self.add_weight(
            shape=(input_dim, self.units * 4),
            name="kernel",
            initializer="glorot_uniform",
            regularizer=tf.keras.regularizers.L2(0.01),
        )
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 4),
            name="recurrent_kernel",
            initializer="glorot_uniform",
            regularizer=tf.keras.regularizers.L2(0.01),
        )

        def bias_initializer(_, *args, **kwargs):
            return tf.concat(
                [
                    initializers.Zeros()((self.units,), *args, **kwargs),
                    initializers.get("ones")((self.units,), *args, **kwargs),
                    initializers.Zeros()((self.units * 2,), *args, **kwargs),
                ], -1
            )
        self.bias = self.add_weight(
            shape=(self.units * 4,),
            name="bias",
            initializer=bias_initializer,
            regularizer=tf.keras.regularizers.L1(0.01),
        )
        self.built = True

    def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
        if inputs is not None:
            batch_size = tf.shape(inputs)[0]
            dtype = inputs.dtype

        return [
            tf.random.normal((batch_size, self.state_size[0]), dtype=dtype, seed=self.seed),
            tf.random.normal((batch_size, self.state_size[0]), dtype=dtype, seed=self.seed+1),
        ]


    def call(self, inputs, states, training=None):
        h0 = states[0]
        c0 = states[1]

        z = tf.matmul(inputs, self.kernel)
        z += tf.matmul(h0, self.recurrent_kernel)
        z = tf.nn.bias_add(z, self.bias)
        z = self.dropout(z)

        z = tf.split(z, num_or_size_splits=4, axis=1)
        z0, z1, z2, z3 = z
        i = self.recurrent_activation(z0)
        f = self.recurrent_activation(z1)
        c1 = f * c0 + i * self.activation(z2)
        o = self.recurrent_activation(z3)

        h1 = o * self.activation(c1)
        return h1, [h1, c1]

In [None]:
outputs = input_layer(inputs)
model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model")
model.summary()

In [196]:
d = layers.Dense(5)
a = tf.convert_to_tensor([[[0.1, 0.2], [0.01, 0.02]], [[0.05, 0.2], [0.05, 0.02]]])
b = tf.convert_to_tensor([[[0.1, 0.2], [0.01, 0.02]], [[0.05, 0.2], [0.05, 0.02]]])
c = tf.concat([a, b], axis=-1)
print(a.shape, b.shape, c.shape)
d.weights

(2, 2, 2) (2, 2, 2) (2, 2, 4)


[]

In [227]:
from keras.models import Sequential

model = Sequential()
model.add(layers.RNN(
    iLSTMCell(units=3),
    return_sequences=True
))

In [229]:
model(a)

<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[0., 0., 0.],
        [0., 0., 0.]],

       [[0., 0., 0.],
        [0., 0., 0.]]], dtype=float32)>