In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# N = number of samples
# T = number of time steps
# D = number of input features
# M = number of hidden units
# K = number of output units

In [3]:
# Make some data
N = 1
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D)

In [4]:
# Make an RNN
class SimpleRNN(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_outputs):
        super(SimpleRNN, self).__init__()
        self.D = n_inputs
        self.M = n_hidden
        self.K = n_outputs
        self.rnn = nn.RNN(
            input_size=self.D,
            hidden_size=self.M,
            nonlinearity='tanh',
            batch_first=True
        )
        self.fc = nn.Linear(self.M, self.K)

    def forward(self, X):
        # initial hidden state 
        h0 = torch.zeros(1, X.size(0), self.M)

        # get RNN unit output
        out, _ = self.rnn(X, h0)
        
        # we only want h(T) at the final time step
        # out = self.fc(out[:, -1, :])
        out = self.fc(out)
        return out

In [5]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M, n_outputs=K)

In [6]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
out = model(inputs)
out

tensor([[[ 0.0594,  0.6624],
         [-0.2094,  0.2945],
         [-0.0305,  0.2726],
         [ 0.2447, -0.1394],
         [ 0.3073,  0.0435],
         [ 0.0526,  0.5712],
         [-0.0967,  0.2954],
         [ 0.1235,  0.3925],
         [ 0.2174, -0.0390],
         [ 0.1057,  0.0733]]], grad_fn=<ViewBackward0>)

In [7]:
out.shape

torch.Size([1, 10, 2])

In [8]:
# Save for later
Yhats_torch = out.detach().numpy()

In [9]:
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()

In [10]:
W_xh.shape

torch.Size([5, 3])

In [11]:
W_xh

Parameter containing:
tensor([[-0.3397,  0.3435,  0.1187],
        [ 0.1922, -0.0545, -0.3062],
        [-0.3581,  0.2287,  0.4450],
        [ 0.4157,  0.0435, -0.3450],
        [ 0.2186, -0.0077,  0.0156]], requires_grad=True)

In [12]:
W_xh = W_xh.data.numpy()
W_xh

array([[-0.33971268,  0.34349126,  0.1186657 ],
       [ 0.19223738, -0.05445626, -0.30618227],
       [-0.35814863,  0.22867501,  0.44498515],
       [ 0.41572154,  0.04352179, -0.3450039 ],
       [ 0.21860707, -0.00772166,  0.01560703]], dtype=float32)

In [13]:
b_xh = b_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_hh = b_hh.data.numpy()

In [14]:
# Did we do it right?
W_xh.shape, W_hh.shape, b_xh.shape, b_hh.shape

((5, 3), (5, 5), (5,), (5,))

In [15]:
# Now get the FC layer weights
Wo, bo = model.fc.parameters()

In [16]:
Wo = Wo.data.numpy()
bo = bo.data.numpy()
Wo.shape, bo.shape

((2, 5), (2,))

In [17]:
# See if we can replicate the output
h_last = np.zeros(M) # initial hidden state
x = X[0] # only one sample
Yhats = np.zeros((T, K)) # where we store the outputs

for t in range(T):
    h = np.tanh(x[t].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)
    y = h.dot(Wo.T) + bo # we only care about this value on the last iteration
    Yhats[t] = y

    # important:
    h_last = h

# print the final output
print(Yhats)

[[ 0.05942239  0.66240166]
 [-0.20939252  0.29454571]
 [-0.03054577  0.27260296]
 [ 0.24468671 -0.13937344]
 [ 0.3072928   0.04351277]
 [ 0.05255175  0.57117018]
 [-0.09668414  0.29539873]
 [ 0.12345423  0.39252538]
 [ 0.21744961 -0.03899946]
 [ 0.10572217  0.07325389]]


In [18]:
# Check 
np.allclose(Yhats, Yhats_torch)

True