In [1]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

In [2]:
class RNNScratch(d2l.Module):
    def __init__(self, num_inputs, num_hiddens, sigma=0.01):
        super().__init__()
        self.save_hyperparameters()
        self.W_xh = nn.Parameter(
        torch.randn(num_inputs, num_hiddens) * sigma)
        self.W_hh = nn.Parameter(
        torch.randn(num_hiddens, num_hiddens) * sigma)
        self.b_h = nn.Parameter(torch.zeros(num_hiddens))

In [5]:
@d2l.add_to_class(RNNScratch)
def forward(self, inputs, state=None):
    if state is None:
        state = torch.zeros((inputs.shape[1], self.num_hiddens),
                            device = inputs.device)
    else:
        state, = state
    outputs = []
    for X in inputs:
        state = torch.tanh(torch.matmul(X, self.W_xh) + 
                          torch.matmul(state, self.W_hh) +
                          self.b_h)
        outputs.append(state)
    return outputs, state

In [6]:
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens)
X = torch.ones((num_steps, batch_size, num_inputs))
outputs, state = rnn(X)

In [12]:
def check_len(a, n):
    assert len(a) == n, \
    f'list\'s length {len(a)} != expected length {n}'

def check_shape(a, shape):
    assert a.shape == shape, \
    f'tensor\'s shape {a.shape} != expected shape {shape}'

In [13]:
check_len(outputs, num_steps)
check_shape(outputs[0], (batch_size, num_hiddens))
check_shape(state, (batch_size, num_hiddens))

In [14]:
class RNNLMScratch(d2l.Classifier):
    def __init__(self, rnn, vocab_size, lr=0.01):
        super().__init__()
        self.save_hyperparameters()
        self.init_params()
    
    def init_params(self):
        self.W_hq = nn.Parameter(
        torch.randn(
        self.rnn.num_hiddens, self.vocab_size) * self.rnn.sigma)
        self.b_q = nn.Parameter(torch.zeros(self.vocab_size))
    
    def training_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('ppl', torch.exp(l), train=True)
        return l
    
    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('ppl', torch.exp(l), train=False)