# Teacher's Assignment - Extra Credit #1

***Author:*** *Ofir Paz* $\qquad$ ***Version:*** *15.07.2024* $\qquad$ ***Course:*** *22961 - Deep Learning* \
***Extra Assignment Course:*** *20998 - Extra Assignment 3*

Welcome to the first question of the extra assignment #1 as part of the course *Deep Learning*. \
In this question we will train an RNN network for classification on the SST-2 dataset while dealing with the exploding gradient problem.

## Imports

In [None]:
import torch  # pytorch.
import torch.nn as nn  # neural network module.
import torch.optim as optim  # optimization module.
import torch.nn.functional as F  # functional module.
import numpy as np  # numpy.
from torch.utils.data import DataLoader, Dataset  # data handling.
import torchtext; torchtext.disable_torchtext_deprecation_warning()
from torchtext.vocab import build_vocab_from_iterator  # vocabulary builder.
import matplotlib.pyplot as plt  # plotting module.
import datasets as ds  # public dataset module.
from base_model import BaseModel  # base model class.

# Type hinting.
from torch import Tensor
from torchtext.vocab import Vocab
from typing import Tuple

## Loading & Pre-Processing

In [None]:
# Load the SST-2 dataset.
dataset: ds.DatasetDict = ds.load_dataset("glue", "sst2")  # type: ignore

train_dataset = dataset["train"]
validation_dataset = dataset["validation"]
test_dataset = dataset["test"]

In [None]:
# Create the vocabulary.
train_sentence_list = train_dataset["sentence"]
vocab = build_vocab_from_iterator(map(str.split, train_sentence_list), 
                                  specials=["<unk>", "<pad>"], min_freq=5)
vocab.set_default_index(vocab["<unk>"])

In [None]:
# Create the SST-2 dataset.
class SST2Dataset(Dataset):
    def __init__(self, dataset: ds.Dataset, vocab: Vocab) -> None:
        self.sentences = list(map(lambda seq: torch.tensor(vocab(seq.split())), dataset["sentence"]))
        self.labels = torch.tensor(dataset["label"], dtype=torch.long)

    def __len__(self) -> int:
        return len(self.sentences)

    def __getitem__(self, idx) -> Tuple[Tensor, Tensor]:
        return self.sentences[idx], self.labels[idx]

In [None]:
train_set = SST2Dataset(train_dataset, vocab)
validation_set = SST2Dataset(validation_dataset, vocab)
test_set = SST2Dataset(test_dataset, vocab)

## The RNN

In [None]:
class RNNClasifer(BaseModel):
    """
    Recurrent Neural Network (RNN) classifier, designed specifically for the SST-2 dataset.
    """
    def __init__(self, vocab: Vocab, embed_dim: int, hidden_dim: int, num_classes: int,
                 RNNlayers: int = 2) -> None:
        super(RNNClasifer, self).__init__()
        self.embedding = nn.Embedding(len(vocab), embed_dim, padding_idx=vocab["<pad>"])
        self.rnns = nn.RNN(embed_dim, hidden_dim, RNNlayers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x: Tensor) -> Tensor:
        x = self.embedding(x)
        x, _ = self.rnns(x)
        x = x[:, -1, :]  # Take the last feature vector of all batches.
        x = self.fc(x)
        return x