# Recurrent Neural Networks

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from utils import d2l
%matplotlib inline

## 9.5 Recurrent Neural Networks from Scratch

In [49]:
class RNNScratch(nn.Module):
    def __init__(self, n_inputs, n_hidden, n_outputs, device='cpu'):
        super().__init__()
        # Set device
        self.device = device

        # Set hyperparameters
        self.n_inputs = n_inputs # Length of input sequence
        self.n_hidden = n_hidden # Number of hidden dimensions
        self.n_outputs = n_outputs # Length of output dimensions

        # Initialize latent layer parameters
        self.W_xh = self._init_parameter((n_inputs, n_hidden))
        self.W_hh = self._init_parameter((n_hidden, n_hidden))
        self.b_h = self._init_zeros((n_hidden,))

        # Initialize output layer parameters
        self.W_hq = self._init_parameter((n_hidden, n_outputs))
        self.b_q = self._init_zeros((n_outputs,))
    
    def _init_parameter(self, shape: tuple):
        # Initialize parameters with Xavier distribution
        return nn.Parameter(torch.nn.init.xavier_normal_(torch.empty(shape, device=self.device)))
    
    def _init_zeros(self, shape: tuple):
        # Initialize bias with zeros
        return nn.Parameter(torch.zeros(shape, device=self.device))

    def forward(self, X: torch.Tensor, hidden_state: torch.Tensor = None, training: bool = True):
        if training and (hidden_state is None):
            # Hidden state does not require gradient
            self.hidden_state = torch.zeros((X.shape[0], self.n_hidden), device=self.device, requires_grad=False)
        
        # Compute the new hidden state
        L_cat = torch.cat((X, self.hidden_state), 1)
        R_cat = torch.cat((self.W_xh, self.W_hh), 0)
        self.hidden_state = torch.tanh(torch.mm(L_cat, R_cat) + self.b_h)
        
        # Compute output
        output = torch.mm(self.hidden_state, self.W_hq) + self.b_q
        
        return output, self.hidden_state

In [50]:
def train_rnn_epoch(net, train_iterator, loss, optimizer):
    """Train a model for one epoch."""
    # Set the model to training mode
    net.train()
    # Initialize the total loss and number of samples
    total_loss, num_samples = 0, 0

    hidden_state = None
    for X, y in train_iterator:
        # Move data to the appropriate device
        X, y = X.to(net.device), y.to(net.device)

        # Forward pass
        y_hat, hidden_state = net(X, hidden_state)

        # Compute the loss
        l = loss(y_hat, y[:, -1].long())
        
        # Backward pass
        optimizer.zero_grad()
        l.backward(retain_graph=True)
        optimizer.step()
        
        # Update the total loss and number of samples
        total_loss += l.item() * y.shape[0]
        num_samples += y.shape[0]
    return total_loss / num_samples

In [51]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [52]:
seq_len = 40

In [53]:
time_machine = d2l.TimeMachine(64, seq_len, 10112, 5056)
train_data = time_machine.get_dataloader(False)
test_data = time_machine.get_dataloader(True)

In [54]:
net = RNNScratch(n_inputs=seq_len, n_hidden=512, n_outputs=len(time_machine.vocab), device=device)
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [55]:
num_epochs = 200
for epoch in range(num_epochs):
    train_loss = train_rnn_epoch(net, train_data, loss, optimizer)
    print(f'Epoch {epoch + 1}, Loss: {train_loss:.4f}')

Epoch 1, Loss: 2.9843
Epoch 2, Loss: 2.7560
Epoch 3, Loss: 2.6404
Epoch 4, Loss: 2.5428
Epoch 5, Loss: 2.4350
Epoch 6, Loss: 2.3342
Epoch 7, Loss: 2.2483
Epoch 8, Loss: 2.1761
Epoch 9, Loss: 2.1292
Epoch 10, Loss: 2.1013
Epoch 11, Loss: 2.0783
Epoch 12, Loss: 2.0217
Epoch 13, Loss: 1.9936
Epoch 14, Loss: 1.9313
Epoch 15, Loss: 1.8846
Epoch 16, Loss: 1.8441
Epoch 17, Loss: 1.8392
Epoch 18, Loss: 1.7996
Epoch 19, Loss: 1.7684
Epoch 20, Loss: 1.7034
Epoch 21, Loss: 1.6301
Epoch 22, Loss: 1.5975
Epoch 23, Loss: 1.5893
Epoch 24, Loss: 1.5297
Epoch 25, Loss: 1.4526
Epoch 26, Loss: 1.4264
Epoch 27, Loss: 1.3930
Epoch 28, Loss: 1.3600
Epoch 29, Loss: 1.3353
Epoch 30, Loss: 1.3095
Epoch 31, Loss: 1.2898
Epoch 32, Loss: 1.2726
Epoch 33, Loss: 1.2813
Epoch 34, Loss: 1.1605
Epoch 35, Loss: 1.0478
Epoch 36, Loss: 1.0398
Epoch 37, Loss: 1.0277
Epoch 38, Loss: 0.9523
Epoch 39, Loss: 0.9336
Epoch 40, Loss: 0.9034
Epoch 41, Loss: 0.8376
Epoch 42, Loss: 0.7923
Epoch 43, Loss: 0.7955
Epoch 44, Loss: 0.79

In [57]:
net.hidden_state.shape

torch.Size([64, 512])

In [58]:
with torch.no_grad():
    net.eval()
    x, y = next(iter(test_data))
    # Move data to device
    x, y = x.to(device), y.to(device)
    
    pred, _ = net(x, None, training=False)
    pred = F.softmax(pred, dim=1)
    pred = torch.argmax(pred, dim=1)
    
    pred = torch.cat((x, pred.unsqueeze(1)), dim=1)
    
    # Move tensors back to CPU for printing
    pred_cpu = pred.cpu()
    y_cpu = y.cpu()
    
    for hat, actual in zip(pred_cpu, y_cpu):
        print("".join(time_machine.vocab.to_tokens(hat.tolist())))
        print("".join([' '] + time_machine.vocab.to_tokens(actual.tolist())))

eller with a slight accession of cheerfu 
 ller with a slight accession of cheerful
 two dimensions but how about up and dow 
 two dimensions but how about up and down
raced and caressed us rather than submit 
 aced and caressed us rather than submitt
ciety said i erected on a strictly commu 
 iety said i erected on a strictly commun
al existence there i object said filby oi
 l existence there i object said filby of
ong it but some foolish people have got c
 ng it but some foolish people have got h
rimental verification said the time trave
 imental verification said the time trave
tter because it happens that our conscio 
 ter because it happens that our consciou
 time as we move about in the other dimes
 time as we move about in the other dimen
 the table was a small shaded lamp the bi
 the table was a small shaded lamp the br
it is spoken of as having three dimensio 
 t is spoken of as having three dimension
 he walked slowly out of the room and wei
 he walked slowly out of the room 

In [None]:
d2l.RNNLMScratch