In [67]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [68]:
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of hidden units
# K = number of output units

In [69]:
# Synthetic Data
N = 1
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [70]:
# Make an RNN
class SimpleRNN(nn.Module):
  def __init__(self, n_inputs, n_hidden, n_outputs):
    super(SimpleRNN, self).__init__()
    self.D = n_inputs
    self.M = n_hidden
    self.K = n_outputs
    self.rnn = nn.RNN(
        input_size=self.D,
        hidden_size=self.M,
        nonlinearity="tanh",
        batch_first=True
    )
    self.fc = nn.Linear(self.M, self.K)

  def forward(self, X):
    # initial hidden states (1 x N x M)
    h0 = torch.zeros(1, X.size(0), self.M)

    # get RNN unit output
    out, _ = self.rnn(X, h0)

    # Take all the hidden states and pass it to the final dense layer
    out = self.fc(out)
    return out

In [71]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [72]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[ 0.2815,  0.2051],
         [ 1.0332,  0.3559],
         [ 0.2034, -0.3969],
         [ 0.4116,  0.3697],
         [ 0.9502, -0.0609],
         [ 0.5670,  0.2924],
         [ 0.3360, -0.4081],
         [ 0.4898,  0.4840],
         [ 0.6353,  0.1821],
         [ 0.7645, -0.1035]]], grad_fn=<ViewBackward0>)

In [73]:
# out is of size (N x T x K)
out.shape

torch.Size([1, 10, 2])

In [74]:
Yhats_torch = out.detach().numpy()

In [75]:
Wxh, Whh, bxh, bhh = model.rnn.parameters()

In [76]:
Wxh.shape

torch.Size([5, 3])

In [77]:
Wxh

Parameter containing:
tensor([[ 0.1883,  0.3353, -0.2365],
        [-0.0915,  0.3221,  0.4183],
        [-0.3326,  0.0902, -0.2205],
        [ 0.3723, -0.0542, -0.1157],
        [-0.3244, -0.1606, -0.4423]], requires_grad=True)

In [78]:
Wxh = Wxh.data.numpy()
Wxh

array([[ 0.1883184 ,  0.33534142, -0.23647648],
       [-0.09146273,  0.3221107 ,  0.4182537 ],
       [-0.3325894 ,  0.09023554, -0.22051273],
       [ 0.372298  , -0.05421357, -0.11572257],
       [-0.32436353, -0.16062526, -0.4423163 ]], dtype=float32)

In [79]:
bxh = bxh.data.numpy()
bhh = bhh.data.numpy()
Whh = Whh.data.numpy()

In [80]:
Wxh.shape, Whh.shape, bxh.shape, bhh.shape

((5, 3), (5, 5), (5,), (5,))

In [81]:
# FC parameter weights
Wo, bo = model.fc.parameters()

In [82]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))

In [83]:
# See if we can replicate the output
h_last = np.zeros(M) # initial hidden state
x = X[0] # the one and only sample
Yhats = np.zeros((T, K)) # the outputs

for t in range(T):
  h = np.tanh(x[t].dot(Wxh.T) + bxh + h_last.dot(Whh.T) + bhh)
  y = h.dot(Wo.T) + bo # we only care about this value in the last iteration
  Yhats[t] = y

  h_last = h

print(Yhats)

[[ 0.2814659   0.20505499]
 [ 1.03323575  0.35593529]
 [ 0.20338952 -0.39690347]
 [ 0.411601    0.36973532]
 [ 0.95017998 -0.06087913]
 [ 0.566959    0.29240908]
 [ 0.3360443  -0.40814406]
 [ 0.48977529  0.48398769]
 [ 0.63526584  0.18214776]
 [ 0.76447335 -0.10353298]]


In [84]:
np.allclose(Yhats, Yhats_torch)

True