## TensorFlow - RNN

In this notebook we will implement an RNN layer and verify it with tensorflow

In [3]:
import numpy as np
import tensorflow as tf

from tensorflow.keras.layers import Input, SimpleRNN, Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD, Adam

%matplotlib inline

In [4]:
# Lets define these variables
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of hidden units
# K = number of output units

np.random.seed(42)
N = 1
T = 10
D = 3
K = 2
X = np.random.randn(N, T, D)

In [5]:
# Make an RNN
M = 5 # number of hidden units
i = Input(shape=(T, D))
x = SimpleRNN(M)(i)
x = Dense(K)(x)

model = Model(i, x)

In [6]:
# Get the output
Yhat = model.predict(X)
print(Yhat)

[[ 0.22453962 -0.8722438 ]]


In [8]:
model.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 10, 3)]           0         
_________________________________________________________________
simple_rnn (SimpleRNN)       (None, 5)                 45        
_________________________________________________________________
dense (Dense)                (None, 2)                 12        
Total params: 57
Trainable params: 57
Non-trainable params: 0
_________________________________________________________________


In [9]:
#Check the shapes of the layers
a, b, c = model.layers[1].get_weights()
print(a.shape, b.shape, c.shape)

(3, 5) (5, 5) (5,)


In [10]:
Wx, Wh, bh = model.layers[1].get_weights()
Wo, bo = model.layers[2].get_weights()

In [11]:
#Initialize the hidden state
h_last = np.zeros(M)
#First sample
x = X[0] 
Yhats = []

for t in range(T):
  h = np.tanh(x[t].dot(Wx) + h_last.dot(Wh) + bh) ## RNN Formula to compute new hidden state
  y = h.dot(Wo) + bo 
  Yhats.append(y)
  
  # important: assign h to h_last
  h_last = h

# print the final output
print(Yhats[-1])

[ 0.22453974 -0.87224379]


This outpit exactly matches the output of Simple RNN layer