In this notebook, we will implement the forward function of a simple RNN and compare with the PyTorch implementation in torch.nn.RNN to see that we get the same results. First, recall that for an RNN we need to sets of weight matrices and bias vectors - one the get from the input layer to the hidden layer and one to get from the hidden layer of the previous time step to the hidden layer of the current time step. Let us start by fixing the dimensions of the input layer and the hidden layer.

In [1]:
import torch
d_hidden = 3
d_in = 5

The matrix $W_{ih}$ from the inner layer to the hidden layer therefore needs to map from 5 dimensions to 3 dimensions, i.e has shape (3,5). Let us initialize it with random values.

In [2]:
w_ih = torch.randn((d_hidden, d_in))
b_ih = torch.randn(d_hidden)

The matrix that we use to map between the values of the hidden layers at different time steps is of course quadratic, let us initialize it as well, along with the respective bias.

In [3]:
w_hh = torch.randn((d_hidden, d_hidden))
b_hh = torch.randn(d_hidden)

We can now implement the forward function of the network. Recall that at each time step t, we need to apply the formula
$$
h_t = tanh(x_t W_{ih}^t + b_{ih} + h_{t-1}W_{hh}^t + b_{hh})
$$
In addition, we want our forward function to be able to optionally accept the hidden layer from a previous step and to also return the new value of the hidden layer along with the output.

In [4]:
def forward(x, h = None):
    L = x.shape[0]
    if h is None:
        h = torch.zeros(d_hidden)
    out = []
    for t in range(L):
      h = torch.tanh(x[t] @ w_ih.t() + b_ih + h @ w_hh.t() + b_hh)
      out.append(h)
    return torch.stack(out), h

Let us run this for a simple example.

In [5]:
L = 5
x = torch.randn((L, d_in))
out, hidden = forward(x)
print(out)
print(hidden)

tensor([[-0.4771,  0.8883,  0.8964],
        [ 0.7838, -0.9487, -0.9940],
        [ 0.9864, -0.6825, -0.9588],
        [ 0.9998, -0.6754,  0.8989],
        [ 0.9993,  0.9369, -0.9899]])
tensor([ 0.9993,  0.9369, -0.9899])


To verify that this is correct, let us compare this to the implementation that comes with PyTorch. For that purpose, we initialize a PyTorch RNN, extract the weights, apply the RNN and our forward function to x and compare.

In [6]:
rnn = torch.nn.RNN(input_size = d_in, hidden_size = d_hidden)
w_hh = rnn.weight_hh_l0
w_ih = rnn.weight_ih_l0
b_ih = rnn.bias_ih_l0
b_hh = rnn.bias_hh_l0
assert w_hh.shape == (d_hidden, d_hidden)
assert w_ih.shape == (d_hidden, d_in)
_out, _hidden = rnn(x)
out, hidden = forward(x)
print(f"Match of outputs: {torch.allclose(_out, out)}")
print(f"Match of hidden layers: {torch.allclose(_hidden, hidden)}")

Match of outputs: True
Match of hidden layers: True


Let us do one more time step with a new input, this time passing the previously computed hidden values back into the model.

In [7]:
x = torch.randn(1, d_in)
_out, _hidden = rnn(x, _hidden)
out, hidden = forward(x, hidden)
print(f"Match of outputs: {torch.allclose(_out, out)}")
print(f"Match of hidden layers: {torch.allclose(_hidden, hidden)}")

Match of outputs: True
Match of hidden layers: True


A word about batching. The way how we have extracted the value of the current time step (simply indexing by t) from the input x works as long as the sequence dimension is the first dimension. Therefore, the batch dimension needs to be the second dimension to make this work. Luckily, this is also the way how PyTorch expects the batch dimension for an RNN. Let us repeat our check with batched input.

In [9]:
B = 4
L = 5
x = torch.randn(L, B, d_in)
_out, _hidden = rnn(x)
out, hidden = forward(x)
print(f"Match of outputs: {torch.allclose(_out, out)}")
print(f"Match of hidden layers: {torch.allclose(_hidden, hidden)}")

Match of outputs: True
Match of hidden layers: True
