# Teacher's Assignment - Extra Credit #1

***Author:*** *Ofir Paz* $\qquad$ ***Version:*** *15.07.2024* $\qquad$ ***Course:*** *22961 - Deep Learning* \
***Extra Assignment Course:*** *20998 - Extra Assignment 3*

Welcome to the second question of the extra assignment #1 as part of the course *Deep Learning*. \
In this question we will implement an RNN block with a basic pass-through control.

## Imports

In [1]:
import torch  # pytorch.
import torch.nn as nn  # neural network module.
import torch.nn.functional as F  # neural network functional module.
from torch.utils.data import DataLoader, Dataset  # data handling.
import torchtext; torchtext.disable_torchtext_deprecation_warning()
from torchtext.vocab import build_vocab_from_iterator  # vocabulary builder.
import matplotlib.pyplot as plt  # plotting module.
import datasets as ds  # public dataset module.
from base_model import BaseModel  # base model class.

# Type hinting.
from torch import Tensor
from torchtext.vocab import Vocab
from typing import Tuple

## The Implementation

In [2]:
class CustomRNNCell(nn.Module):
    """
    Custom RNN cell class.

    Use another hidden state - the pass-through hidden state to control the hidden state of the RNN cell.
    """
    def __init__(self, embed_dim: int, hidden_dim: int) -> None:
        super(CustomRNNCell, self).__init__()
        self.input_linear = nn.Linear(embed_dim, hidden_dim)
        self.hidden_linear = nn.Linear(hidden_dim, hidden_dim)
        self.regular_activation = nn.Tanh()

        self.pass_through_layer = nn.Linear(embed_dim, hidden_dim)
        self.pass_through_activation = nn.Softmax(dim=-1)

    def forward(self, one_embedded_token: Tensor, hidden_state: Tensor) -> Tensor:
        Z1 = self.input_linear(one_embedded_token)
        Z2 = self.hidden_linear(hidden_state)
        h_hat_t = self.regular_activation(Z1 + Z2)
        
        r_t = self.pass_through_activation(self.pass_through_layer(one_embedded_token))
        new_hidden_state = h_hat_t * r_t

        return new_hidden_state

### Explenation

We can implement "forgetness" by multiplying the data we want to "forget" by a factor that is smaller than 1, thus we are lowering the signal and "forget" it. This could be beneficial since we might want to pass through only part of the signal, in some cases where the other part is irrelevant to the rest of the text or is made up of noise. \
In this question, we are multiplying the signal by the same signal that is created with the current token only while augmenting it with learned weights.

I chose the softmax activation function for $R_t$, since this function has the range $(0, 1)$ so $R_t$ can only lower the signal and not increase it.

In my opinion, we can also implement forgetness by implmenting an exponential decay with the previous signal (just multiply the hidden state with some $0<r<1$). In this way we use less parameters and the signal will be forgotten in a more natural way.

## Testing The Implementation

In [3]:
# Load a dataset to try to fit on.
full_dataset: ds.DatasetDict = ds.load_dataset("glue", "sst2")  # type: ignore
big_train_dataset = full_dataset["train"]
big_validation_dataset = full_dataset["validation"]
train_dataset = big_train_dataset.select(range(500))  # small dataset for testing.
validation_dataset = big_validation_dataset.select(range(250))  # small dataset for testing.

In [4]:
# Create the vocabulary.
train_sentence_list = train_dataset["sentence"]
vocab = build_vocab_from_iterator(map(str.split, train_sentence_list), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

In [5]:
# Create the dataset class.
class SST2Dataset(Dataset):
    def __init__(self, dataset: ds.Dataset, vocab: Vocab) -> None:
        self.sentences = list(map(lambda seq: torch.tensor(vocab(seq.split())), dataset["sentence"]))
        self.labels = torch.tensor(dataset["label"], dtype=torch.long)

    def __len__(self) -> int:
        return len(self.sentences)

    def __getitem__(self, idx) -> Tuple[Tensor, Tensor]:
        return self.sentences[idx], self.labels[idx]

In [6]:
# Create the dataloaders.
train_set = SST2Dataset(train_dataset, vocab)
validation_set = SST2Dataset(validation_dataset, vocab)
train_loader = DataLoader(train_set, batch_size=1, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=1, shuffle=False)

In [7]:
class RNN(BaseModel):
    """
    RNN model class.

    The RNN model class uses the custom RNN cell to create a custom RNN model.
    """
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_classes: int) -> None:
        super(RNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.rnn_cell = CustomRNNCell(embed_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, sentence_tokens: Tensor) -> Tensor:
        batch_size = sentence_tokens.size(0)
        hidden_state = torch.zeros(batch_size, self.rnn_cell.hidden_linear.out_features, 
                                   device=sentence_tokens.device)
        embedded = self.embed(sentence_tokens)
        
        for i in range(embedded.size(1)):
            hidden_state = self.rnn_cell(embedded[:, i, :], hidden_state)
        
        return self.fc(hidden_state)

In [8]:
# Train the model.
model = RNN(len(vocab), embed_dim=20, hidden_dim=100, num_classes=2)

model.fit(train_loader, validation_loader, num_epochs=15, lr=0.001, try_cuda=False, print_stride=3)
model.fit(train_loader, validation_loader, num_epochs=10, lr=0.0001, try_cuda=False, print_stride=2)

Using CPU for training.
[epoch: 01/15] [Train loss: 0.693001  Train Accuracy: 0.510]  [Val loss: 0.692427]  Val Accuracy: 0.516]
[epoch: 04/15] [Train loss: 0.661650  Train Accuracy: 0.702]  [Val loss: 0.704553]  Val Accuracy: 0.520]
[epoch: 07/15] [Train loss: 0.503337  Train Accuracy: 0.802]  [Val loss: 0.698057]  Val Accuracy: 0.524]
[epoch: 10/15] [Train loss: 0.309917  Train Accuracy: 0.908]  [Val loss: 0.770862]  Val Accuracy: 0.520]
[epoch: 13/15] [Train loss: 0.172037  Train Accuracy: 0.958]  [Val loss: 1.052298]  Val Accuracy: 0.516]
[epoch: 15/15] [Train loss: 0.114313  Train Accuracy: 0.980]  [Val loss: 1.105240]  Val Accuracy: 0.540]
Using CPU for training.
[epoch: 16/25] [Train loss: 0.091820  Train Accuracy: 0.984]  [Val loss: 1.133645]  Val Accuracy: 0.536]
[epoch: 18/25] [Train loss: 0.086085  Train Accuracy: 0.986]  [Val loss: 1.176048]  Val Accuracy: 0.544]
[epoch: 20/25] [Train loss: 0.081168  Train Accuracy: 0.988]  [Val loss: 1.206758]  Val Accuracy: 0.544]
[epoch: