### Multilayer RNN from scratch

In [1]:
from typing import Callable

import torch
from torch import nn

In [2]:
def compute_new_hidden_state(
    previous_hidden_state: torch.Tensor,
    inputs: torch.Tensor,
    weight_hh: torch.Tensor,
    bias_hh: torch.Tensor,
    weight_ih: torch.Tensor,
    bias_ih: torch.Tensor,
    activation_func: Callable = torch.tanh,
) -> torch.Tensor:
    """
    Compute the next hidden state
    """
    return activation_func(
        previous_hidden_state @ weight_hh.T
        + bias_hh
        + inputs @ weight_ih.T
        + bias_ih
    )


class MyRNN(nn.Module):
    def __init__(self, rnn):
        super().__init__()
        self.num_layers = rnn.num_layers
        self.hidden_size = rnn.hidden_size
        self.params = dict(rnn.named_parameters())
        self.activation_func = (
            torch.relu if rnn.nonlinearity == "relu" else torch.tanh
        )

    def forward(
        self, X: torch.Tensor, hidden_0: torch.Tensor = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        seq_length = X.shape[0]
        batch_size = X.shape[1]
        if hidden_0 is None:
            hidden_0 = torch.zeros(
                self.num_layers,
                batch_size,
                self.hidden_size,
                dtype=torch.float64,
            )

        # Array to hold hidden states as they are calculated.
        # Initializing values to nan as a way to verify the values
        # are all getting filled in during the loop.
        hidden_states = torch.zeros(
            seq_length,
            self.num_layers,
            batch_size,
            self.hidden_size,
            dtype=torch.float64,
        )
        hidden_states[:] = torch.nan

        for idx in range(seq_length):
            for layer in range(self.num_layers):
                previous_hidden_states = (
                    hidden_0 if idx == 0 else hidden_states[idx - 1]
                )
                inputs = X[idx] if layer == 0 else hidden_states[idx, layer - 1]
                hidden_states[idx, layer] = compute_new_hidden_state(
                    previous_hidden_state=previous_hidden_states[layer],
                    inputs=inputs,
                    weight_hh=self.params.get(f"weight_hh_l{layer}"),
                    bias_hh=self.params.get(f"bias_hh_l{layer}", 0),
                    weight_ih=self.params.get(f"weight_ih_l{layer}"),
                    bias_ih=self.params.get(f"bias_ih_l{layer}", 0),
                    activation_func=self.activation_func,
                )

        assert not hidden_states.isnan().any()
        return hidden_states[:, -1], hidden_states[-1]

In [3]:
def test_rnn(
    input_size: int,
    hidden_size: int,
    num_layers: int,
    seq_length: int,
    batch_size: int,
    nonlinearity: str = "tanh",
    bias: bool = True,
) -> None:
    """
    Verify that nn.RNN and MyRNN produce the same output

    Returns None if the outputs agree to within numerical
    precision, raises an AssertionError otherwise
    """
    rnn = nn.RNN(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        nonlinearity=nonlinearity,
        bias=bias,
        dtype=torch.float64,
    )
    my_rnn = MyRNN(rnn)

    # Random inputs and starting hidden state
    X = torch.randn(seq_length, batch_size, input_size, dtype=torch.float64)
    hidden_0 = torch.randn(
        num_layers, batch_size, hidden_size, dtype=torch.float64
    )

    # Verify that my_rnn has the same output as rnn
    for rnn_output, my_rnn_output in zip(rnn(X, hidden_0), my_rnn(X, hidden_0)):
        assert torch.isclose(rnn_output, my_rnn_output).all()

    return None


input_size = 6
hidden_size = 4
seq_length = 13
batch_size = 7

for num_layers in [1, 2, 3]:
    for bias in [True, False]:
        for nonlinearity in ("relu", "tanh"):
            test_rnn(
                input_size=input_size,
                hidden_size=hidden_size,
                num_layers=num_layers,
                seq_length=seq_length,
                batch_size=batch_size,
                nonlinearity=nonlinearity,
                bias=bias,
            )