RNN from scratch

In [2]:
# lets use pytorch to build simple RNN

In [1]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

In [None]:
class RNNScratch(nn.Module):
    def __init__(self, num_inputs, num_hiddens, num_outputs, sigma=0.01):
        super(RNNScratch, self).__init__()
        self.num_hiddens = num_hiddens
        self.W_xh = nn.Parameter(torch.randn(num_inputs, num_hiddens) * sigma)
        self.W_hh = nn.Parameter(torch.randn(num_hiddens, num_hiddens) * sigma)
        self.b_h = nn.Parameter(torch.zeros(num_hiddens))
        self.W_hq = nn.Parameter(torch.randn(num_hiddens, num_outputs) * sigma)
        self.b_q = nn.Parameter(torch.zeros(num_outputs))

    def forward(self, inputs, state):
        H, = state  # Unpack the state
        outputs = []
        for X in inputs:  # Iterate through the sequence
            H = torch.tanh(torch.matmul(X, self.W_xh) + torch.matmul(H, self.W_hh) + self.b_h)
            Y = torch.matmul(H, self.W_hq) + self.b_q
            outputs.append(Y)
        return torch.stack(outputs, dim=0), (H,)  # Stack outputs along the sequence dimension
    
    def begin_state(self, batch_size, device):
        # Initialize the hidden state with zeros
        return (torch.zeros((batch_size, self.num_hiddens), device=device),)


In [4]:
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100

model = RNNScratch(num_inputs, num_hiddens, num_outputs=1)

In [5]:
# initialize the hidden state
state = model.begin_state(batch_size, device=torch.device('cpu'))
state[0].shape

torch.Size([2, 32])

In [6]:
# Forward pass
X = torch.rand(num_steps, batch_size, num_inputs)
Y, new_state = model(X, state)

In [7]:
Y

tensor([[[-0.0023],
         [-0.0031]],

        [[-0.0027],
         [-0.0018]],

        [[-0.0020],
         [-0.0013]],

        [[-0.0026],
         [-0.0017]],

        [[-0.0011],
         [-0.0016]],

        [[-0.0021],
         [-0.0024]],

        [[-0.0019],
         [-0.0025]],

        [[-0.0023],
         [-0.0021]],

        [[-0.0027],
         [-0.0021]],

        [[-0.0024],
         [-0.0024]],

        [[-0.0023],
         [-0.0028]],

        [[-0.0025],
         [-0.0021]],

        [[-0.0014],
         [-0.0024]],

        [[-0.0006],
         [-0.0024]],

        [[-0.0018],
         [-0.0025]],

        [[-0.0023],
         [-0.0029]],

        [[-0.0029],
         [-0.0020]],

        [[-0.0009],
         [-0.0017]],

        [[-0.0020],
         [-0.0030]],

        [[-0.0022],
         [-0.0011]],

        [[-0.0022],
         [-0.0013]],

        [[-0.0017],
         [-0.0020]],

        [[-0.0034],
         [-0.0009]],

        [[-0.0017],
         [-0.0

In [8]:
new_state

(tensor([[-0.0277,  0.0053,  0.0239, -0.0068, -0.0131,  0.0260,  0.0087, -0.0304,
          -0.0156, -0.0503,  0.0577, -0.0184,  0.0450,  0.0485,  0.0219,  0.0194,
           0.0553, -0.0187,  0.0153,  0.0032,  0.0294, -0.0015,  0.0202, -0.0024,
          -0.0071, -0.0502, -0.0125, -0.0106,  0.0036,  0.0098,  0.0046, -0.0021],
         [-0.0057, -0.0057,  0.0017, -0.0181, -0.0404,  0.0132, -0.0112, -0.0068,
          -0.0362, -0.0389,  0.0231, -0.0090,  0.0244,  0.0335,  0.0281,  0.0030,
           0.0856, -0.0248,  0.0265,  0.0112,  0.0262, -0.0011,  0.0335, -0.0280,
           0.0069, -0.0387,  0.0016,  0.0017, -0.0239,  0.0182,  0.0084, -0.0146]],
        grad_fn=<TanhBackward0>),)