In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Things you should automatically know and have memorized
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of hidden units (similar to hidden units in ANN)
# K = number of output units (K > 1 can be for both multi-dimensional classification & regression)

In [3]:
# Make some data
N = 1
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [4]:
# Make an RNN
class SimpleRNN(nn.Module):
  def __init__(self, n_inputs, n_hidden, n_outputs):
    super(SimpleRNN, self).__init__()
    self.D = n_inputs
    self.M = n_hidden
    self.K = n_outputs
    self.rnn = nn.RNN(
        input_size=self.D,
        hidden_size=self.M,
        nonlinearity='tanh',
        batch_first=True)
    self.fc = nn.Linear(self.M, self.K)
  
  def forward(self, X):
    # initial hidden states
    h0 = torch.zeros(1, X.size(0), self.M)

    # get RNN unit output
    out, _ = self.rnn(X, h0)

    # we only want h(T) at the final time step
    # out = self.fc(out[:, -1, :])
    out = self.fc(out)
    return out

In [5]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [6]:
X

array([[[ 0.45559783, -0.16725984,  1.00291235],
        [ 0.77751218,  0.79736996, -0.86861649],
        [ 0.39083605, -0.38366753, -0.8010485 ],
        [-1.49520601, -0.15733797,  0.01884234],
        [-0.14028233, -2.31122115, -1.13235073],
        [ 0.74420014,  1.21497542,  0.47943072],
        [-0.69169127,  0.46438335, -0.94206014],
        [ 0.9346199 ,  0.2526913 ,  0.97838357],
        [-0.04884489,  0.04625612,  0.90359591],
        [ 1.10330945,  0.69241498, -1.05888735]]])

In [7]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[ 0.0068,  0.2706],
         [-0.0469,  0.3847],
         [-0.0536,  0.4782],
         [-0.0869,  0.4981],
         [ 0.0505,  0.7588],
         [-0.0368,  0.2807],
         [-0.2325,  0.2890],
         [-0.0180,  0.1703],
         [-0.1668,  0.1577],
         [ 0.0773,  0.4717]]], grad_fn=<ViewBackward0>)

In [8]:
out.shape

torch.Size([1, 10, 2])

In [9]:
# Save for later
Yhats_torch = out.detach().numpy()

In [10]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [11]:
# M x D (input to hidden) Transpose
W_xh.shape

torch.Size([5, 3])

In [12]:
# M x M (hidden to hidden) Transpose
W_hh.shape

torch.Size([5, 5])

In [13]:
# Pytorch gives 2 biases : (1) input to hidden bias
b_xh

Parameter containing:
tensor([ 0.4368,  0.1357, -0.1267, -0.1542, -0.1866], requires_grad=True)

In [14]:
# Pytorch gives 2 biases : (2) hidden to hidden bias
b_hh

Parameter containing:
tensor([-0.2114,  0.0012, -0.4216, -0.0020,  0.3457], requires_grad=True)

In [15]:
W_xh

Parameter containing:
tensor([[ 0.2619, -0.4260,  0.1387],
        [ 0.1793,  0.2145,  0.1761],
        [ 0.4203,  0.3563,  0.2998],
        [-0.1127, -0.0014,  0.2368],
        [-0.1317, -0.0281, -0.2126]], requires_grad=True)

In [16]:
W_xh = W_xh.data.numpy()
W_xh

array([[ 0.2619323 , -0.42599908,  0.13866633],
       [ 0.17934269,  0.21446586,  0.17613333],
       [ 0.42032683,  0.3563159 ,  0.29977322],
       [-0.11274099, -0.00143483,  0.2367993 ],
       [-0.1317223 , -0.02805439, -0.21260928]], dtype=float32)

In [17]:
b_xh = b_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_hh = b_hh.data.numpy()

In [18]:
# Did we do it right?
W_xh.shape, b_xh.shape, W_hh.shape, b_hh.shape

((5, 3), (5,), (5, 5), (5,))

In [19]:
# Now get the FC layer weights
Wo, bo = model.fc.parameters()

In [20]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))

In [21]:
# Manual RNN
# See if we can replicate the output
h_last = np.zeros(M) # initial hidden state
x = X[0] # the one and only sample
Yhats = np.zeros((T, K)) # where we store the outputs

for t in range(T):
  h = np.tanh(x[t].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)
  y = h.dot(Wo.T) + bo # we only care about this value on the last iteration
  Yhats[t] = y
  
  # important: assign h to h_last
  h_last = h

# print the final output
print(Yhats)

[[ 0.00684706  0.27062538]
 [-0.04693386  0.38468926]
 [-0.05355885  0.47819874]
 [-0.08685729  0.4981236 ]
 [ 0.05051763  0.75876122]
 [-0.03683864  0.28071162]
 [-0.23248158  0.28903397]
 [-0.01796208  0.17027288]
 [-0.16684348  0.15766129]
 [ 0.07731097  0.47172901]]


In [22]:
# Check
np.allclose(Yhats, Yhats_torch)

True

In [23]:
# Bonus exercise: calculate the output for multiple samples at once (N > 1)