In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNN, LSTM, Bidirectional

In [4]:
#batch size = 1
#string length = 4 (timesteps)
#word dimension = 5 (input dimension)
train_X = [[[0.1, 4.2, 1.5, 1.1, 2.8], 
           [1.0, 3.1, 2.5, 0.7, 1.1], 
           [0.3, 2.1, 1.5, 2.1, 0.1], 
           [2.2, 1.4, 0.5, 0.9, 1.1]]]

train_X = np.array(train_X, dtype=np.float32)

print(train_X.shape)

(1, 4, 5)


## RNN

In [5]:
rnn = SimpleRNN(3)
hidden_state = rnn(train_X)

print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape))

hidden state : [[ 0.21673541  0.11997614 -0.4476763 ]], shape: (1, 3)


In [7]:
rnn = SimpleRNN(3, return_sequences=True)
hidden_states = rnn(train_X)

print('hidden states : \n{}\n shape: {}'.format(hidden_states, hidden_states.shape))

hidden states : 
[[[ 0.99664104  0.50199795 -0.9252158 ]
  [ 0.9942719   0.36647075 -0.1466478 ]
  [ 0.9954034   0.87517875 -0.23367552]
  [ 0.9619131   0.892164   -0.79656494]]]
 shape: (1, 4, 3)


In [8]:
rnn = SimpleRNN(3, return_sequences=False, return_state=True)
hidden_state, last_state = rnn(train_X)

print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape))
print('last hidden state : {}, shape: {}'.format(last_state, last_state.shape))

hidden state : [[-0.98200625  0.9830647  -0.60710776]], shape: (1, 3)
last hidden state : [[-0.98200625  0.9830647  -0.60710776]], shape: (1, 3)


## LSTM

In [9]:
lstm = LSTM(3, return_sequences=False, return_state=True)
hidden_state, last_state, last_cell_state = lstm(train_X)

print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape))
print('last hidden state : {}, shape: {}'.format(last_state, last_state.shape))
print('last cell state : {}, shape: {}'.format(last_cell_state, last_cell_state.shape))

hidden state : [[-0.18132937  0.1757587  -0.28994292]], shape: (1, 3)
last hidden state : [[-0.18132937  0.1757587  -0.28994292]], shape: (1, 3)
last cell state : [[-0.47738892  0.33070382 -0.52610266]], shape: (1, 3)


In [10]:
lstm = LSTM(3, return_sequences=True, return_state=True)
hidden_states, last_hidden_state, last_cell_state = lstm(train_X)

print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('last hidden state : {}, shape: {}'.format(last_hidden_state, last_hidden_state.shape))
print('last cell state : {}, shape: {}'.format(last_cell_state, last_cell_state.shape))

hidden states : [[[-0.00999284  0.05313414 -0.16249251]
  [ 0.00461868  0.10489163 -0.58243823]
  [-0.05330745  0.09463956 -0.6670565 ]
  [-0.05722552  0.1392165  -0.79728633]]], shape: (1, 4, 3)
last hidden state : [[-0.05722552  0.1392165  -0.79728633]], shape: (1, 3)
last cell state : [[-0.14565682  0.4232592  -1.3951226 ]], shape: (1, 3)


## Bidirectional LSTM

In [11]:
k_init = tf.keras.initializers.Constant(value=0.1)
b_init = tf.keras.initializers.Constant(value=0)
r_init = tf.keras.initializers.Constant(value=0.1)

In [12]:
bilstm = Bidirectional(LSTM(3, return_sequences=False, return_state=True, \
                            kernel_initializer=k_init, bias_initializer=b_init, recurrent_initializer=r_init))
hidden_states, forward_h, forward_c, backward_h, backward_c = bilstm(train_X)

print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('forward state : {}, shape: {}'.format(forward_h, forward_h.shape))
print('backward state : {}, shape: {}'.format(backward_h, backward_h.shape))

hidden states : [[0.6303138 0.6303138 0.6303138 0.7038734 0.7038734 0.7038734]], shape: (1, 6)
forward state : [[0.6303138 0.6303138 0.6303138]], shape: (1, 3)
backward state : [[0.7038734 0.7038734 0.7038734]], shape: (1, 3)


In [14]:
bilstm = Bidirectional(LSTM(3, return_sequences=True, return_state=True, \
                            kernel_initializer=k_init, bias_initializer=b_init, recurrent_initializer=r_init))
hidden_states, forward_h, forward_c, backward_h, backward_c = bilstm(train_X)

print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('forward state : {}, shape: {}'.format(forward_h, forward_h.shape))
print('backward state : {}, shape: {}'.format(backward_h, backward_h.shape))

hidden states : [[[0.35906473 0.35906473 0.35906473 0.7038734  0.7038734  0.7038734 ]
  [0.55111325 0.55111325 0.55111325 0.58863586 0.58863586 0.58863586]
  [0.59115744 0.59115744 0.59115744 0.3951699  0.3951699  0.3951699 ]
  [0.6303138  0.6303138  0.6303138  0.21942244 0.21942244 0.21942244]]], shape: (1, 4, 6)
forward state : [[0.6303138 0.6303138 0.6303138]], shape: (1, 3)
backward state : [[0.7038734 0.7038734 0.7038734]], shape: (1, 3)
