In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, GRU
import numpy as np

In [2]:
print(tf.__version__)

2.0.0


In [3]:
T = 8 # sequence length
D = 2 # input dimensionality
M = 3 # hidden layer size

In [4]:
# random single sentence of (T=8) words (of) vectors (of Dimensionality 2)
X = np.random.randn(1, T, D)

In [5]:
def lstm(return_sequences):
    input_ = Input(shape=(T, D))
    rnn = LSTM(M, return_state=True, return_sequences=return_sequences)
    x = rnn(input_)
    
    model = Model(inputs=input_, outputs=x)
    o, h, c = model.predict(X)
    print("o:", o) # output state
    print("h:", h) # hidden state
    print("c:", c) # cell state

In [6]:
def gru(return_sequences=False):
    input_ = Input(shape=(T, D))
    rnn = GRU(M, return_state=True, return_sequences=return_sequences)
    x = rnn(input_)
    
    model = Model(inputs=input_, outputs=x)
    o, h = model.predict(X)
    print("o:", o)
    print("h:", h)

In [7]:
X

array([[[ 0.42534953, -2.05203147],
        [ 0.45866982,  0.2726821 ],
        [-1.11217119, -0.45322218],
        [ 1.67636628,  0.40574839],
        [-0.30272344, -1.54502383],
        [ 0.31148515,  0.32733057],
        [-0.41662199, -0.63395068],
        [-0.93057375,  0.52341877]]])

In [8]:
print('lstm(return_sequences=False)')
lstm(return_sequences=False)
# should see that 'h'idden is actually the 'o'utput

lstm(return_sequences=False)
o: [[-0.13610387 -0.03827858  0.01187589]]
h: [[-0.13610387 -0.03827858  0.01187589]]
c: [[-0.25883973 -0.08570366  0.02630636]]


In [9]:
print('lstm(return_sequences=True)')
lstm(return_sequences=True)
# same as before, except that sequence lenght is of asize T (8)
# see that 'h' is the same as last element of 'o' 
# thus, 'h' and 'c' represent the last state of LSTM (ie last timestep)

lstm(return_sequences=True)
o: [[[ 0.05537746  0.02390799  0.18071148]
  [ 0.12642565 -0.01081144  0.14512156]
  [ 0.08062249  0.02655875  0.12119054]
  [ 0.15234677 -0.06671423  0.11360002]
  [ 0.10770918 -0.03812384  0.26930055]
  [ 0.17586416 -0.04390792  0.17320353]
  [ 0.13328935 -0.0153063   0.19501005]
  [ 0.11396513  0.00746261  0.09307939]]]
h: [[0.11396513 0.00746261 0.09307939]]
c: [[0.19826348 0.01449099 0.14929938]]


In [10]:
print('gru(return_sequences=False)')
gru(return_sequences=True)

gru(return_sequences=False)
o: [[[ 0.3429041   0.46695912  0.19873063]
  [ 0.04636335  0.06543563 -0.01808175]
  [ 0.54613626  0.3499534   0.12801264]
  [ 0.28386107 -0.17762409 -0.29888126]
  [ 0.5421966   0.4044663  -0.07630345]
  [ 0.12815757  0.02214268 -0.09808372]
  [ 0.34162334  0.26588506  0.08139671]
  [ 0.3102122   0.15861422  0.03027977]]]
h: [[0.3102122  0.15861422 0.03027977]]


In [11]:
print('gru(return_sequences=True)')
gru(return_sequences=True)

gru(return_sequences=True)
o: [[[-0.3739081   0.37828812 -0.06370363]
  [ 0.00392233  0.22754148  0.03395929]
  [-0.05687128  0.0427503  -0.1035232 ]
  [ 0.15019408  0.09942636  0.1825223 ]
  [-0.18950835  0.43134612  0.05081913]
  [ 0.04437899  0.22612762  0.07575828]
  [-0.04645113  0.25450534 -0.03373754]
  [ 0.00423348 -0.24768388 -0.08856485]]]
h: [[ 0.00423348 -0.24768388 -0.08856485]]
