In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Things you should automatically know and have memorized
# N = number of samples
# T = sequence length
# D = number of input features
# M = number of hidden units
# K = number of output units

In [3]:
# Make some data
N = 1
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [4]:
X.shape

(1, 10, 3)

In [5]:
# Make an RNN
class SimpleRNN(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_outputs):
        super(SimpleRNN, self).__init__()
        self.D = n_inputs
        self.M = n_hidden
        self.K = n_outputs
        self.rnn = nn.RNN(
            input_size=self.D,
            hidden_size=self.M,
            nonlinearity='tanh',
            batch_first=True)
        self.fc = nn.Linear(self.M, self.K)

    def forward(self, X):
        # initial hidden states
        h0 = torch.zeros(1, X.size(0), self.M)


        # get RNN unit output
        out, _ = self.rnn(X, h0)

        # we only want h(T) at the final time step
        # out = self.fc(out[:, -1, :])
        out = self.fc(out)
        return out

In [6]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [7]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[-0.2980, -0.1512],
         [-0.6675, -0.2178],
         [-0.1684,  0.2758],
         [-0.4829, -0.1307],
         [-0.1445, -0.2362],
         [-0.7084, -0.3117],
         [-0.2216, -0.0057],
         [-0.1214, -0.3488],
         [-0.4838, -0.1101],
         [-0.4085, -0.4424]]], grad_fn=<ViewBackward0>)

In [8]:
out.shape

torch.Size([1, 10, 2])

In [9]:

# Save for later
Yhats_torch = out.detach().numpy()

In [10]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [11]:
W_xh.shape

torch.Size([5, 3])

In [12]:
W_xh

Parameter containing:
tensor([[ 0.1229, -0.1610, -0.3194],
        [-0.4353, -0.4396, -0.2349],
        [-0.3445, -0.0574, -0.3666],
        [ 0.1775,  0.2865, -0.2325],
        [-0.3590,  0.2284,  0.0559]], requires_grad=True)

In [13]:
W_xh = W_xh.data.numpy()
W_xh

array([[ 0.12291408, -0.1610299 , -0.3193888 ],
       [-0.43527004, -0.43961504, -0.23493043],
       [-0.34448075, -0.0573996 , -0.36662662],
       [ 0.17749107,  0.28654283, -0.23246112],
       [-0.358966  ,  0.2283746 ,  0.0558511 ]], dtype=float32)

In [14]:
b_xh = b_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_hh = b_hh.data.numpy()

In [15]:
# Did we do it right?
W_xh.shape, b_xh.shape, W_hh.shape, b_hh.shape

((5, 3), (5,), (5, 5), (5,))

In [16]:
# Now get the FC layer weights
Wo, bo = model.fc.parameters()

In [17]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))

In [18]:
# See if we can replicate the output
h_last = np.zeros(M) # initial hidden state
x = X[0] # the one and only sample
Yhats = np.zeros((T, K)) # where we store the outputs

for t in range(T):
    h = np.tanh(x[t].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)
    y = h.dot(Wo.T) + bo # we only care about this value on the last iteration
    Yhats[t] = y

    # important: assign h to h_last
    h_last = h

# print the final output
print(Yhats)

[[-0.2980089  -0.15115902]
 [-0.66750363 -0.21781065]
 [-0.16840073  0.27575078]
 [-0.48289198 -0.13070377]
 [-0.14450986 -0.23621114]
 [-0.70837086 -0.31172306]
 [-0.22159285 -0.00574892]
 [-0.12139882 -0.34876806]
 [-0.48375626 -0.11013   ]
 [-0.40848929 -0.44240513]]


In [19]:
# Check
np.allclose(Yhats, Yhats_torch)

True