# Teacher's Assignment - Extra Credit #1

***Author:*** *Ofir Paz* $\qquad$ ***Version:*** *15.07.2024* $\qquad$ ***Course:*** *22961 - Deep Learning* \
***Extra Assignment Course:*** *20998 - Extra Assignment 3*

Welcome to the first question of the extra assignment #1 as part of the course *Deep Learning*. \
In this question we will train an RNN network for classification on the SST-2 dataset while dealing with the exploding gradient problem.

## Imports

In [4]:
import torch  # pytorch.
import torch.nn as nn  # neural network module.
import torch.optim as optim  # optimization module.
import torch.nn.functional as F  # functional module.
import numpy as np  # numpy.
from torch.utils.data import DataLoader, Dataset  # data handling.
import torchtext; torchtext.disable_torchtext_deprecation_warning()
from torchtext.vocab import build_vocab_from_iterator  # vocabulary builder.
import matplotlib.pyplot as plt  # plotting module.
import datasets as ds  # public dataset module.
from base_model import BaseModel  # base model class.

# Type hinting.
from torch import Tensor
from torchtext.vocab import Vocab
from typing import Tuple

## Loading & Pre-Processing

In [5]:
# Load the SST-2 dataset.
dataset: ds.DatasetDict = ds.load_dataset("glue", "sst2")  # type: ignore

train_set = dataset["train"][:3000]
validation_set = dataset["validation"][:1000]
test_set = dataset["test"]

In [6]:
# Create the vocabulary.
vocab = build_vocab_from_iterator(map(str.split, train_set["sentence"]), specials=["<unk>"], min_freq=5)
vocab.set_default_index(vocab["<unk>"])

In [7]:
# Create the SST-2 dataset.
class SST2Dataset(Dataset):
    def __init__(self, dataset: ds.Dataset, vocab: Vocab) -> None:
        self.sentences = list(map(lambda seq: torch.tensor(vocab(seq.split())), dataset["sentence"]))
        self.labels = torch.tensor(dataset["label"], dtype=torch.long)

    def __len__(self) -> int:
        return len(self.sentences)

    def __getitem__(self, idx) -> Tuple[Tensor, Tensor]:
        return self.sentences[idx], self.labels[idx]

In [8]:
train_dataset = SST2Dataset(train_set, vocab)
validation_dataset = SST2Dataset(validation_set, vocab)
test_dataset = SST2Dataset(test_set, vocab)

In [9]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False)

## The RNN

In [44]:
class RNNClasifer(BaseModel):
    """
    Recurrent Neural Network (RNN) classifier, designed specifically for the SST-2 dataset.
    """
    def __init__(self, vocab: Vocab, t: int, embed_dim: int, hidden_dim: int, num_classes: int,
                 RNNlayers: int = 2, **kwargs) -> None:
        super(RNNClasifer, self).__init__(**kwargs)
        self.t = t
        self.embedding = nn.Embedding(len(vocab), embed_dim)
        self.rnns = nn.RNN(embed_dim, hidden_dim, RNNlayers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)
        
    def rnn_forward(self, x: Tensor) -> Tensor:
        x, _ = self.rnns(x)
        return x
    
    def normal_forward(self, x: Tensor) -> Tensor:
        x = self.embedding(x)
        x = self.rnn_forward(x)
        x = x[:, -1, :]
        x = self.fc(x)
        return x
    
    def forward(self, x: Tensor, capture_gradients: bool = False) -> Tensor:
        assert len(x.size()) == 2 and x.size(0) == 1, f"Can only process batch size of 1. Got: {x.size()}."
        if x.size(1) <= self.t:
            return self.normal_forward(x)
        
        x = self.embedding(x)  # Shape: (batch_size, seq_len, embed_dim).

        with torch.no_grad():  # Disable gradient tracking for the first T-t tokens.
            x_no_grad = self.rnn_forward(x[:, :-self.t])  # Shape: (T-t, hidden_dim).
        

        # Last t tokens (with gradient tracking).
        x_grad = self.rnn_forward(x[:, -self.t:])  # Shape: (t, hidden_dim).

        # If we want to capture gradients, register hooks to save gradient norms.
        if capture_gradients:
            self.save_grads(x_grad)

        # Combine the outputs.
        combined_output = torch.cat((x_no_grad, x_grad), dim=1)  # Shape: (batch_size, seq_len, hidden_dim).
        
        # Take the hidden state of the last token.
        final_output = combined_output[:, -1, :]  # Shape: (batch_size, hidden_dim).

        # Pass through the fully connected layer for classification.
        final_output = self.fc(final_output)  # Shape: (batch_size, num_classes).
        
        return final_output
    
    def save_grads(self, rnn_outs: Tensor) -> None:
        # Compute the norm of the gradient and store it
        self.gradient_norms = []
        for i in range(rnn_outs.size(1)):  # Iterate over the sequence length (last t tokens)
            grad_norm = rnn_outs[:, i].backward(retain_graph=True).norm().item()
            self.gradient_norms.append(grad_norm)
    
    def plot_gradients(self) -> None:
        """
        Plots the gradient norms for each token after backpropagation.
        """
        if len(self.gradient_norms) == 0:
            print("No gradients captured yet.")
            return

        # Plot the gradient norms
        plt.figure(figsize=(10, 6))
        plt.plot(range(-self.t, 0), self.gradient_norms, marker='o', label="Gradient Norms (Last t Tokens)")
        plt.axhline(0, color='r', linestyle='--', label='No Gradient (First T-t Tokens)')
        plt.xlabel("Token Index (Relative to Last t Tokens)")
        plt.ylabel("Gradient Norm")
        plt.title("Gradient Norms for the Last t Tokens")
        plt.legend()
        plt.show()

In [11]:
# Initialize the model.
model = RNNClasifer(vocab, t=5, embed_dim=32, hidden_dim=64, num_classes=2, RNNlayers=1, 
                    task_type="classification")
# Train the model.
model.fit(train_loader, validation_loader, num_epochs=2, try_cuda=False)

Using CPU for training.
[epoch: 01/10] [Train loss: 0.699764  Train Accuracy: 0.532]  [Val loss: 0.694710]  Val Accuracy: 0.513]
[epoch: 02/10] [Train loss: 0.679523  Train Accuracy: 0.578]  [Val loss: 0.695728]  Val Accuracy: 0.526]
[epoch: 03/10] [Train loss: 0.645737  Train Accuracy: 0.620]  [Val loss: 0.708480]  Val Accuracy: 0.537]
[epoch: 04/10] [Train loss: 0.604211  Train Accuracy: 0.662]  [Val loss: 0.704215]  Val Accuracy: 0.561]
[epoch: 05/10] [Train loss: 0.556997  Train Accuracy: 0.699]  [Val loss: 0.698462]  Val Accuracy: 0.544]
[epoch: 06/10] [Train loss: 0.507165  Train Accuracy: 0.739]  [Val loss: 0.777916]  Val Accuracy: 0.547]
[epoch: 07/10] [Train loss: 0.469323  Train Accuracy: 0.761]  [Val loss: 0.792984]  Val Accuracy: 0.547]
[epoch: 08/10] [Train loss: 0.422144  Train Accuracy: 0.801]  [Val loss: 0.823008]  Val Accuracy: 0.549]
[epoch: 09/10] [Train loss: 0.389506  Train Accuracy: 0.811]  [Val loss: 0.842547]  Val Accuracy: 0.568]
[epoch: 10/10] [Train loss: 0.3

In [45]:
model = RNNClasifer(vocab, t=5, embed_dim=32, hidden_dim=64, num_classes=2, RNNlayers=1, 
                    task_type="classification")
# Plot the gradient norms
for x, y in train_loader:
    model(x, capture_gradients=True)
    model.plot_gradients()
    break

RuntimeError: grad can be implicitly created only for scalar outputs