**Udemy 6-8. Paying Attention to Shapes - Appendix A**

Excersize: Calculate the output for multiple samples at once (N > 1)

In [27]:
import torch
import torch.nn as nn
import numpy as np

In [28]:
# Things you should automatically know and have memorized
# N: number of samples
# T: length of the sequence
# D: number of input features
# M: number of hidden units
# K: number of output units

In [29]:
# Make some Data
N = 4
T = 10
D = 3
M = 5
K = 2
X = np.random.randn(N, T, D) # because batch_first=True, if batch_first=False then (T, N, D)

In [30]:
# PART 1 - RNN using Torch
# Make an RNN
class SimpleRNN(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_output):
        super().__init__()
        self.D = n_inputs
        self.M = n_hidden
        self.K = n_output
        self.rnn = nn.RNN(input_size=self.D,
                          hidden_size=self.M,
                          nonlinearity='tanh',
                          batch_first=True,)
        self.fc = nn.Linear(self.M, self.K)
        
    def forward(self, X):
        # initialize h0
        h0 = torch.zeros(1, X.size(0), self.M) # L x N x M
        
        # get RNN unit output
        out, _ = self.rnn(X, h0) # X: N x T x D | out: N x T x M
        
        # passing all hidden state (h_0 ... h_T) through dense layer
        out = self.fc(out) # in: N x T x M | out: N x T x K
        return out

In [31]:
# Instantiate the model
model = SimpleRNN(n_inputs=D, n_hidden=M, n_output=K)

In [32]:
# Get the output
inputs = torch.from_numpy(X.astype(np.float32))
outputs = model(inputs)
print(outputs)
print(outputs.shape) # out: N x T x K

tensor([[[-3.8805e-01,  6.6007e-02],
         [ 9.2461e-02,  2.1089e-01],
         [ 2.5599e-01,  1.5064e-01],
         [ 2.6491e-03, -2.2587e-01],
         [ 5.0860e-01,  3.5218e-01],
         [ 1.5987e-01, -3.3655e-02],
         [ 3.3207e-01,  1.9994e-01],
         [-2.3496e-01, -1.1881e-01],
         [ 1.6551e-01,  5.1548e-02],
         [-1.6664e-01, -3.1746e-01]],

        [[-4.9303e-01, -8.4668e-05],
         [ 5.9000e-01,  3.2127e-01],
         [ 9.5413e-02,  1.9815e-01],
         [ 1.0299e-01,  3.9019e-01],
         [ 5.2420e-01,  4.9695e-01],
         [-6.8659e-03, -7.8936e-02],
         [ 5.9741e-01,  2.3778e-01],
         [ 2.6502e-01,  2.7343e-02],
         [ 8.0257e-02,  1.0623e-01],
         [ 4.5976e-02,  1.5480e-01]],

        [[-1.4548e-01, -7.1074e-02],
         [ 1.7801e-01,  5.0315e-01],
         [-4.1280e-01, -1.5312e-01],
         [ 4.6199e-01,  1.7224e-01],
         [ 3.4270e-01,  1.1509e-01],
         [ 1.0365e-02, -1.2205e-01],
         [-2.5549e-01,  5.6542e-02

In [33]:
# Save for later
Yhat_torch = outputs.detach().numpy()

In [34]:
# get the RNN layer parameters
W_xh, W_hh, b_xh, b_hh = model.rnn.parameters()
print("W_xh.shape", W_xh.shape)
print(W_xh)

W_xh = W_xh.data.numpy()
W_hh = W_hh.data.numpy()
b_xh = b_xh.data.numpy()
b_hh = b_hh.data.numpy()

print(W_xh.shape, b_xh.shape, W_hh.shape, b_hh.shape) # MxD, Dx1, MxM, Mx1 

W_xh.shape torch.Size([5, 3])
Parameter containing:
tensor([[ 0.1665, -0.0060,  0.3870],
        [-0.0372, -0.1850, -0.1352],
        [-0.0465,  0.4222,  0.0643],
        [-0.3622, -0.0760, -0.0403],
        [-0.2343,  0.2927,  0.3627]], requires_grad=True)
(5, 3) (5,) (5, 5) (5,)


In [35]:
# get the FC layer parameters
W_o, b_o = model.fc.parameters()

W_o = W_o.data.numpy()
b_o = b_o.data.numpy()

print(W_o.shape, b_o.shape) # KxM, Kx1

(2, 5) (2,)


In [36]:
# PART 2 - RNN using numpy
# Simplified because N is considered to be 1
# See if we can replicate the output
h_last = np.zeros((N, M),) # initial hidden state

yhats = np.zeros((N, T, K),) # where we store the outputs

for t in range(T):
    h = np.tanh(X[:,t,:].dot(W_xh.T) + b_xh + h_last.dot(W_hh.T) + b_hh)
    y = h.dot(W_o.T) + b_o # We only care about this value on last iteration
    yhats[:, t, :] = y
    
    h_last = h # Don't forget to assign h to h_last

print(yhats)

[[[-3.88054048e-01  6.60075294e-02]
  [ 9.24613927e-02  2.10890429e-01]
  [ 2.55985599e-01  1.50635328e-01]
  [ 2.64908628e-03 -2.25871625e-01]
  [ 5.08596132e-01  3.52184731e-01]
  [ 1.59868258e-01 -3.36545826e-02]
  [ 3.32070158e-01  1.99937682e-01]
  [-2.34963059e-01 -1.18812277e-01]
  [ 1.65510139e-01  5.15483850e-02]
  [-1.66637558e-01 -3.17462970e-01]]

 [[-4.93027234e-01 -8.46709248e-05]
  [ 5.90001432e-01  3.21271478e-01]
  [ 9.54127200e-02  1.98152895e-01]
  [ 1.02994129e-01  3.90190412e-01]
  [ 5.24201017e-01  4.96949630e-01]
  [-6.86588833e-03 -7.89361623e-02]
  [ 5.97408621e-01  2.37777439e-01]
  [ 2.65015740e-01  2.73425297e-02]
  [ 8.02568865e-02  1.06232770e-01]
  [ 4.59757744e-02  1.54803534e-01]]

 [[-1.45481832e-01 -7.10739956e-02]
  [ 1.78006871e-01  5.03152182e-01]
  [-4.12801793e-01 -1.53116622e-01]
  [ 4.61987271e-01  1.72240093e-01]
  [ 3.42704716e-01  1.15092479e-01]
  [ 1.03653004e-02 -1.22052131e-01]
  [-2.55490514e-01  5.65424993e-02]
  [ 1.50902258e-01  3.71

In [37]:
# Check
print(np.allclose(yhats, Yhat_torch)) # Both nn.RNN and our formula respond the same

True
