In [20]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Bidirectional, GRU
import numpy as np
import warnings

warnings.filterwarnings('ignore')

In [33]:
B = 2
T = 5
D = 2
U = 3

X = np.random.randn(B, T, D)
X.shape, X

((2, 5, 2),
 array([[[ 1.44491378, -0.37883838],
         [ 1.136315  ,  1.04057081],
         [-0.21683899, -1.55339637],
         [ 0.01449   , -0.94652504],
         [ 0.85272905,  0.03959404]],
 
        [[ 1.28862533, -1.30228365],
         [-0.62834433, -0.33669056],
         [ 0.27760031, -1.46619091],
         [ 1.43595219,  1.14082813],
         [-1.11618955,  0.862076  ]]]))

In [22]:
def lstm(data=X, return_sequences=False):
  inp = Input(shape=(T, D))
  out = LSTM(U, return_sequences=return_sequences)(inp)
  model = Model(inputs=inp, outputs=out)
  return model.predict(X)

In [23]:
print('----return_sequences=False----')
lstm_out = lstm()
lstm_out.shape, lstm_out

----return_sequences=False----


((2, 3),
 array([[ 0.2305741 , -0.03407818,  0.06759147],
        [ 0.21218197,  0.03541003,  0.12211575]], dtype=float32))

In [24]:
print('----return_sequences=True')
lstm_out = lstm(return_sequences=True)
lstm_out.shape, lstm_out

----return_sequences=True


((2, 5, 3),
 array([[[ 0.05036221,  0.03548711, -0.16693464],
         [ 0.36756083, -0.21324664,  0.14096953],
         [ 0.38542596, -0.2888657 ,  0.27480155],
         [ 0.1270014 , -0.07893647,  0.17077789],
         [ 0.22793996, -0.21785195,  0.27632308]],
 
        [[-0.00777443,  0.04540763, -0.07044486],
         [-0.06497024, -0.3009203 ,  0.0902238 ],
         [-0.03417871, -0.06164884,  0.07642759],
         [-0.03974312, -0.19663633,  0.12427132],
         [-0.17726132,  0.01890991,  0.04646362]]], dtype=float32))

In [25]:
def lstm(return_state=False):
  inp = Input(shape=(T, D))
  out = LSTM(U, return_state=return_state)(inp)
  model = Model(inputs=inp, outputs=out)
  if return_state:
    o, h, c = model.predict(X)
    print('o:', o.shape)
    print(o)
    print('h:', h.shape)
    print(h)
    print('c:', c.shape)
    print(c)
  else:
    o = model.predict(X)
    print('o:', o.shape)
    print(o)

print('----return_state=False')
lstm()
print('----return_state=True')
lstm(return_state=True)


----return_state=False
o: (2, 3)
[[ 0.14960548 -0.19208433  0.13237055]
 [-0.19101419 -0.04614019 -0.192797  ]]
----return_state=True
o: (2, 3)
[[ 0.22299464 -0.1519229  -0.18894227]
 [ 0.23327349 -0.18648033  0.06530386]]
h: (2, 3)
[[ 0.22299464 -0.1519229  -0.18894227]
 [ 0.23327349 -0.18648033  0.06530386]]
c: (2, 3)
[[ 0.6513296  -0.34467825 -0.31529662]
 [ 0.33073503 -0.42570293  0.28122306]]


In [26]:
T, D, U

(5, 2, 3)

In [27]:
def bi_lstm(return_sequences=False, return_state=False):
  inp = Input(shape=(T, D))
  out = Bidirectional(
      LSTM(U, return_state=return_state, return_sequences=return_sequences)
  )(inp)
  model = Model(inputs=inp, outputs=out)
  if return_state:
    o, h1, c1, h2, c2 = model.predict(X)
    print('o: ', o.shape)
    print('h1: ', h1.shape)
    print('c1: ', c1.shape)
    print('h2: ', h2.shape)
    print('c2: ', c2.shape)
  else:
    o = model.predict(X)
    print('o: ', o.shape)

In [28]:
print('---return_sequences=False---')
bi_lstm()
print('---return_sequences=True---')
bi_lstm(return_sequences=True)
print('---return_sequences=True, return_state=True---')
bi_lstm(return_sequences=True, return_state=True)

---return_sequences=False---
o:  (2, 6)
---return_sequences=True---
o:  (2, 5, 6)
---return_sequences=True, return_state=True---
o:  (2, 5, 6)
h1:  (2, 3)
c1:  (2, 3)
h2:  (2, 3)
c2:  (2, 3)


In [29]:
def gru(return_sequences=False, return_state=False):
  inp = Input(shape=(T, D))
  out = GRU(U, return_state=return_state, return_sequences=return_sequences)(inp)
  model = Model(inputs=inp, outputs=out)
  if return_state:
    o, h = model.predict(X)
    print('o: ', o.shape)
    print('h: ', h.shape)
  else:
    o = model.predict(X)
    print('o: ', o.shape)

In [30]:
print('---many to one---')
gru()
print('---many to many---')
gru(return_sequences=True)
print('---many to many with states---')
gru(return_sequences=True, return_state=True)

---many to one---
o:  (2, 3)
---many to many---
o:  (2, 5, 3)
---many to many with states---
o:  (2, 5, 3)
h:  (2, 3)


In [31]:
def bi_gru(return_sequences=False, return_state=False):
  inp = Input(shape=(T, D))
  out = Bidirectional(
      GRU(U, return_state=return_state, return_sequences=return_sequences)
  )(inp)
  model = Model(inputs=inp, outputs=out)
  if return_state:
    o, h1, h2 = model.predict(X)
    print('o: ', o.shape)
    print('h1: ', h1.shape)
    print('h2: ', h2.shape)
  else:
    o = model.predict(X)
    print('o: ', o.shape)

In [32]:
print('---many to one---')
bi_gru()
print('---many to many---')
bi_gru(return_sequences=True)
print('---many to many with states---')
bi_gru(return_sequences=True, return_state=True)

---many to one---
o:  (2, 6)
---many to many---
o:  (2, 5, 6)
---many to many with states---
o:  (2, 5, 6)
h1:  (2, 3)
h2:  (2, 3)
