In [1]:
from typing import Callable

import torch
from torch import nn

In [2]:
def compute_new_hidden_state(
    previous_hidden_state: torch.Tensor,
    inputs: torch.Tensor,
    weight_hh: torch.Tensor,
    bias_hh: torch.Tensor,
    weight_ih: torch.Tensor,
    bias_ih: torch.Tensor,
    activation_func: Callable = torch.tanh,
) -> torch.Tensor:
    """
    Compute the next hidden state
    """
    return activation_func(
        previous_hidden_state @ weight_hh.T
        + bias_hh
        + inputs @ weight_ih.T
        + bias_ih
    )


class MyRNN(nn.Module):
    def __init__(self, rnn):
        super().__init__()
        self.rnn = rnn

    def forward(
        self, X: torch.Tensor, hidden_0: torch.Tensor = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        seq_length = X.shape[0]
        batch_size = X.shape[1]
        if hidden_0 is None:
            hidden_0 = torch.zeros(
                self.rnn.num_layers,
                batch_size,
                self.rnn.hidden_size,
                dtype=torch.float64,
            )

        # Array to hold hidden states as they are calculated.
        # Initializing values to nan as a way to verify the values
        # are all getting filled in during the loop.
        hidden_states = (
            torch.zeros(
                seq_length,
                self.rnn.num_layers,
                batch_size,
                self.rnn.hidden_size,
                dtype=torch.float64,
            )
            * torch.nan
        )

        for idx in range(seq_length):
            # Hidden states for the previous step in the sequence
            previous_hidden_states = (
                hidden_0 if idx == 0 else hidden_states[idx - 1]
            )

            for layer in range(self.rnn.num_layers):
                # Inputs to the model. If we are in layer 0, the inpute
                # come from the sequence, otherwise they come from the
                # previous layer.
                inputs = X[idx] if layer == 0 else hidden_states[idx, layer - 1]
                hidden_states[idx, layer] = compute_new_hidden_state(
                    previous_hidden_state=previous_hidden_states[layer],
                    inputs=inputs,
                    weight_hh=self.rnn.get_parameter(f"weight_hh_l{layer}"),
                    bias_hh=self.rnn.get_parameter(f"bias_hh_l{layer}"),
                    weight_ih=self.rnn.get_parameter(f"weight_ih_l{layer}"),
                    bias_ih=self.rnn.get_parameter(f"bias_ih_l{layer}"),
                )
        assert not hidden_states.isnan().any()
        return hidden_states[:, -1], hidden_states[-1]

In [3]:
input_size = 6
hidden_size = 4
num_layers = 3
bias = True
batch_first = False

rnn = nn.RNN(
    input_size=input_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    nonlinearity="tanh",  # default
    batch_first=batch_first,  # default
    bias=bias,
    dtype=torch.float64,
)

my_rnn = MyRNN(rnn)

seq_length = 13
batch_size = 7

# Random inputs
X = torch.randn(seq_length, batch_size, input_size, dtype=torch.float64)

# Random starting hidden state
hidden_0 = torch.randn(num_layers, batch_size, hidden_size, dtype=torch.float64)

for rnn_output, my_rnn_output in zip(rnn(X, hidden_0), my_rnn(X, hidden_0)):
    assert torch.isclose(rnn_output, my_rnn_output).all()
print("Tests passed")

Tests passed
