# Tiny Transformer Summarizer from Scratch
The objective of this notebook is to train a LLM from scratch on a small dataset for text summarization. The model will be trained directly on the *text summarization* task. As last week, we will continue working on:
1. Data preparation
2. Building the LLM architecture
3. Training an LLM

However, there are two main differences with respect to the *transformer classifier* that we trained last week:
- *Encoder-Decoder architecture.* In order to be able to perform the text summarization task, we will train a transformer having an *encoder-decoder* type of architecture (or Seq2Seq architecture). The encoder part is quite similar to that of an encoder-only model, we will focus on the decoder part.
- *Maksing.* We will also focus on the masking technique. Masking is important to make sure that the decoder does not look into future tokens or the current token during training. We will also use masking in order to ignore the padding tokens that we will add to complete the different batches.

We start with the usual library imports.

In [None]:
import math

import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

## Encoder-decoder architecture
Since we are training a *transformer model* for the task of text summarization, an *encoder-decoder* transformer is necessary. Our encoder-decoder model will follow the classical transformer architecture, with the only difference that we will only be stacking a single encoder block and a single decoder block.

We will start by implementing all the necessary sub-modules.

### Multi-Head Attention
The *multi-head* attention layer takes the input and processes it in chunks of equal length through eacc of its different heads, by using the attention mechanism, i.e. computing the queries, keys and values, and the attention scores and weights from them.

We will modify the `MultiHeadAttention` class that we coded last week in order to:
- Allow for queries, keys and values coming from different input (important for the *decoder cross-attention* mechanism).
- Allow for *masking:* the `forward` method of the class will take an additional optional `mask` argument that we will apply to the attention scores before the softmax computation.

**Question.** Since we are masking attention scores *before* taking the softmax, what value should we replace the masked entries with?

**Exercise.** Choose an appropriate value for the masking process in the `attention`method of the class.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads  # Compute the per-head hidden dimension

        self.W_query = nn.Linear(hidden_dim, hidden_dim) # queries weight matrix
        self.W_key = nn.Linear(hidden_dim, hidden_dim) # keys weight matrix
        self.W_value = nn.Linear(hidden_dim, hidden_dim) # values weight matrix

        self.W_out = nn.Linear(hidden_dim, hidden_dim)  # output weight matrix


    @staticmethod
    def attention(queries, keys, values, mask):
        head_dim = queries.shape[-1]

        # Compute attention scores
        attn_scores = torch.einsum("bijk, bikl -> bijl", queries, keys.transpose(2, 3))
        attn_scores = attn_scores / math.sqrt(head_dim)

        # TODO: Add appropriate value for masking
        if mask is not None:
            attn_scores.masked_fill_(mask == 0, # TODO: Add appropriate value)

        # Compute attention weights
        attn_weights = torch.softmax(attn_scores, dim=-1)

        # Attention output
        attn_output = torch.einsum("bijk, bikl -> bijl", attn_weights, values)

        # Return attention outputs (and attention weights for visualization purposes)
        return attn_output


    def forward(self, q, k, v, mask):
        batch_size, _, hidden_dim = q.shape
        assert hidden_dim == self.hidden_dim, f"hidden_dim must be {self.hidden_dim}"

        queries = self.W_query(q) # Shape: (batch_size, num_tokens, hidden_dim)
        keys = self.W_key(k)  # Shape: (batch_size, num_tokens, hidden_dim)
        values = self.W_value(v) # Shape: (batch_size, num_tokens, hidden_dim)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (batch_size, num_tokens, d_out) -> (batch_size, num_tokens, num_heads, head_dim)
        queries = queries.view(batch_size, -1, self.num_heads, self.head_dim)
        keys = keys.view(batch_size, -1, self.num_heads, self.head_dim)
        values = values.view(batch_size, -1, self.num_heads, self.head_dim)

        # Transpose: (batch_size, num_tokens, num_heads, head_dim) -> (batch_size, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention)
        attn_output = self.attention(queries, keys, values, mask)

        # Transpose back: (batch_size, num_heads, num_tokens, head_dim) -> (batch_size, num_tokens, num_heads, head_dim)
        attn_output = attn_output.transpose(1, 2)

        # Concatenate heads: (batch_size, num_tokens, num_heads, head_dim) -> (batch_size, num_tokens, hiddend_dim)
        attn_output = attn_output.reshape(batch_size, -1, self.hidden_dim)

        # Compute output
        output = self.W_out(attn_output)

        return output

**Solution.** Click below to check the solution.

In [None]:
# @title
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads  # Compute the per-head hidden dimension

        self.W_query = nn.Linear(hidden_dim, hidden_dim) # queries weight matrix
        self.W_key = nn.Linear(hidden_dim, hidden_dim) # keys weight matrix
        self.W_value = nn.Linear(hidden_dim, hidden_dim) # values weight matrix

        self.W_out = nn.Linear(hidden_dim, hidden_dim)  # output weight matrix


    @staticmethod
    def attention(queries, keys, values, mask):
        head_dim = queries.shape[-1]

        # Compute attention scores
        attn_scores = torch.einsum("bijk, bikl -> bijl", queries, keys.transpose(2, 3))
        attn_scores = attn_scores / math.sqrt(head_dim)

        # Apply masking
        if mask is not None:
            attn_scores.masked_fill_(mask == 0, float("-inf"))

        # Compute attention weights
        attn_weights = torch.softmax(attn_scores, dim=-1)

        # Attention output
        attn_output = torch.einsum("bijk, bikl -> bijl", attn_weights, values)

        # Return attention outputs (and attention weights for visualization purposes)
        return attn_output


    def forward(self, q, k, v, mask):
        batch_size, _, hidden_dim = q.shape
        assert hidden_dim == self.hidden_dim, f"hidden_dim must be {self.hidden_dim}"

        queries = self.W_query(q) # Shape: (batch_size, num_tokens, hidden_dim)
        keys = self.W_key(k)  # Shape: (batch_size, num_tokens, hidden_dim)
        values = self.W_value(v) # Shape: (batch_size, num_tokens, hidden_dim)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (batch_size, num_tokens, d_out) -> (batch_size, num_tokens, num_heads, head_dim)
        queries = queries.view(batch_size, -1, self.num_heads, self.head_dim)
        keys = keys.view(batch_size, -1, self.num_heads, self.head_dim)
        values = values.view(batch_size, -1, self.num_heads, self.head_dim)

        # Transpose: (batch_size, num_tokens, num_heads, head_dim) -> (batch_size, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention)
        attn_output = self.attention(queries, keys, values, mask)

        # Transpose back: (batch_size, num_heads, num_tokens, head_dim) -> (batch_size, num_tokens, num_heads, head_dim)
        attn_output = attn_output.transpose(1, 2)

        # Concatenate heads: (batch_size, num_tokens, num_heads, head_dim) -> (batch_size, num_tokens, hiddend_dim)
        attn_output = attn_output.reshape(batch_size, -1, self.hidden_dim)

        # Compute output
        output = self.W_out(attn_output)

        return output

### Layer Norm
The `LayerNorm` module is the same as lask week.

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.eps = 1e-5 # small value to avoid division by zero
        self.scale = nn.Parameter(torch.ones(hidden_dim)) # scale parameter (learnable)
        self.shift = nn.Parameter(torch.zeros(hidden_dim)) # shift parameter (learnable)

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        y = self.scale * norm_x + self.shift
        return y

### Feed-Forward Network
The *feed forward* module is the same as last week.

In [None]:
class FeedForward(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )

    def forward(self, x):
        return self.layers(x)

### Transformer Encoder Block
We will keep the same architecture for the `TransformerEncoderBlock`as we did last week. However, we have changed the `MultiHeadAttention` class which now takes its inputs differently.

**Exercise.** Implement the `TransformerEncoderBlock` module below by completing the `TODO` tags.

In [None]:
class TransformerEncoderBlock(nn.Module):
  def __init__(self, hidden_dim, num_heads):
    super().__init__()

    self.attention = MultiHeadAttention(hidden_dim, num_heads)
    self.norm1 = nn.LayerNorm(hidden_dim)
    self.norm2 = nn.LayerNorm(hidden_dim)
    self.feed_forward = FeedForward(hidden_dim)

  def forward(self, x, mask):
      attn_output = # TODO: apply the attention module to the correct inputs
      x = x + attn_output
      x = self.norm1(x)
      ff_output = self.feed_forward(x)
      x = x + ff_output
      x = self.norm2(x)
      return x

**Solution.** Click below to check the solution.

In [None]:
# @title
class TransformerEncoderBlock(nn.Module):
  def __init__(self, hidden_dim, num_heads):
    super().__init__()

    self.attention = MultiHeadAttention(hidden_dim, num_heads)
    self.norm1 = nn.LayerNorm(hidden_dim)
    self.norm2 = nn.LayerNorm(hidden_dim)
    self.feed_forward = FeedForward(hidden_dim)

  def forward(self, x, mask):
      attn_output = self.attention(x, x, x, mask)
      x = x + attn_output
      x = self.norm1(x)
      ff_output = self.feed_forward(x)
      x = x + ff_output
      x = self.norm2(x)
      return x

###Transformer Decoder Block
We will now build the *Transformer decoder block*, which is composed of:
1. A *Multi-head attention* sub-module, which will perform the self-attention on the decoder input, and will apply both a padding mask and a causal mask.
2. A *LayerNorm* normalization.
3. A *Multi-head attention* sub-module, which will perform the cross-attention between the decoder hidden state and the encoder hidden state, and will apply the encoder padding mask.
4. A *LayerNorm* normalization
5. A *Feed Forward* sub-module.
6. A *LayerNorm* normalization

Moreover, three skip-connections are present in the *Transformer Block*:
- *(A)* A skip connection that adds the original input to the output of the operation *1* above.
- *(B)* A skip connection that adds the output of the operation *2* to the output of the operation *3* above.
- *(C)* A skip connection that adds the output of the operation *4* to the output of the operation *5* above

**Exercise.** Implement the `TransformerDecoderBlock` module below by completing the `TODO` tags.

In [None]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()

        self.self_attention = # TODO: initialize the self attention module with the appropriate dimensions
        self.norm1 = # TODO: initialize the first Layer Norm with the appropriate dimensions
        self.cross_attention = # TODO: initialize the self attention module with the appropriate dimensions
        self.norm2 = # TODO: initialize the second Layer Norm with the appropriate dimensions
        self.feed_forward = # TODO: initialize the Feed Forward module with the appropriate dimensions
        self.norm3 = # TODO: initialize the third Layer Norm with the appropriate dimensions

    def forward(self, x, enc_hidden_state, src_mask, tgt_mask):
        tgt_mask = # TODO: add causal mask to target mask by using the add_causal_mask function below
        # TODO: implement the rest of the forward pass of the TransformerDecoderBLock
        return x

    def add_causal_mask(mask):
        # assuming attn scores of shape (batch_size, num_heads, seq_len, seq_len)
        # assuming mask of shape (batch_size, 1, 1, seq_len)
        causal_mask = torch.triu(torch.ones(1, 1, mask.size(-1), mask.size(-1)), diagonal=1).type(torch.int)  # shape (1, 1, seq_len, seq_len)
        causal_mask = causal_mask == 0
        causal_mask.to(device)
        return mask & causal_mask

**Solution.** Click below to check the solution.

In [None]:
# @title
class TransformerDecoderBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()

        self.self_attention = MultiHeadAttention(hidden_dim, num_heads)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.cross_attention = MultiHeadAttention(hidden_dim, num_heads)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.feed_forward = FeedForward(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)

    def forward(self, x, enc_hidden_state, src_mask, tgt_mask):
        tgt_mask = add_causal_mask(tgt_mask) # add causal mask
        attn_output = self.self_attention(x, x, x, tgt_mask)
        x = x + attn_output
        x = self.norm1(x)
        attn_output = self.cross_attention(x, enc_hidden_state, enc_hidden_state, src_mask)
        x = x + attn_output
        x = self.norm2(x)
        ff_output = self.feed_forward(x)
        x = x + ff_output
        x = self.norm3(x)
        return x

    def add_causal_mask(mask):
        # assuming attn scores of shape (batch_size, num_heads, seq_len, seq_len)
        # assuming mask of shape (batch_size, 1, 1, seq_len)
        causal_mask = torch.triu(torch.ones(1, 1, mask.size(-1), mask.size(-1)), diagonal=1).type(torch.int)  # shape (1, 1, seq_len, seq_len)
        causal_mask = causal_mask == 0
        causal_mask.to(device)
        return mask & causal_mask

### Embedding
The `Embedding` module is the same as for the encoder-only model of last week.

In [None]:
class Embedding(nn.Module):
  def __init__(self, vocab_size, max_length, hidden_dim):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, hidden_dim)
    self.position_encoding = nn.Embedding(max_length, hidden_dim)

  def forward(self, x):
    _, seq_length = x.shape
    token_embeddings = self.embedding(x)
    pos_encodings = self.position_encoding(torch.arange(seq_length, device=x.device))
    return token_embeddings + pos_encodings

### Transformer Encoder-Decoder Model for Text Summarization
We have implemented all necessary sub-modules ane we are now ready to implement the *end-to-end* architecture of the *transformer encoder-decoder* model. To that end, we will make use of:
- The `Embedding` sub-modules (one for the encoder, one for the decoder).
- The `TransformerEncoderBlock` sub-module
- The `TransformerDecoderBlock` sub-module
- A `nn.Linear` layer to act as a classification head.

**Note.** The output of the `TransformerDecoderBlock` sub-module is a tensor of shape $(b, s, d_h)$ where $b$ represents the batch size, $s$ the sequence length, and $d_h$ the hidden dimension of the model. In order to perform the classification task, we need to compute one logit per token and per sequence in the batch.

**Exercise.** Implement the `TransformerSummarizer` model by completing the `TODO` tags below.

In [None]:
class TransformerSummarizer(nn.Module):
    def __init__(self,
               vocab_size,
               encoder_max_length,
               decoder_max_length,
               hidden_dim,
               num_heads):
        super().__init__()

        self.src_embedding = # TODO: initialize the encoder embedding with the appropriate parameters
        self.tgt_embedding = # TODO: initialize the decoder embedding with the appropriate parameters
        self.encoder = # TODO: initialize the encoder with the appropriate parameters
        self.decoder = # TODO: initialize the decoder with the appropriate parameters
        self.summarizer_head = # TODO: initialize the classifier head with the appropriate parameters


    def encode(self, x_src, src_mask):
        src_mask = src_mask.unsqueeze(1).unsqueeze(2) # (b, seq_len) -> (b, 1, 1, seq_len)
        x_src = # TODO: compute the embedding of x_src
        x_src = # TODO: compute the encoder hidden state of x
        return x_src


    def decode(self, x_tgt, encoder_hidden_state, src_mask, tgt_mask):
        src_mask = src_mask.unsqueeze(1).unsqueeze(2) # (b, seq_len) -> (b, 1, 1, seq_len)
        tgt_mask = tgt_mask.unsqueeze(1).unsqueeze(2) # (b, seq_len) -> (b, 1, 1, seq_len)
        x_tgt = # TODO: compute the embedding of x_tgt
        x_tgt = # TODO: compute the decoder hidden state of x_tgt
        return x_tgt


    def forward(self, x_src, x_tgt, src_mask, tgt_mask):
        # TODO: implement the forward pass and return the logit values
        return logits

**Solution.** Click below to check the solution.

In [None]:
# @title
class TransformerSummarizer(nn.Module):
    def __init__(self,
               vocab_size,
               encoder_max_length,
               decoder_max_length,
               hidden_dim,
               num_heads):
        super().__init__()

        self.src_embedding = Embedding(vocab_size, encoder_max_length, hidden_dim)
        self.tgt_embedding = Embedding(vocab_size, decoder_max_length, hidden_dim)
        self.encoder = TransformerEncoderBlock(hidden_dim, num_heads)
        self.decoder = TransformerDecoderBlock(hidden_dim, num_heads)
        self.summarizer_head = nn.Linear(hidden_dim, vocab_size)


    def encode(self, x_src, src_mask):
        src_mask = src_mask.unsqueeze(1).unsqueeze(2)
        x_src = self.src_embedding(x_src)
        x_src = self.encoder(x_src, src_mask)
        return x_src


    def decode(self, x_tgt, encoder_hidden_state, src_mask, tgt_mask):
        src_mask = src_mask.unsqueeze(1).unsqueeze(2)
        tgt_mask = tgt_mask.unsqueeze(1).unsqueeze(2)
        x_tgt = self.tgt_embedding(x_tgt)
        x_tgt = self.decoder(x_tgt, encoder_hidden_state, src_mask, tgt_mask)
        return x_tgt


    def forward(self, x_src, x_tgt, src_mask, tgt_mask):
        encoder_hidden_state = self.encode(x_src, src_mask)
        x_tgt = self.decode(x_tgt, encoder_hidden_state, src_mask, tgt_mask)
        logits = self.summarizer_head(x_tgt)
        return logits

## Data Preparation
We will train the `TransformerSummarizer` model on a text summarization task. The data set we will be using is the following [dataset](https://huggingface.co/datasets/EdinburghNLP/xsum). It consists of pairs of text of type `document` and `summary`.


In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset, DatasetDict

raw_dataset = load_dataset("EdinburghNLP/xsum")
raw_dataset

**Questions.** Describe the datasets above.

**Exercise.** Print one of the elements of ``raw_dataset["train"]``, both the document and its summary.

In [None]:
# TODO: print one of the elements of raw_dataset["train"]

**Solution.** Click below to check the solution.

In [None]:
# @title
print("Document:")
print(raw_dataset['train'][0]['document'])
print("")
print("Summary:")
print(raw_dataset['train'][0]['summary'])

The above `raw_dataset` is quite large, and training the transformer on it will take too much time. In order to be able to carry our experiments in a reasonable time, we will reduce the size of the dataset.

**Exercise.** Create a `tiny_dataset` by keeping the first 500 training exmamples, the first 100 validation examples, and the first 100 test examples.

In [None]:
# TODO: Create the tiny_dataset object
print(tiny_dataset)

**Solution.** Click below to check the solution.

In [None]:
# @title
tiny_dataset = {
    "train": raw_dataset["train"].select(range(500)),  # Keep first 5,000 examples
    "validation": raw_dataset["validation"].select(range(100)),  # Keep first 1_000
    "test": raw_dataset["test"].select(range(100))  # Keep first 1_000
}

tiny_dataset = DatasetDict(tiny_dataset)
tiny_dataset

### Pre-trained Tokenizer
In order to convert the textual data in the above datasets into a format that can be used as input for our models, we first *tokenize* the text using a pre-trained tokenizer.

In [None]:
from transformers import BartTokenizer

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

**Exercise.** Let's find out more about the tokenizer we will be using. Write code in order to answer the following questions:
1. What is the name of the tokenizer being used?
2. What is the size of the vocabulary?
3. What is the maximum model input length?
4. What special tokens does the tokenizer use? What are their IDs?

In [None]:
# TODO: print the necessary information about the automatically load tokenizer

**Solution.** Click below to check the solution.

In [None]:
# @title
print(f"Name of the tokenizer: {tokenizer.__class__}")
print(f"Size of the vocabulary: {tokenizer.vocab_size}")
print(f"Maximum model input length: {tokenizer.model_max_length}")
print(f"Special tokens: {tokenizer.special_tokens_map}")

for key, value in tokenizer.special_tokens_map.items():
    print(f"{key}: {value}; token_id: {tokenizer.convert_tokens_to_ids(value)}")

### Pre-processing the data
We next build a pre-processing function in order to:
1. Tokenize the document and summaries.
2. Generate the appropriate inputs and outputs.

Note that our model will be trained using:
- The document as the input to the encoder.
- The summary as the input to the decoder.
- The summary with the tokens shifter right as the target output.

In order to produce the decoder input-output pairs, we will suppress the *end-of-sequence* token from the decoder input, and the *beginning-of-sequence* token from the target output.

In [None]:
VOCAB_SIZE = tokenizer.vocab_size
ENCODER_MAX_LEN = 1024
DECODER_MAX_LEN = 64

In [None]:
def preprocessing_function(example):
    model_inputs = {}
    encoder_inputs = tokenizer(example["document"], max_length=ENCODER_MAX_LEN, padding='max_length', truncation=True)
    model_inputs["encoder_input_ids"] = encoder_inputs["input_ids"]
    model_inputs["encoder_mask"] = encoder_inputs["attention_mask"]
    labels = tokenizer(example["summary"], max_length=DECODER_MAX_LEN + 1, padding='max_length', truncation=True)
    eos_index = labels["input_ids"].index(tokenizer.eos_token_id)
    model_inputs["decoder_input_ids"] = labels["input_ids"][:eos_index]+labels["input_ids"][eos_index+1:]
    model_inputs["decoder_mask"] = labels["attention_mask"][:eos_index] + labels["attention_mask"][eos_index+1:]
    model_inputs["labels"] = labels["input_ids"][1:]
    return model_inputs

In [None]:
# Apply tokenization to the entire dataset
tokenized_dataset = tiny_dataset.map(preprocessing_function, batched=False)

In [None]:
# Remove the original text fields to save memory
tokenized_dataset = tokenized_dataset.remove_columns(["document", "summary", "id"])

# Print dataset structure
print(tokenized_dataset)

### Data Loaders
Finally, we create three separate pytorch data loaders in order to easily iterate through them during the training and evaluation phases.

In [None]:
tokenized_dataset.set_format(type="torch")

train_loader = torch.utils.data.DataLoader(
    tokenized_dataset["train"],
    shuffle=True,
    batch_size=16,
)

val_loader = torch.utils.data.DataLoader(
    tokenized_dataset["validation"],
    shuffle=False,
    batch_size=16,
)

test_loader = torch.utils.data.DataLoader(
    tokenized_dataset["test"],
    shuffle=False,
    batch_size=16,
)

## Training and Evaluation
Now that we have both our model architecture defined and our data prepared, we can proceed to the last phase of the lab project: training and evaluating the model.

### Initialize model

In [None]:
HIDDEN_DIM = 256
NUM_HEADS = 8

print(f"VOCAB_SIZE: {VOCAB_SIZE}")
print(f"ENCODER_MAX_LEN: {ENCODER_MAX_LEN}")
print(f"DECODER_MAX_LEN: {DECODER_MAX_LEN}")
print(f"HIDDEN_DIM: {HIDDEN_DIM}")
print(f"NUM_HEADS: {NUM_HEADS}")

In [None]:
summarizer = TransformerSummarizer(VOCAB_SIZE, ENCODER_MAX_LEN, DECODER_MAX_LEN, HIDDEN_DIM, NUM_HEADS)
print(summarizer)

### Train Model
We can now proceed to the model training phase. In order to do so, we will define two functions:
- A `train_epoch` function that will train the model by iterating through the given data loader once.
- An `evaluate` function that will evaluate the model on the given data loader.

In [None]:
import time
from tqdm import tqdm

In [None]:
def train_epoch(model, dataloader, optimizer, criterion):
    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for batch in tqdm(dataloader, desc="Processing Batches"):
        optimizer.zero_grad()

        encoder_input_ids = batch["encoder_input_ids"]
        encoder_mask = batch["encoder_mask"]
        decoder_input_ids = batch["decoder_input_ids"]
        decoder_mask = batch["decoder_mask"]
        labels = batch["labels"]

        outputs = model(encoder_input_ids, decoder_input_ids, encoder_mask, decoder_mask)
        outputs = outputs.view(-1, VOCAB_SIZE)
        labels = labels.view(-1)
        loss = criterion(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(dataloader), epoch_acc / len(dataloader)

In [None]:
def evaluate(model, dataloader, criterion):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Processing Batches"):

            encoder_input_ids = batch["encoder_input_ids"]
            encoder_mask = batch["encoder_mask"]
            decoder_input_ids = batch["decoder_input_ids"]
            decoder_mask = batch["decoder_mask"]
            labels = batch["labels"]

            outputs = model(encoder_input_ids, decoder_input_ids, encoder_mask, decoder_mask)
            outputs = outputs.view(-1, VOCAB_SIZE)
            labels = labels.view(-1)
            loss = criterion(outputs, labels)
            acc = (outputs.argmax(dim=1) == labels).float().mean()

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(dataloader), epoch_acc / len(dataloader)

In [None]:
EPOCHS = 5
LEARNING_RATE = 1e-3

optimizer = torch.optim.Adam(summarizer.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss().to(device)

for epoch in range(EPOCHS):

    train_loss, train_acc = train_epoch(summarizer, train_loader, optimizer, criterion)
    valid_loss, valid_acc = evaluate(summarizer, val_loader, criterion)

    epoch_time = time.time()

    print("")
    print(f'Epoch: {epoch+1:02} | Time: {epoch_time}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

### Evaluation
Finally, we evaluate the model on the test set.

In [None]:
test_loss, test_acc = evaluate(summarizer, test_loader, criterion)
print("")
print(f'Test Loss: {test_loss:.3f} |  Test Acc: {test_acc*100:.2f}%')