# Environment Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
"""
Change directory to where this file is located
"""
%cd 'COPY&PASTE FILE DIRECTORY HERE'

In [None]:
! pip install torchdata

In [None]:
import math
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchtext
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torchtext.data.functional import to_map_style_dataset

In [None]:
"""
import modules you need
"""


In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

print("Using PyTorch version: {}, Device: {}".format(torch.__version__, DEVICE)) ## should be 1.11.0 and cuda
print("Using torchtext version: {}".format(torchtext.__version__)) ## should be 0.12.0

# Load Data

In [None]:
"""
Load AG_NEWS dataset and set up the tokenizer and encoder pipeline.

Do NOT modify.
"""

train_data, test_data = torchtext.datasets.AG_NEWS(root='./data')

tokenizer = get_tokenizer('basic_english')

def tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

encoder = build_vocab_from_iterator(tokens(train_data), specials=["<unk>"])
encoder.set_default_index(encoder["<unk>"])

text_pipeline = lambda x: encoder(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

In [None]:
def collate_batch(batch):
    """
    Creates a batch of encoded text, label and token length tensors.

    Question (a)
    - Length of token sequence in each batch is determined by 
      the average of token length of all sequences in each batch.
    - Text tensors are stacked with dimension of (TOKEN_LENGTH, BATCH),
      for easier process in RNN model.
    - Token length tensors are used to index the last valid hidden token for classification.

    Inputs
    - list of tuples, each containing an integer label and a text input
    - number of tuples in the list == BATCH SIZE
    Returns
    - text_list: batch of encoded long type text tensors with size (TOKEN_LENGTH, BATCH)
    - label_list: batch of label tensors with size (BATCH)
    - len_list: batch of token length tensors with size (BATCH)
    """

    text_list, label_list, len_list = [], [], []
    
    ### COMPLETE HERE ###
    
    ### COMPLETE HERE ###
    
    assert text_list.size(1) == len(batch)

    return text_list, label_list, len_list

In [None]:
"""
Load the data loader.

Do NOT modify.
"""

BATCH_SIZE = 512

train_dataset = to_map_style_dataset(train_data)
test_dataset = to_map_style_dataset(test_data)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)

In [None]:
"""
Print out the first batch in the train loader.
Check if the collate function is implemented correctly.

Do NOT modify.
"""

for batch_x, batch_y, len_x in train_dataloader:
    print(batch_x[:10])
    print(batch_y[:10])
    print(len_x[:10])
    break

In [None]:
"""
Plot the sequence length distribution of the batches in the train dataloader.
Make sure that all batches have difference sequence lengths.

Do NOT modify.
"""

batch_len = []
for batch_x, _, _ in train_dataloader:
    batch_len.append(batch_x.size(0))
plt.hist(batch_len)

# Model

In [None]:
class RNN(nn.Module):
    def __init__(self, vocab_size, input_size, hidden_size, num_class):
        """
        Define the model weight parameters and initialize the weights.

        Question (b)
        - Complete the dimension and shape of the weights and biases.
        - Use the model parameters (vocab_size, input_size, hidden_size, num_class).
        """
        super(RNN, self).__init__()

        ### COMPLETE HERE ###
        whh_size = 
        wxh_size = 
        why_size = 
        bhh_size = 
        bxh_size = 
        bhy_size = 
        ### COMPLETE HERE ###

        kwargs = {'device': DEVICE, 'dtype': torch.float}
        self.hidden = hidden_size
        self.num_class = num_class
        self.embedding = nn.Embedding(vocab_size, input_size)
        self.W_hh = nn.parameter.Parameter(torch.empty(whh_size, **kwargs))
        self.W_xh = nn.parameter.Parameter(torch.empty(wxh_size, **kwargs))
        self.W_hy = nn.parameter.Parameter(torch.empty(why_size, **kwargs))
        self.b_hh = nn.parameter.Parameter(torch.empty(bhh_size, **kwargs))
        self.b_xh = nn.parameter.Parameter(torch.empty(bxh_size, **kwargs))
        self.b_hy = nn.parameter.Parameter(torch.empty(bhy_size, **kwargs))

        self.init_parameters()

    def init_parameters(self):
        """
        Initialize the parameters with Kaiming uniform initialization.

        Do NOT modify this method.
        """
        nn.init.kaiming_uniform_(self.W_hh, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_hh)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.b_hh, -bound, bound)
        nn.init.kaiming_uniform_(self.W_xh, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_xh)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.b_xh, -bound, bound)
        nn.init.kaiming_uniform_(self.W_hy, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_hy)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.b_hy, -bound, bound)

    def forward(self, inputs, length):
        """
        Question (c)
        - Passes a sequence of tokens into the recurrent network.
        - Randomly initialize h_0 with appropriate shape.
        - We do not want to use a hidden cell of a zero-padded token for classification!
        - Index the hidden state of the last valid token (excluding the zero-padding)
          based on the token length of each example in the batch.

        Inputs
        - a batch of encoded token sequences with shape (SEQ_LEN, BATCH_SIZE)
        - a batch of token lengths with shape (BATCH_SIZE)
        Returns
        - Softmax probabilites for each class with shape (BATCH_SIZE, NUM_CLASS)
        """
        
        ### COMPLETE HERE ###

        ### COMPLETE HERE ###

        return softmax_probs
    
    def compute_loss(self, prediction, label):
        """
        Question (d)
        - Compute the cross entropy loss and the number of correct predictions
        - Do NOT use loss function in torch.nn library ex) nn.CrossEntropyLoss()
        - Hint: use torch.nn.functional.one_hot(tensor, num_classes=?) to generate one-hot encodings


        Inputs
        - prediction: output from self.forward(inputs) with shape (BATCH_SIZE, NUM_CLASS)
        - label: integer labels of the batch inputs with shape (BATCH_SIZE)
        Returns
        - cross entropy loss of the batch (float) and number of correct predictions (integer)
        """
        loss = 0
        correct = 0

        ### COMPLETE HERE ###
        
        ### COMPLETE HERE ###

        return loss, correct

# Training Modules

In [None]:
class ScheduledOptim():
    """
    Learning rate scheduler.

    Do NOT modify.
    """

    def __init__(self, optimizer, n_warmup_steps, decay_rate):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.decay = decay_rate
        self.n_steps = 0
        self.initial_lr = optimizer.param_groups[0]['lr']
        self.current_lr = optimizer.param_groups[0]['lr']

    def zero_grad(self):
        self._optimizer.zero_grad()
    
    def step(self):
        self._optimizer.step()
    
    def get_lr(self):
        return self.current_lr
    
    def update(self):
        if self.n_steps < self.n_warmup_steps:
            lr = self.n_steps / self.n_warmup_steps * self.initial_lr
        elif self.n_steps == self.n_warmup_steps:
            lr = self.initial_lr
        else:
            lr = self.current_lr * self.decay
        
        self.current_lr = lr
        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

        self.n_steps += 1

In [None]:
"""
Functions for training and evaluating the model.

Question (e)
- There has been minor changes with the model forward operation and loss computation.
  Compare the updates with the train, evaluate functions that we have previously used,
  and complete the train and evaluate function that works for the current model architecture.
- Use the methods of the scheduler to perform necessary operations on the optimizer.
- Do NOT change the arguments given to the train, evaluate functions.
"""

def train(model, train_loader, scheduler):
    model.train()
    train_loss = 0
    correct = 0
    
    ### COMPLETE HERE ###
    tqdm_bar = tqdm(train_loader)

    for text, label, length in tqdm_bar:
        text = text.to(DEVICE)
        label = label.to(DEVICE)
        length = length.to(DEVICE)

    train_loss /= len(train_loader.dataset)
    train_acc = 100. * correct / len(train_loader.dataset)
    ### COMPLETE HERE ###
    
    return train_loss, train_acc

def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    ### COMPLETE HERE ###
    
    ### COMPLETE HERE ###

    return test_loss, test_acc

# Model Training

In [None]:
"""
Question (f)
- Train your RNN model and obtain the test accuracy of 70%.
- Select the input size, hidden size of your choice
- Try various optimizer type, learning rate and scheduler options for the best performance.
"""

### COMPLETE HERE ###
EPOCHS = 0
vocab_size = 0
input_size = 0
hidden_size = 0
num_class = 0

model = None
optimizer = None
scheduler = None
### COMPLETE HERE ###

for epoch in range(1, EPOCHS + 1):
    loss_train, accu_train = train(model, train_dataloader, scheduler)
    loss_val, accu_val = evaluate(model, valid_dataloader)
    lr = scheduler.get_lr()
    print('-' * 83)
    print('| end of epoch {:2d} | lr: {:5.4f} | train accuracy: {:8.3f} | '
          'valid accuracy {:8.3f} '.format(epoch, lr, accu_train, accu_val))
    print('-' * 83)