In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define RNN parameters
input_size = 10
hidden_size = 20
num_layers = 2
seq_len = 5
batch_size = 3

# Create an RNN instance
rnn = nn.RNN(input_size, hidden_size, num_layers, nonlinearity="tanh", batch_first=False)

# Generate random input
x = torch.randn(seq_len, batch_size, input_size)

# Get PyTorch RNN output
with torch.no_grad():
    rnn_out, rnn_hidden = rnn(x)

# Manual implementation
h_t_minus_1 = torch.zeros(num_layers, batch_size, hidden_size)
h_t = torch.zeros_like(h_t_minus_1)

manual_output = []

for t in range(seq_len):
    h_t_new = []
    for layer in range(num_layers):
        weight_ih = getattr(rnn, f'weight_ih_l{layer}')
        bias_ih = getattr(rnn, f'bias_ih_l{layer}')
        weight_hh = getattr(rnn, f'weight_hh_l{layer}')
        bias_hh = getattr(rnn, f'bias_hh_l{layer}')

        xin = x[t] if layer == 0 else h_t_new[layer-1]

        h_layer = torch.tanh(
            xin @ weight_ih.T
            + bias_ih
            + h_t_minus_1[layer] @ weight_hh.T
            + bias_hh
        )

        h_t_new.append(h_layer)

    h_t = torch.stack(h_t_new)
    manual_output.append(h_t[-1])  # Last layer output

    h_t_minus_1 = h_t  # Avoid .detach() for debugging

manual_output = torch.stack(manual_output)  # Shape: (seq_len, batch_size, hidden_size)

# Compare outputs
diff = torch.abs(manual_output - rnn_out)
max_diff = diff.max().item()
print(f"Max difference: {max_diff}")

if max_diff > 1e-5:
    print("Outputs do not match. Investigating...")
    for t in range(seq_len):
        print(f"Time step {t}: max difference = {diff[t].max().item()}")