In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, GRU
import numpy as np

In [2]:
print(tf.__version__)

2.0.0


In [3]:
T = 8 # sequence length
D = 2 # input dimensionality
M = 3 # hidden layer size

In [4]:
# random single sentence of (T=8) words (of) vectors (of Dimensionality 2)
X = np.random.randn(1, T, D)

In [5]:
def lstm(return_sequences):
    input_ = Input(shape=(T, D))
    rnn = LSTM(M, return_state=True, return_sequences=return_sequences)
    x = rnn(input_)
    
    model = Model(inputs=input_, outputs=x)
    o, h, c = model.predict(X)
    print("o:", o) # output state
    print("h:", h) # hidden state
    print("c:", c) # cell state

In [6]:
def gru(return_sequences=False):
    input_ = Input(shape=(T, D))
    rnn = GRU(M, return_state=True, return_sequences=return_sequences)
    x = rnn(input_)
    
    model = Model(inputs=input_, outputs=x)
    o, h = model.predict(X)
    print("o:", o)
    print("h:", h)

In [7]:
X

array([[[-1.58101302, -0.50853706],
        [ 0.48220456, -0.61128178],
        [-1.00432767, -1.36041766],
        [-0.48961686, -0.93750145],
        [ 1.11000796, -1.00821446],
        [ 0.62150429, -0.51635786],
        [-0.5338402 , -0.13564563],
        [-1.51934683,  0.7882209 ]]])

In [8]:
lstm(return_sequences=False)
# should see that 'h'idden is actually the 'o'utput

o: [[ 0.05730367 -0.04569765 -0.2714537 ]]
h: [[ 0.05730367 -0.04569765 -0.2714537 ]]
c: [[ 0.18275177 -0.13815974 -0.5394716 ]]


In [9]:
lstm(return_sequences=True)
# same as before, except that sequence lenght is of asize T (8)
# see that 'h' is the same as last element of 'o' 
# thus, 'h' and 'c' represent the last state of LSTM (ie last timestep)

o: [[[-0.09656115  0.19171457  0.07158016]
  [ 0.02977696  0.04933275  0.10633902]
  [-0.02770934  0.11342802  0.10969745]
  [-0.03649454  0.0972841   0.19043373]
  [ 0.16736402 -0.11725976  0.23386152]
  [ 0.19705722 -0.19162004  0.19689865]
  [ 0.06550689 -0.09403286  0.19696847]
  [-0.08422926  0.13640538  0.1767208 ]]]
h: [[-0.08422926  0.13640538  0.1767208 ]]
c: [[-0.17140211  0.32462227  0.5022582 ]]


In [10]:
gru(return_sequences=True)

o: [[[ 0.1490269   0.06900228  0.56701374]
  [-0.03951652  0.2667033   0.43546468]
  [-0.103096    0.47720635  0.64681345]
  [-0.0856227   0.5200608   0.7124688 ]
  [-0.23634195  0.641216    0.5925705 ]
  [-0.20280662  0.5694688   0.46542   ]
  [ 0.01270336  0.38558966  0.531096  ]
  [ 0.34961498  0.13376813  0.54545957]]]
h: [[0.34961498 0.13376813 0.54545957]]


In [11]:
gru(return_sequences=True)

o: [[[ 0.23264034  0.19821157  0.0483178 ]
  [-0.11523146  0.27577135 -0.04154351]
  [-0.28109428  0.49534118 -0.03242477]
  [-0.3425638   0.43852633 -0.00331145]
  [-0.552319    0.39653492 -0.17713967]
  [-0.528515    0.2527525  -0.13412812]
  [-0.20332986  0.12062275 -0.01676751]
  [ 0.26767033 -0.10721889  0.119253  ]]]
h: [[ 0.26767033 -0.10721889  0.119253  ]]
