In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of hidden units
# K = number of output units

In [3]:
# Synthetic Data
N = 10
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [4]:
# Make an RNN
class SimpleRNN(nn.Module):
  def __init__(self, n_inputs, n_hidden, n_outputs):
    super(SimpleRNN, self).__init__()
    self.D = n_inputs
    self.M = n_hidden
    self.K = n_outputs
    self.rnn = nn.RNN(
        input_size=self.D,
        hidden_size=self.M,
        nonlinearity="tanh",
        batch_first=True
    )
    self.fc = nn.Linear(self.M, self.K)

  def forward(self, X):
    # initial hidden states (1 x N x M)
    h0 = torch.zeros(1, X.size(0), self.M)

    # get RNN unit output
    out, _ = self.rnn(X, h0)

    # Take all the hidden states and pass it to the final dense layer
    out = self.fc(out)
    return out

In [5]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [6]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[-0.3595, -0.4377],
         [-0.0964, -0.1917],
         [-0.0313, -0.5381],
         [-0.3981, -0.3038],
         [ 0.1957, -0.3274],
         [-0.4587, -0.3525],
         [-0.5403, -0.1035],
         [-0.5833, -0.5002],
         [-0.3776, -0.3689],
         [-0.7410, -0.3071]],

        [[-0.4235, -0.4075],
         [-0.4947, -0.2603],
         [-0.1921, -0.2827],
         [-0.2645, -0.2836],
         [-0.0574, -0.1189],
         [-0.1244, -0.2939],
         [-0.5746, -0.4755],
         [-0.5824, -0.3386],
         [-0.3307, -0.3745],
         [-0.1177, -0.3466]],

        [[ 0.2445, -0.3096],
         [-0.2507, -0.1250],
         [-0.5935, -0.4780],
         [-0.3377, -0.5198],
         [-0.2780, -0.2746],
         [-0.3523, -0.3557],
         [-0.6158, -0.4737],
         [-0.2584, -0.1584],
         [-0.0710, -0.4629],
         [-0.2711, -0.1994]],

        [[ 0.2459, -0.0601],
         [ 0.2271, -0.2342],
         [-0.2175, -0.3616],
         [-0.6630, -0.5427],
        

In [7]:
# out is of size (N x T x K)
out.shape

torch.Size([10, 10, 2])

In [8]:
Yhats_torch = out.detach().numpy()

In [9]:
Wxh, Whh, bxh, bhh = model.rnn.parameters()

In [10]:
Wxh.shape

torch.Size([5, 3])

In [11]:
Wxh

Parameter containing:
tensor([[-0.3948, -0.3095, -0.0984],
        [ 0.1909, -0.0851, -0.3178],
        [-0.3491, -0.1400, -0.1770],
        [ 0.3057, -0.4102, -0.4345],
        [ 0.2492,  0.0297, -0.2726]], requires_grad=True)

In [12]:
Wxh = Wxh.data.numpy()
Wxh

array([[-0.39480132, -0.30950844, -0.09838684],
       [ 0.19090745, -0.08509136, -0.31777355],
       [-0.3491203 , -0.14001457, -0.17701289],
       [ 0.3056773 , -0.41021654, -0.43447495],
       [ 0.2491674 ,  0.02969623, -0.27261898]], dtype=float32)

In [13]:
bxh = bxh.data.numpy()
bhh = bhh.data.numpy()
Whh = Whh.data.numpy()

In [14]:
Wxh.shape, Whh.shape, bxh.shape, bhh.shape

((5, 3), (5, 5), (5,), (5,))

In [15]:
# FC parameter weights
Wo, bo = model.fc.parameters()

In [16]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))

In [33]:
# See if we can replicate the output
Yhats = np.zeros((N, T, K)) # the outputs
h_last = np.zeros(M) # initial hidden state
for i in range(N):
  x = X[i]
  for t in range(T):
    h = np.tanh(x[t].dot(Wxh.T) + bxh + h_last.dot(Whh.T) + bhh)
    y = h.dot(Wo.T) + bo # we only care about this value in the last iteration
    Yhats[t] = y

    h_last = h

print(Yhats)

[[[-0.5872389  -0.42979765]
  [-0.5872389  -0.42979765]
  [-0.5872389  -0.42979765]
  [-0.5872389  -0.42979765]
  [-0.5872389  -0.42979765]
  [-0.5872389  -0.42979765]
  [-0.5872389  -0.42979765]
  [-0.5872389  -0.42979765]
  [-0.5872389  -0.42979765]
  [-0.5872389  -0.42979765]]

 [[-0.39539058 -0.385592  ]
  [-0.39539058 -0.385592  ]
  [-0.39539058 -0.385592  ]
  [-0.39539058 -0.385592  ]
  [-0.39539058 -0.385592  ]
  [-0.39539058 -0.385592  ]
  [-0.39539058 -0.385592  ]
  [-0.39539058 -0.385592  ]
  [-0.39539058 -0.385592  ]
  [-0.39539058 -0.385592  ]]

 [[-0.04223876 -0.43245241]
  [-0.04223876 -0.43245241]
  [-0.04223876 -0.43245241]
  [-0.04223876 -0.43245241]
  [-0.04223876 -0.43245241]
  [-0.04223876 -0.43245241]
  [-0.04223876 -0.43245241]
  [-0.04223876 -0.43245241]
  [-0.04223876 -0.43245241]
  [-0.04223876 -0.43245241]]

 [[-0.45235257 -0.38318635]
  [-0.45235257 -0.38318635]
  [-0.45235257 -0.38318635]
  [-0.45235257 -0.38318635]
  [-0.45235257 -0.38318635]
  [-0.45235257

In [32]:
np.allclose(Yhats, Yhats_torch)

False