# Teacher's Assignment - Extra Credit #1

***Author:*** *Ofir Paz* $\qquad$ ***Version:*** *15.07.2024* $\qquad$ ***Course:*** *22961 - Deep Learning* \
***Extra Assignment Course:*** *20998 - Extra Assignment 3*

Welcome to the second question of the extra assignment #1 as part of the course *Deep Learning*. \
In this question we will implement an RNN block with a basic pass-through control.

## Imports

In [1]:
import torch  # pytorch.
import torch.nn as nn  # neural network module.
import torch.nn.functional as F  # neural network functional module.
from torch.utils.data import DataLoader, Dataset  # data handling.
import torchtext; torchtext.disable_torchtext_deprecation_warning()
from torchtext.vocab import build_vocab_from_iterator  # vocabulary builder.
import matplotlib.pyplot as plt  # plotting module.
import datasets as ds  # public dataset module.
from base_model import BaseModel  # base model class.

# Type hinting.
from torch import Tensor
from torchtext.vocab import Vocab
from typing import Tuple

## The Implementation

In [9]:
class CustomRNNCell(nn.Module):
    """
    Custom RNN cell class.

    Use another hidden state - the pass-through hidden state - to control the hidden state of the RNN cell.
    
    :math:`h^{\hat}_t = tanh(W_{ih}x_t + b_{ih} + W_{hh}h_{t-1} + b_{hh})`
    :math:`r_t = softmax(W_{ih}x_t + b_{ih})`
    :math:`h_t = h^{\hat}_t \odot r_t`  # element-wise multiplication.
    """
    def __init__(self, embed_dim: int, hidden_dim: int) -> None:
        super(CustomRNNCell, self).__init__()
        self.input_linear = nn.Linear(embed_dim, hidden_dim)
        self.hidden_linear = nn.Linear(hidden_dim, hidden_dim)
        self.regular_activation = nn.Tanh()

        self.pass_through_layer = nn.Linear(embed_dim, hidden_dim)
        self.pass_through_activation = nn.Softmax(dim=-1)

    def forward(self, one_embedded_token: Tensor, hidden_state: Tensor) -> Tensor:
        Z1 = self.input_linear(one_embedded_token)
        Z2 = self.hidden_linear(hidden_state)
        h_hat_t = self.regular_activation(Z1 + Z2)
        
        r_t = self.pass_through_activation(self.pass_through_layer(one_embedded_token))
        new_hidden_state = h_hat_t * r_t

        return new_hidden_state

### Explenation

## Testing The Implementation

In [44]:
# Load a dataset to try to fit on.
full_dataset: ds.DatasetDict = ds.load_dataset("glue", "sst2")  # type: ignore
big_train_dataset = full_dataset["train"]
big_validation_dataset = full_dataset["validation"]
train_dataset = big_train_dataset.select(range(500))  # small dataset for testing.
validation_dataset = big_validation_dataset.select(range(250))  # small dataset for testing.

In [45]:
# Create the vocabulary.
train_sentence_list = train_dataset["sentence"]
vocab = build_vocab_from_iterator(map(str.split, train_sentence_list), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [46]:
# Create the dataset class.
class SST2Dataset(Dataset):
    def __init__(self, dataset: ds.Dataset, vocab: Vocab) -> None:
        self.sentences = list(map(lambda seq: torch.tensor(vocab(seq.split())), dataset["sentence"]))
        self.labels = torch.tensor(dataset["label"], dtype=torch.long)

    def __len__(self) -> int:
        return len(self.sentences)

    def __getitem__(self, idx) -> Tuple[Tensor, Tensor]:
        return self.sentences[idx], self.labels[idx]

In [47]:
# Create the dataloaders.
train_set = SST2Dataset(train_dataset, vocab)
validation_set = SST2Dataset(validation_dataset, vocab)
train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=1, shuffle=False)

In [49]:
class RNN(BaseModel):
    """
    RNN model class.

    The RNN model class uses the custom RNN cell to create a custom RNN model.
    """
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_classes: int) -> None:
        super(RNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.rnn_cell = CustomRNNCell(embed_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, sentence_tokens: Tensor) -> Tensor:
        batch_size = sentence_tokens.size(0)
        hidden_state = torch.zeros(batch_size, self.rnn_cell.hidden_linear.out_features, 
                                   device=sentence_tokens.device)
        embedded = self.embed(sentence_tokens)
        
        for i in range(embedded.size(1)):
            hidden_state = self.rnn_cell(embedded[:, i, :], hidden_state)
        
        return self.fc(hidden_state)

In [51]:
# Train the model.
model = RNN(len(vocab), embed_dim=20, hidden_dim=100, num_classes=2)

model.fit(train_loader, validation_loader, num_epochs=15, lr=0.001, try_cuda=False, print_stride=3)
model.fit(train_loader, validation_loader, num_epochs=10, lr=0.0001, try_cuda=False, print_stride=2)

Using CPU for training.
[epoch: 01/15] [Train loss: 0.693663  Train Accuracy: 0.478]  [Val loss: 0.692511]  Val Accuracy: 0.516]
[epoch: 04/15] [Train loss: 0.656935  Train Accuracy: 0.676]  [Val loss: 0.698796]  Val Accuracy: 0.512]
[epoch: 07/15] [Train loss: 0.506899  Train Accuracy: 0.812]  [Val loss: 0.708899]  Val Accuracy: 0.520]
[epoch: 10/15] [Train loss: 0.342418  Train Accuracy: 0.898]  [Val loss: 0.751542]  Val Accuracy: 0.476]
[epoch: 13/15] [Train loss: 0.198692  Train Accuracy: 0.952]  [Val loss: 0.872927]  Val Accuracy: 0.524]
[epoch: 15/15] [Train loss: 0.137132  Train Accuracy: 0.964]  [Val loss: 0.989995]  Val Accuracy: 0.512]
Using CPU for training.
[epoch: 16/25] [Train loss: 0.109748  Train Accuracy: 0.978]  [Val loss: 0.995595]  Val Accuracy: 0.512]
[epoch: 18/25] [Train loss: 0.103879  Train Accuracy: 0.980]  [Val loss: 1.014728]  Val Accuracy: 0.508]
[epoch: 20/25] [Train loss: 0.098271  Train Accuracy: 0.984]  [Val loss: 1.033314]  Val Accuracy: 0.508]
[epoch: