RNN from scratch

In [2]:
# lets use pytorch to build simple RNN

In [1]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

In [None]:
class RNNScratch(nn.Module):
    def __init__(self, num_inputs, num_hiddens, num_outputs, sigma=0.01):
        super(RNNScratch, self).__init__()
        self.num_hiddens = num_hiddens
        self.W_xh = nn.Parameter(torch.randn(num_inputs, num_hiddens) * sigma)
        self.W_hh = nn.Parameter(torch.randn(num_hiddens, num_hiddens) * sigma)
        self.b_h = nn.Parameter(torch.zeros(num_hiddens))
        self.W_hq = nn.Parameter(torch.randn(num_hiddens, num_outputs) * sigma)
        self.b_q = nn.Parameter(torch.zeros(num_outputs))

    def forward(self, inputs, state):
        H, = state  # Unpack the state
        outputs = []
        for X in inputs:  # Iterate through the sequence
            H = torch.tanh(torch.matmul(X, self.W_xh) + torch.matmul(H, self.W_hh) + self.b_h)
            Y = torch.matmul(H, self.W_hq) + self.b_q
            outputs.append(Y)
        return torch.stack(outputs, dim=0), (H,)  # Stack outputs along the sequence dimension
    
    def begin_state(self, batch_size, device):
        # Initialize the hidden state with zeros
        return (torch.zeros((batch_size, self.num_hiddens), device=device),)


In [4]:
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100

model = RNNScratch(num_inputs, num_hiddens, num_outputs=1)

In [5]:
# initialize the hidden state
state = model.begin_state(batch_size, device=torch.device('cpu'))
state[0].shape

torch.Size([2, 32])

In [6]:
# Forward pass
X = torch.rand(num_steps, batch_size, num_inputs)
Y, new_state = model(X, state)