#### Implementation of a plain RNN basic operations equivalent to Pytorch RNN

In [1]:
import torch
from torch import nn
rnn = nn.RNN(input_size=4, hidden_size=8, num_layers=1, nonlinearity='relu')
input = torch.rand(size=(5, 4))

In [4]:
print('Weight input-hidden', rnn.weight_ih_l0.shape)
print('Weight hidden-hidden', rnn.weight_hh_l0.shape)
print('Bias input-hidden', rnn.bias_ih_l0.shape)
print('Bias hidden-hidden', rnn.bias_hh_l0.shape)

Weight input-hidden torch.Size([8, 4])
Weight hidden-hidden torch.Size([8, 8])
Bias input-hidden torch.Size([8])
Bias hidden-hidden torch.Size([8])


##### Output calculation
last_h is the last hidden state of the RNN. It is the same that the last element of the hidden_states list (output)

In [3]:
output, last_h = rnn(input)
last_h

tensor([[0.2044, 0.3398, 0.1788, 0.0464, 0.5647, 0.0000, 0.0000, 0.0000]],
       grad_fn=<SqueezeBackward1>)

#### Classic calculation (Pytorch way)
It uses 2 matrices: one for the input and one for the hidden state. The output is the activation of the sum of the bias plus matrix multiplication of the input and the hidden state.

In [7]:
last_hidden = torch.zeros((8,))

for i in range(input.shape[0]):
    ih = torch.matmul(input[i], rnn.weight_ih_l0.T)
    hh = torch.matmul(last_hidden, rnn.weight_hh_l0.T)

    ih = ih + rnn.bias_ih_l0
    hh = hh + rnn.bias_hh_l0

    h0 = torch.relu(ih + hh)
    last_hidden = h0 

display(last_hidden)   
print("Final hidden state is the same: ", torch.allclose(last_hidden, last_h.squeeze()))

tensor([0.2044, 0.3398, 0.1788, 0.0464, 0.5647, 0.0000, 0.0000, 0.0000],
       grad_fn=<ReluBackward0>)

Final hidden state is the same:  True


##### Alternative calculation (One single matrix)

It uses a single matrix to calculate the output. The matrix is the concatenation of the input and hidden state matrices. The output is the activation of the sum of the bias plus matrix multiplication of the input and the hidden state.

In [8]:

single_weight_matrix = torch.cat([rnn.weight_ih_l0, rnn.weight_hh_l0], dim=1)
all_b = rnn.bias_ih_l0 + rnn.bias_hh_l0
last_hidden_single = torch.zeros((8,))

for i in range(input.shape[0]):
    input_i = torch.cat([input[i].reshape(1, 4), last_hidden_single.reshape(1, 8)], dim=1)
    ih = torch.matmul(input_i, single_weight_matrix.T)
    h0 = torch.relu(ih + all_b)
    last_hidden_single = h0 

display(last_hidden_single)
print("Final hidden state is the same: ", torch.allclose(last_hidden_single, last_h.squeeze()))

tensor([[0.2044, 0.3398, 0.1788, 0.0464, 0.5647, 0.0000, 0.0000, 0.0000]],
       grad_fn=<ReluBackward0>)

Final hidden state is the same:  True
