# Assignment 8

In the module, the attention mechanism implementation is based on a single sequence input and iterated over every token individually. This implementation matches the block diagram in the module. 

Convert the implementation to batched using `BATCH_SIZE` to take advantage of PyTorch and a GPGPU device. Code updates and hints towards batch-by-batch processing have been provided throughout the implementation, including the dataset class NMTDataset. 

General steps towards batching in an RNN attention neural network:

- Each sequence is created as a fixed-length tensor and padded to fill the tokens to the 
length after the EOS token. 
- Training has to input the sequences batch by batch.
- It is permissible to go over the tokens one by one, but in batches.
- RNN layers have batch_first to control the batch order, either batch-sequence-features or sequence-batch-features for True and False, respectively. The initial hidden 
layer order is always the same, sequence-batch-features

Show that batched implementation is faster and generates the same error convergence.

In [None]:
# Import necessary modules
import sys
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from nmt_dataset.nmt_dataset import \
    NMTDataset, GO_token, EOS_token, SEQ_MXLEN

In [None]:
# Set PyTorch device according to system offering
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
device

In [None]:
# Establish "global" hyperparameter variables and load in the dataset
HIDDEN_N = 128
BATCH_SIZE = 1 # TODO: Needs to be increased
DATASET_SIZE = BATCH_SIZE * (100000 // BATCH_SIZE)
dataset = NMTDataset('nmt_dataset/eng-fra.txt', DATASET_SIZE)
dataset_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
print('Example of an item from the dataset:')
dataset_loader.dataset[0]

In [None]:
# Create an encoder model
class Encoder(nn.Module):
    def __init__(self, n_input: int, n_hidden: int) -> None:
        super().__init__()
        self.n_input, self.n_hidden = n_input, n_hidden

        self.embedding_layer = nn.Embedding(n_input, n_hidden)
        self.rnn_cell = nn.GRU(n_hidden, n_hidden, batch_first=True)
    
    def forward(self, _x, _hn):
        _x_embedded = self.embedding_layer(_x).view(BATCH_SIZE, 1, -1)
        return self.rnn_cell(_x_embedded, _hn)
    
    def init_hidden(self):
        return torch.zeros(1, 1, self.n_hidden, device=device)

encoder = Encoder(dataset.input_lang.n_words, HIDDEN_N).to(device)
encoder

In [None]:
# Create a decoder model
class Decoder(nn.Module):
    def __init__(
        self, n_hidden: int, n_output: int,
        dropout_rate: float = 0.1
    ) -> None:
        super().__init__()
        self.n_hidden, self.n_output = n_hidden, n_output

        self.embedding = nn.Embedding(self.n_output, self.n_hidden)
        self.dropout = nn.Dropout(dropout_rate)
        self.attention = nn.Linear(n_hidden, n_hidden)
        self.w_c = nn.Linear(n_hidden * 2, n_hidden)
        self.rnn_cell = nn.GRU(n_hidden, n_hidden, batch_first=True)
        self.w_y = nn.Linear(n_hidden, n_output)
    
    def forward(self, _x, _hn, _encoder_outputs):
        _x_embedded = self.embedding(_x).view(BATCH_SIZE, 1, -1)
        _x_embedded = self.dropout(_x_embedded)
        _, _hn = self.rnn_cell(_x_embedded, _hn)
        _alignment_scores = torch.mm(
            self.attention(_hn)[0], _encoder_outputs.t())
        _attention_weights = nn.functional.softmax(
            _alignment_scores, dim=1)
        _c_t = torch.mm(_attention_weights, _encoder_outputs)
        _hidden_s_t = torch.cat([_hn[0], _c_t], dim=1)
        _hidden_s_t = torch.tanh(self.w_c(_hidden_s_t))
        _output = nn.functional.log_softmax(self.w_y(_hidden_s_t), dim=1)
        return _output, _hn, _attention_weights

    def init_hidden(self):
        return torch.zeros(1, 1, self.n_hidden, device=device)

decoder = Decoder(HIDDEN_N, dataset.output_lang.n_words).to(device)
decoder 

In [None]:
# Create the loss function and optimizers for training the above models
loss_function = nn.NLLLoss()
encoder_optimizer = torch.optim.Adam(encoder.parameters())
decoder_optimizer = torch.optim.Adam(decoder.parameters())

In [None]:
# Create and run a function to train the above models
def train(
    _encoder: Encoder, _decoder: Decoder, debug_step: int = None
) -> tuple[Encoder, Decoder]:
    total_loss = 0
    for batch_index, (seq1, seq2) in enumerate(dataset_loader):
        seq1 = seq1.to(device).squeeze(0)
        seq2 = seq2.to(device).squeeze(0)

        encoder_hn = _encoder.init_hidden()

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        len_seq1 = seq1.size(0)
        len_seq2 = seq2.size(0)

        encoder_outputs = torch.zeros(
            SEQ_MXLEN, _encoder.n_hidden, device=device)
        
        loss = torch.Tensor([0]).squeeze().to(device)

        with torch.set_grad_enabled(True):
            for encoder_index in range(len_seq1):
                encoder_output, encoder_hn = _encoder(
                    seq1[encoder_index], encoder_hn)
                encoder_outputs[encoder_index] = encoder_output[0, 0]
            
            decoder_input = torch.tensor([[GO_token]], device=device)
            decoder_hn = encoder_hn

            for decoder_index in range(len_seq2):
                decoder_output, decoder_hn, _ = _decoder(
                    decoder_input, decoder_hn, encoder_outputs
                )
                loss += loss_function(decoder_output, seq2[decoder_index])
                decoder_input = seq2[decoder_index]

            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()
        
        total_loss += loss.item() / len_seq2

        if debug_step:
            if batch_index % debug_step == 0:
                sys.stdout.write(
                    f'\r{batch_index // debug_step:3d}/'
                    f'{DATASET_SIZE // debug_step:3d} | '
                    f'Loss: {total_loss / debug_step:3.2f}')
                sys.stdout.flush()
                total_loss = 0

train(encoder, decoder, debug_step=1000)