In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNN, LSTM, Bidirectional

In [10]:
train_X = [[
    [0.1, 4.2, 1.5, 1.1, 2.8],
    [1.0, 3.1, 2.5, 0.7, 1.1],
    [0.3, 2.1, 1.5, 2.1, 0.1],
    [2.2, 1.4, 0.5, 0.9, 1.1]
]]
train_X = np.array(train_X, dtype=np.float32)
# print(np.shape(train_X))
print(train_X.shape)

(1, 4, 5)


# Simple RNN

In [24]:
rnn = SimpleRNN(3,return_sequences = True, return_state = True)
hidden_state, last_state = rnn(train_X)

print('hidden state : {}, shape : {}\n'.format(hidden_state, hidden_state.shape))
print('last hidden state : {}, shape : {}\n'.format(last_state, last_state.shape))


hidden state : [[[-0.9832972   0.20854129 -0.9622653 ]
  [-0.5839553  -0.9145554  -0.1803549 ]
  [-0.76298046 -0.42132413 -0.9220408 ]
  [-0.98496413 -0.99511445 -0.9254847 ]]], shape : (1, 4, 3)

last hidden state : [[-0.98496413 -0.99511445 -0.9254847 ]], shape : (1, 3)



# LSTM

In [26]:
lstm = LSTM(3,return_sequences = True, return_state = True)
hidden_state, last_hidden_state, last_cell_state = lstm(train_X)

print('hidden state : {}, shape : {}\n'.format(hidden_state, hidden_state.shape))
print('last hidden state : {}, shape : {}\n'.format(last_state, last_state.shape))
print('last cell state : {}, shape : {}\n'.format(last_cell_state, last_cell_state.shape))


hidden state : [[[-0.00410205 -0.0448423   0.19561355]
  [-0.01596891 -0.07216485  0.5621325 ]
  [-0.05578514 -0.12985231  0.460378  ]
  [-0.13154697 -0.06703199  0.39312464]]], shape : (1, 4, 3)

last hidden state : [[-0.98496413 -0.99511445 -0.9254847 ]], shape : (1, 3)

last cell state : [[-0.80375516 -0.268283    2.0221992 ]], shape : (1, 3)



# Bidirectional LSTM

In [30]:
k_init = tf.keras.initializers.Constant(value=0.1)
b_init = tf.keras.initializers.Constant(value=0)
r_init = tf.keras.initializers.Constant(value=0.1)

In [39]:
bilstm = Bidirectional(LSTM(3,return_sequences=False,\
                         return_state = True,\
                        bias_initializer = b_init,\
                        recurrent_initializer=r_init))
hidden_states, forward_h, forward_c, backward_h, backward_c = bilstm(train_X)
print('hidden state : {}, shape : {}\n'.format(hidden_states, hidden_state.shape))

print('foward state : {}, shape : {}\n'.format(forward_h, hidden_state.shape))
print('backward state : {}, shape : {}\n'.format(backward_h, hidden_state.shape))

print('foward state : {}, shape : {}\n'.format(forward_c, hidden_state.shape))
print('backward state : {}, shape : {}\n'.format(backward_c, hidden_state.shape))


hidden state : [[-0.10009428 -0.21149424  0.17957386 -0.0047524  -0.45616934 -0.0389562 ]], shape : (1, 4, 3)

foward state : [[-0.10009428 -0.21149424  0.17957386]], shape : (1, 4, 3)

backward state : [[-0.0047524  -0.45616934 -0.0389562 ]], shape : (1, 4, 3)

foward state : [[-1.0290873  -1.451037    0.23582846]], shape : (1, 4, 3)

backward state : [[-0.03199977 -1.2557712  -0.33501065]], shape : (1, 4, 3)



In [40]:
bilstm = Bidirectional(LSTM(3,return_sequences=True,\
                         return_state = True,\
                        bias_initializer = b_init,\
                        recurrent_initializer=r_init))
hidden_states, forward_h, forward_c, backward_h, backward_c = bilstm(train_X)
print('hidden state : {}, shape : {}\n'.format(hidden_states, hidden_state.shape))

print('foward state : {}, shape : {}\n'.format(forward_h, hidden_state.shape))
print('backward state : {}, shape : {}\n'.format(backward_h, hidden_state.shape))

print('foward state : {}, shape : {}\n'.format(forward_c, hidden_state.shape))
print('backward state : {}, shape : {}\n'.format(backward_c, hidden_state.shape))


hidden state : [[[-0.18485835  0.11145597 -0.04052968  0.48256007  0.01597929
   -0.6401246 ]
  [-0.04677725 -0.14995725 -0.07835873  0.4538554   0.0215591
   -0.63838285]
  [-0.16335599 -0.14277561 -0.32572478  0.42411008  0.02753893
   -0.47895408]
  [-0.23449121  0.4636655  -0.13165762  0.31614307  0.04366803
    0.04542973]]], shape : (1, 4, 3)

foward state : [[-0.23449121  0.4636655  -0.13165762]], shape : (1, 4, 3)

backward state : [[ 0.48256007  0.01597929 -0.6401246 ]], shape : (1, 4, 3)

foward state : [[-0.32241893  0.5670044  -0.6864971 ]], shape : (1, 4, 3)

backward state : [[ 0.7599268   0.51883084 -0.9874033 ]], shape : (1, 4, 3)

