<a href="https://colab.research.google.com/github/mahesh-keswani/pytorch-example-notebook/blob/main/13_Pytorch_Understanding_shapes_in_RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
N = 1 # only single sample
T = 10 # sequence length
input_dimension = 1
hidden_dimension = 5
output_dimension = 2

X = np.random.randn(N, T, input_dimension)

In [3]:
class SimpleRNN(nn.Module):
    def __init__(self, input_dimension, hidden_dimension, output_dimension):
        super(SimpleRNN, self).__init__()
        self.rnn = nn.RNN(input_dimension, hidden_dimension, batch_first=True)
        self.fc = nn.Linear(hidden_dimension, output_dimension)
        # here we  are having only  single layer in rnn by default

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), hidden_dimension)
        out, _ = self.rnn(x, h0)
        print(f"Output shape  from hidden unit {out.shape}")

        # passing the entire  out to fc, i.e we will get now output for every timestep
        out = self.fc(out)
        print(f"Output shape  from linear unit {out.shape}")
        return out

In [4]:
model = SimpleRNN(input_dimension, hidden_dimension, output_dimension)

In [5]:
inputs = torch.from_numpy(X.astype(np.float32))
outputs = model(inputs)
outputs

Output shape  from hidden unit torch.Size([1, 10, 5])
Output shape  from linear unit torch.Size([1, 10, 2])


tensor([[[-0.0581, -0.2801],
         [-0.1343, -0.3343],
         [-0.2036, -0.3577],
         [-0.1521, -0.3276],
         [-0.1256, -0.3063],
         [-0.2171, -0.3509],
         [-0.2122, -0.3553],
         [-0.1598, -0.3237],
         [-0.1526, -0.3176],
         [-0.2319, -0.3584]]], grad_fn=<ViewBackward0>)

In [6]:
# lets save outputs for later
Yhats = outputs.detach().numpy()

In [7]:
Wxh, Whh, bxh, bhh = model.rnn.parameters()

In [8]:
print(X.shape)
print(Wxh.shape)
print(bxh.shape)
print(Whh.shape)
print(bhh.shape)

(1, 10, 1)
torch.Size([5, 1])
torch.Size([5])
torch.Size([5, 5])
torch.Size([5])


In [9]:
Wxh = Wxh.data.numpy()
Whh = Whh.data.numpy()
bxh = bxh.data.numpy()
bhh = bhh.data.numpy()

In [10]:
#lets also grab parameters of the output layer
W, b = model.fc.parameters()
W, b = W.data.numpy(), b.data.numpy()
W.shape, b.shape

((2, 5), (2,))

In [11]:
X.shape

(1, 10, 1)

In [12]:
# lets see if we can replicate the output using manual calculations
h_last = np.zeros(hidden_dimension)
X = X[0] # we  will work with single sample only,therefore  X is (10,1) now
Yhats_manual = np.zeros((T, output_dimension))

for t in range(T):
    # X[t] is (1, 1). Wxh is (5,1), therefore X[t].Wxh is (5,1)+bxh = (5,1)
    # h_last is (5,1), Whh is (5,5),therefore Whh.h_last is (5,1)
    h_next = np.tanh(X[t].dot(Wxh.T) + h_last.dot(Whh.T) + bxh)

    Yhats_manual[t] = np.dot(h_next, W.T) + b
    h_last = h_next

In [13]:
Yhats

array([[[-0.05807926, -0.2801122 ],
        [-0.1342808 , -0.33428547],
        [-0.203589  , -0.35774437],
        [-0.1521302 , -0.32757244],
        [-0.12563804, -0.3062835 ],
        [-0.21710108, -0.35088778],
        [-0.2122352 , -0.35530427],
        [-0.1597626 , -0.32368603],
        [-0.15259574, -0.31762537],
        [-0.23187415, -0.35836193]]], dtype=float32)

In [14]:
Yhats_manual

array([[ 0.05203359, -0.38844685],
       [ 0.06049624, -0.41301386],
       [ 0.00201575, -0.45323665],
       [ 0.06425441, -0.42201359],
       [ 0.09402765, -0.4060147 ],
       [-0.00361719, -0.45334501],
       [-0.00809015, -0.46341526],
       [ 0.04823504, -0.42915469],
       [ 0.05922465, -0.42226637],
       [-0.02817548, -0.46625672]])

In [15]:
# check if rnn output and manual calculation values are all close
np.allclose(Yhats, Yhats_manual)

False