In [2]:
from keras.models import Model
from keras.layers import Input, LSTM, GRU
import numpy as np
import matplotlib.pyplot as plt

import keras.backend as K
if len(K.tensorflow_backend._get_available_gpus()) > 0:
    from keras.layers import CuDNNLSTM as LSTM
    from keras.layers import CuDNNGRU as GRU

In [3]:
T = 8
D = 2
M = 3

X = np.random.randn(1, T, D)

In [4]:
def lstm1():
    input_ = Input(shape=(T, D))
    rnn = LSTM(M, return_state=True)
    x = rnn(input_)
    
    model = Model(input_, x)
    o, h, c = model.predict(X)
    print("o:", o)
    print("h:", h)
    print("c:", c)

def lstm2():
    input_ = Input(shape=(T, D))
    rnn = LSTM(M, return_state=True, return_sequences=True)
    x = rnn(input_)
    
    model = Model(input_, x)
    o, h, c = model.predict(X)
    print("o:", o)
    print("h:", h)
    print("c:", c)

def gru1():
    input_ = Input(shape=(T, D))
    rnn = GRU(M, return_state=True)
    x = rnn(input_)
    
    model = Model(input_, x)
    o, h = model.predict(X)
    print("o:", o)
    print("h:", h)

def gru2():
    input_ = Input(shape=(T, D))
    rnn = GRU(M, return_state=True, return_sequences=True)
    x = rnn(input_)
    
    model = Model(input_, x)
    o, h = model.predict(X)
    print("o:", o)
    print("h:", h)

In [5]:
print("lstm1:")
lstm1()
print("lstm2:")
lstm2()
print("gru1:")
gru1()
print("gru2:")
gru2()

lstm1:
o: [[ 0.03838445  0.13209935 -0.10795344]]
h: [[ 0.03838445  0.13209935 -0.10795344]]
c: [[ 0.05685591  0.20994118 -0.14142275]]
lstm2:
o: [[[ 0.11599305  0.1256925  -0.10781699]
  [ 0.11521567  0.17942551 -0.10310015]
  [ 0.0266189   0.18865916 -0.05122662]
  [-0.06087825 -0.09563553  0.06982227]
  [-0.33823293 -0.24805805  0.21074349]
  [-0.05284936 -0.15866783  0.08730863]
  [-0.04282744 -0.17751427  0.06523596]
  [ 0.06137247  0.02683462 -0.03220469]]]
h: [[ 0.06137247  0.02683462 -0.03220469]]
c: [[ 0.18248907  0.06017476 -0.0766963 ]]
gru1:
o: [[ 0.06176398 -0.30909953 -0.10846783]]
h: [[ 0.06176398 -0.30909953 -0.10846783]]
gru2:
o: [[[ 0.21512292  0.2948208  -0.00963136]
  [ 0.34962517  0.290185   -0.00330868]
  [ 0.08177797  0.1534144  -0.00652864]
  [-0.6382818  -0.1426416  -0.00543295]
  [-0.5924832   0.07882399 -0.05546368]
  [-0.4101083  -0.0459121  -0.07528275]
  [-0.18622403  0.11382141 -0.06107236]
  [ 0.02845139  0.19716881 -0.03319811]]]
h: [[ 0.02845139  0.197