In [1]:
import tensorflow as tf
import numpy as np

In [2]:
input_dim = 1
seq_size = 6
batch_size = 2
input_shape = (batch_size, seq_size, input_dim)
x = tf.random.normal(shape=input_shape, dtype=tf.float32)

In [3]:
def make_cell(state_dim):
    return tf.keras.layers.LSTMCell(state_dim)

cell = make_cell(state_dim=10)
lstm_layer = tf.keras.layers.RNN(cell, input_shape=input_shape, return_sequences=True)

outputs, states = lstm_layer(x)

In [4]:
print(outputs)
print(states)
print(np.shape(outputs))
print(np.shape(states))

tf.Tensor(
[[-8.45988169e-02  1.11620054e-01 -3.02392263e-02 -1.22468159e-01
   1.33334056e-01  6.90495446e-02  6.17534295e-02  8.95454362e-02
  -8.47920179e-02 -2.12975265e-03]
 [-3.48110646e-02  1.45503059e-02 -1.00789405e-02 -4.27137427e-02
   5.84185645e-02  2.70296745e-02  7.20029138e-03  9.98209417e-03
  -4.03444991e-02 -1.63298042e-03]
 [-2.00411770e-02 -9.19348281e-03  1.01534228e-04 -1.15191005e-02
   1.93205774e-02  2.38736765e-03 -8.68689641e-03 -1.50571847e-02
  -2.16582678e-02 -1.48856791e-03]
 [-2.18446590e-02 -3.22688458e-04  8.38491251e-04 -1.49279973e-02
   1.94979291e-02  3.03098187e-03 -2.78164051e-03 -8.68288707e-03
  -1.97505467e-02 -8.96198151e-04]
 [-2.31807139e-02  7.27013499e-03  6.57327997e-04 -1.97691992e-02
   2.32450124e-02  5.11133997e-03  2.79000867e-03 -2.71375477e-03
  -1.93995126e-02 -3.06556612e-04]
 [-4.59365658e-02  4.38651666e-02 -7.40416721e-03 -5.97635135e-02
   6.81029633e-02  2.58890744e-02  2.54415236e-02  2.53195297e-02
  -4.13494334e-02 -5.0

In [5]:
cell2 = make_cell(state_dim=10)
lstm_layer2 = tf.keras.layers.RNN([cell, cell2], return_sequences=True, input_shape=input_shape)
outputs2, states2 = lstm_layer2(x)

In [6]:
print(outputs2)
print(states2)
print(np.shape(outputs2))
print(np.shape(states2))

tf.Tensor(
[[-0.01078437 -0.01857414 -0.00766134 -0.01318614 -0.00606433 -0.00415463
   0.03530656 -0.00818623  0.01483619  0.01870922]
 [-0.00899756 -0.01763183 -0.00665104 -0.01423258 -0.00533178  0.00034458
   0.03449577 -0.00829276  0.01351705  0.02694262]
 [-0.00421184 -0.0120158  -0.002818   -0.01266827 -0.00253892  0.00345639
   0.02565332 -0.00576628  0.0083866   0.02790515]
 [-0.0016041  -0.00902793  0.00019809 -0.01213416 -0.00080121  0.00422876
   0.02165543 -0.00402911  0.00532845  0.02709346]
 [-0.00049569 -0.00806458  0.00227091 -0.01229546  0.00015167  0.00392908
   0.02078773 -0.00326917  0.00391078  0.0262211 ]
 [-0.0030129  -0.01393169  0.00170807 -0.01670164 -0.00094634  0.00231666
   0.03178665 -0.0058312   0.0076776   0.03091326]], shape=(6, 10), dtype=float32)
tf.Tensor(
[[ 3.6875019e-03  7.3912689e-03  2.1604097e-03  4.3328642e-03
   2.0346756e-03  9.1667642e-04 -1.2154214e-02  3.7458951e-03
  -4.9172565e-03 -5.9206211e-03]
 [-6.9996850e-03 -6.4164936e-03 -5.0243

In [7]:
def make_multi_cell(state_dim, num_layers):
    cells = [make_cell(state_dim) for _ in range(num_layers)]
    return tf.keras.layers.RNN(cells, return_sequences=True, input_shape=input_shape)

multi_cell = make_multi_cell(state_dim=10, num_layers=4)
outputs4, states4 = multi_cell(x)

In [8]:
print(outputs4)
print(states4)
print(np.shape(outputs4))
print(np.shape(states4))

tf.Tensor(
[[ 5.45344665e-04  2.38843113e-05 -1.57655304e-04  3.62023682e-04
  -1.51362270e-04  3.85025414e-05 -1.99394504e-04 -1.20650611e-05
  -6.04397792e-04  1.70747997e-04]
 [ 1.52576622e-03 -2.10644248e-05 -5.07967721e-04  8.76861566e-04
  -4.89718281e-04  1.13365699e-04 -6.07929134e-04 -4.42237397e-05
  -1.58654689e-03  4.82244533e-04]
 [ 2.63723778e-03 -1.85123878e-04 -9.87036619e-04  1.30665558e-03
  -9.24839638e-04  1.97630972e-04 -1.13853044e-03 -1.08651664e-04
  -2.61168880e-03  8.17572116e-04]
 [ 3.68119963e-03 -4.36515606e-04 -1.51178089e-03  1.58201018e-03
  -1.36233016e-03  2.68521253e-04 -1.70967879e-03 -2.13465901e-04
  -3.53279640e-03  1.10653881e-03]
 [ 4.56626853e-03 -7.24965474e-04 -2.01697601e-03  1.71759271e-03
  -1.74691330e-03  3.16266640e-04 -2.26318371e-03 -3.55665397e-04
  -4.30818601e-03  1.32487039e-03]
 [ 5.42294187e-03 -9.96877556e-04 -2.50543375e-03  1.86141336e-03
  -2.10309494e-03  3.56446952e-04 -2.81781214e-03 -5.27597149e-04
  -5.11379959e-03  1.5