# How To Fine-Tune Decoder-Only Models For Sequence Classification Using Last Token Pooling?

[![Twitter Handle](https://img.shields.io/badge/Twitter-@gaohongnan-blue?style=social&logo=twitter)](https://twitter.com/gaohongnan)
[![LinkedIn Profile](https://img.shields.io/badge/@gaohongnan-blue?style=social&logo=linkedin)](https://linkedin.com/in/gao-hongnan)
[![GitHub Profile](https://img.shields.io/badge/GitHub-gao--hongnan-lightgrey?style=social&logo=github)](https://github.com/gao-hongnan)
![Tag](https://img.shields.io/badge/Tag-Brain_Dump-red)
![Tag](https://img.shields.io/badge/Level-Beginner-green)
[![Code](https://img.shields.io/badge/View-Code-blue?style=flat-square&logo=github)](https://github.com/gao-hongnan/omniverse/tree/main/omnivault/transformer)

```{contents}
:local:
```

Firstly, if you have not read my
[Generative Pre-trained Transformers (GPT) series](https://www.gaohongnan.com/influential/generative_pretrained_transformer/03_concept.html),
please have a read first to establish some basic understand on what a
decoder-only model entails.

## Dependencies

```bash
pip install -U omniverse
```

In [2]:
from __future__ import annotations

import logging
from collections import Counter, OrderedDict
from functools import partial
from typing import Any, Dict, List, Literal, Tuple, cast, overload

import evaluate
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import psutil
import seaborn as sns
import torch
from datasets import Dataset, load_dataset
from IPython.display import clear_output
from rich.pretty import pprint
from sklearn.metrics import classification_report
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchinfo import summary
from tqdm.notebook import tqdm  # Use notebook version for better UI in notebooks
from transformers import (
    AutoModelForMaskedLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    DistilBertTokenizer,
    GPT2Tokenizer,
    LlamaForSequenceClassification,
    PreTrainedModel,
    PreTrainedTokenizer,
    RobertaConfig,
    RobertaModel,
    Trainer,
    TrainingArguments,
)

from omnivault.transformer.config.decoder import (
    AddNormConfig,
    DecoderBlockConfig,
    DecoderConfig,
    MultiHeadedAttentionConfig,
    PositionwiseFeedForwardConfig,
)
from omnivault.transformer.modules.attention.core import MultiHeadedAttention, ScaledDotProductAttention
from omnivault.transformer.modules.layers.addnorm import AddNorm
from omnivault.transformer.modules.layers.mlp import PositionwiseFeedForward
from omnivault.utils.reproducibility.seed import seed_all

## Setting Up

In [3]:
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)

2024

In [4]:
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
LOGGER.addHandler(handler)

In [5]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MAX_LENGTH = 64
PADDING = "longest"
BATCH_SIZE = 32
TRUNCATION = True
RETURN_TENSORS = "pt"

## Dataset

In [6]:
dataset = load_dataset('financial_phrasebank', 'sentences_allagree', trust_remote_code=True)["train"]
dataset

Dataset({
    features: ['sentence', 'label'],
    num_rows: 2264
})

In [7]:
def count_labels(labels: List[int]) -> Dict[int, int]:
    label_counts = Counter(labels)
    ordered_label_counts = OrderedDict(sorted(label_counts.items()))
    return dict(ordered_label_counts)


sentences_allagree = dataset['sentence']
labels_allagree = dataset['label']

label_counts = count_labels(labels_allagree)
pprint(label_counts)

In [8]:
train_valid_split = dataset.train_test_split(test_size=0.1, shuffle=True, stratify_by_column='label')
train_dataset = train_valid_split['train']
valid_dataset = train_valid_split['test']

We create our own `Dataset` just for understanding!

In [9]:
train_df = train_dataset.to_pandas()
valid_df = valid_dataset.to_pandas()

In [10]:
class FinancialDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: PreTrainedTokenizer, **tokenizer_kwargs: Any) -> None:
        self.tokenizer = tokenizer
        self.tokenizer_kwargs = tokenizer_kwargs
        self.inputs = df["sentence"].tolist()
        self.labels = df["label"].tolist()

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        input_ids = self.tokenizer.encode(text=self.inputs[index], **self.tokenizer_kwargs).long()
        labels = torch.tensor(self.labels[index]).long()
        return {
            "input_ids": input_ids,
            "labels": labels,
        }

We will create the causal mask in the collator.

## Tokenizer

In [11]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
pprint(tokenizer.special_tokens_map)

tokenizer.pad_token = tokenizer.eos_token
pprint(tokenizer)

## Data Collator And DataLoader

In [12]:
train_dataset = FinancialDataset(train_df, tokenizer=tokenizer, max_length=3, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
valid_dataset = FinancialDataset(valid_df, tokenizer=tokenizer, max_length=3, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)

In [13]:
def construct_dummy_batch_causal_masks(batch_size: int, seq_len: int) -> torch.BoolTensor:
    """Broadcast future mask from shape (L, L) to (B, L, L) then (B, 1, L, L)."""
    # Create a lower triangular mask for a single sequence
    future_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool), diagonal=0).to(torch.bool)
    future_mask = future_mask.contiguous()
    # broadcast future mask from shape (L, L) to (B, L, L)
    causal_masks = future_mask.unsqueeze(0).expand(batch_size, -1, -1)
    # broadcast future mask from shape (B, L, L) to (B, 1, L, L)
    causal_masks = causal_masks.unsqueeze(1)
    return torch.BoolTensor(causal_masks)

def collate_for_unidirectional(
    batch: List[Dict[str, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    max_length = max(item["input_ids"].size(1) for item in batch) # 25
    input_ids = torch.zeros((len(batch), max_length), dtype=torch.long)
    labels = torch.zeros(len(batch), dtype=torch.long)

    # do padding manually
    for index, item in enumerate(batch):
        seq_len = item["input_ids"].size(1)
        input_ids[index, :seq_len] = item["input_ids"]
        labels[index] = item["labels"]

    batch_size, seq_len = input_ids.size()

    causal_masks = construct_dummy_batch_causal_masks(batch_size, seq_len)
    return input_ids, labels, causal_masks

In [14]:
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_for_unidirectional)
valid_dataloader = DataLoader(valid_dataset, batch_size=2, shuffle=False, collate_fn=collate_for_unidirectional)

for batch in train_dataloader:
    input_ids, labels, causal_masks = batch
    pprint(input_ids)
    pprint(labels)
    pprint(causal_masks)
    break

## Model Architecture

In [15]:
class DecoderForSequenceClassificationConfig(DecoderConfig):
    num_labels: int
    head_bias: bool = False


class GPTPretrainedModel(nn.Module):
    def _init_weights(self, module: nn.Module) -> None:
        normal_init_modules = (nn.Linear, nn.Embedding)
        if isinstance(module, normal_init_modules):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if hasattr(module, "bias") and module.bias is not None:
                torch.nn.init.zeros_(module.bias)


class GPTDecoderBlock(nn.Module):
    """GPTDecoderBlock focuses on masked self-attention and feed-forward layers.

    The architecture follows the GPT-style decoder, which only has masked
    self-attention and position-wise feed-forward layers, omitting the
    encoder-decoder cross-attention.
    """

    def __init__(self, config: DecoderConfig) -> None:
        super().__init__()
        self.masked_self_attention_mha = MultiHeadedAttention(
            **config.decoder_block.masked_self_attention_mha.model_dump(mode="python")
        )
        self.feed_forward = PositionwiseFeedForward(**config.decoder_block.feed_forward.model_dump(mode="python"))
        self.add_norm_1 = AddNorm(**config.decoder_block.add_norm_1.model_dump(mode="python"))
        self.add_norm_2 = AddNorm(**config.decoder_block.add_norm_2.model_dump(mode="python"))

    def forward(self, z: torch.Tensor, causal_masks: torch.BoolTensor) -> torch.Tensor:
        """
        Parameters
        ----------
        z:              Input sequence.
                        type:  torch.Tensor
                        shape: (B, S or T, D)

        Returns
        -------
        z:              Output tensor after masked self-attention and feed-forward layers.
                        type:  torch.Tensor
                        shape: (B, S or T, D)
        """
        z = self.add_norm_1(
            z,
            lambda z: self.masked_self_attention_mha(query=z, key=z, value=z, mask=causal_masks),
        )
        z = self.add_norm_2(z, self.feed_forward)
        return z


class GPTBackbone(GPTPretrainedModel):
    def __init__(self, config: DecoderConfig) -> None:
        super().__init__()
        self.d_model: int = config.d_model
        self.tok_embed: nn.Embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embed: nn.Parameter = nn.Parameter(torch.zeros(1, config.context_length, config.d_model))
        self.decoder_blocks: nn.ModuleList = nn.ModuleList(
            [GPTDecoderBlock(config) for _ in range(config.num_decoder_blocks)]
        )  # PyTorch did not make ModuleList a proper container, maybe open a PR to make it inherit Generic[T]???

        self.dropout: nn.Dropout = nn.Dropout(config.dropout)
        self.layer_norm: nn.LayerNorm = nn.LayerNorm(config.d_model)

        self.apply(self._init_weights)

        context_projections = ("context_projection.weight", "W_O.weight")
        # apply special scaled init to the residual projections, per GPT-2 paper
        for parameter_name, parameter in self.named_parameters():
            # NOTE: W_O is also projection but I did not have foresight to name it as such.
            if parameter_name.endswith(context_projections):
                mean = 0.0
                std_dev = 0.02 / torch.sqrt(torch.tensor(2 * config.num_decoder_blocks, dtype=torch.float))
                torch.nn.init.normal_(parameter, mean=mean, std=std_dev)

    def forward(
        self, input_tokens: torch.LongTensor, *, causal_masks: torch.BoolTensor
    ) -> torch.FloatTensor:
        seq_len: int = input_tokens.size(1)  # note seq_len <= context_length in decoder
        causal_masks = causal_masks.to(input_tokens.device)  # type: ignore[assignment]

        z = self.tok_embed(input_tokens)  # TODO: * math.sqrt(self.d_model) for better optimization landscape
        z = z + self.pos_embed[:, :seq_len, :]
        z = self.dropout(z)

        for decoder_block in self.decoder_blocks:
            z = decoder_block(z, causal_masks=causal_masks)

        z = self.layer_norm(z)
        return z

class LastTokenPooling(nn.Module):
    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        # pooling to extract the last token's logits
        # logits: [B, T, C] -> [B, C] the -1 takes the last token's logits along
        # the feature dimension
        pooled_logits = logits[:, -1, :]
        return pooled_logits

class GPTForSequenceClassification(GPTPretrainedModel):
    def __init__(self, config: DecoderForSequenceClassificationConfig) -> None:
        super().__init__()
        self.backbone = GPTBackbone(config)
        self.head = nn.Linear(config.d_model, config.num_labels, bias=config.head_bias)
        self.pooler = LastTokenPooling()

        self.apply(self._init_weights)

        # apply special scaled init to the residual projections, per GPT-2 paper
        for parameter_name, parameter in self.named_parameters():
            if parameter_name.endswith("context_projection.weight"):
                mean = 0.0
                std_dev = 0.02 / torch.sqrt(torch.tensor(2 * config.num_decoder_blocks, dtype=torch.float))
                torch.nn.init.normal_(parameter, mean=mean, std=std_dev)

    def forward(
        self,
        input_tokens: torch.LongTensor,
        *,  # force keyword only arguments to prevent errors
        causal_masks: torch.BoolTensor,
    ) -> torch.FloatTensor:
        """
        Notations
        ---------
        B:      Batch size
        S or L: Source sequence length
        T or L: Target sequence length
        D:      Embedding dimension
        C:      Vocabulary size (Class size)

        Parameters
        ----------
        input_tokens:           Input sequence.
                                type:  torch.Tensor
                                shape: (B, T)
        causal_masks:           Future mask.
                                type:  torch.BoolTensor
                                shape: (B, 1, T, T)

        Variables
        ---------
        z:                      Input sequence after token and position embedding.
                                type:  torch.Tensor
                                shape: (B, T, D)
        causal_masks:           Target mask.
                                type:  torch.BoolTensor
                                shape: (B, 1, T, T)
        logits:                 Output logits.
                                type:  torch.FloatTensor
                                shape: (B, T, C)
        pooled_logits:          Pooled logits.
                                type:  torch.FloatTensor
                                shape: (B, C)
        """

        backbone_last_layer_hidden_state = self.backbone(input_tokens, causal_masks=causal_masks)
        logits: torch.FloatTensor = self.head(backbone_last_layer_hidden_state)
        pooled_logits = self.pooler(logits)
        return pooled_logits


In [16]:
model_config = DecoderForSequenceClassificationConfig(
    d_model=32,
    vocab_size=tokenizer.vocab_size,
    context_length=MAX_LENGTH,
    num_decoder_blocks=1,
    dropout=0.0,
    decoder_block=DecoderBlockConfig(
        masked_self_attention_mha=MultiHeadedAttentionConfig(
            attention=ScaledDotProductAttention(), d_model=32, H=1, dropout=0.0
        ),
        feed_forward=PositionwiseFeedForwardConfig(
            d_model=32, d_ff=32 * 2, activation=nn.GELU(approximate="tanh"), dropout=0.0, bias=True
        ),
        add_norm_1=AddNormConfig(feature_dim=32, dropout=0.0),
        add_norm_2=AddNormConfig(feature_dim=32, dropout=0.0),
    ),
    num_labels=3,
)

In [17]:
model = GPTForSequenceClassification(model_config).to(DEVICE)

In [18]:
pprint(model)

## Dry Run

In our dry run, we make the following assumptions:

- `batch_size = 2` which means we have $2$ samples in a batch.
- `MAX_LEN = 3` which means the context length $T$ is $3$.
- `d_model = 4` which means the model dimension is $4$ for hidden layers.
- Consequently the final output dimension of the backbone is $\mathcal{B} \times T \times D \rightarrow 2\times 3 \times 4$.

In [19]:
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)

dry_run_model_config = DecoderForSequenceClassificationConfig(
    d_model=4,
    vocab_size=tokenizer.vocab_size,
    context_length=MAX_LENGTH,
    num_decoder_blocks=1,
    dropout=0.0,
    decoder_block=DecoderBlockConfig(
        masked_self_attention_mha=MultiHeadedAttentionConfig(
            attention=ScaledDotProductAttention(), d_model=4, H=1, dropout=0.0
        ),
        feed_forward=PositionwiseFeedForwardConfig(
            d_model=4, d_ff=4 * 2, activation=nn.GELU(approximate="tanh"), dropout=0.0, bias=True
        ),
        add_norm_1=AddNormConfig(feature_dim=4, dropout=0.0),
        add_norm_2=AddNormConfig(feature_dim=4, dropout=0.0),
    ),
    num_labels=3,
)

dry_run_model = GPTForSequenceClassification(dry_run_model_config).to(DEVICE)
pprint(dry_run_model)

In [20]:
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)

batch = next(iter(train_dataloader))
input_ids, labels, causal_masks = batch

First, we see the input ids, labels and causal masks to be of the below format.

In [21]:
pprint(input_ids)
pprint(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
pprint(labels)
pprint(causal_masks)

In [23]:
dry_run_backbone = dry_run_model.backbone
dry_run_backbone_last_layer_hidden_state = dry_run_backbone(input_ids, causal_masks=causal_masks)
dry_run_backbone_last_layer_hidden_state = dry_run_backbone_last_layer_hidden_state.detach().cpu()
pprint(dry_run_backbone_last_layer_hidden_state)
pprint(dry_run_backbone_last_layer_hidden_state.shape)

Here indeed the output of the backbone is of shape `[2, 3, 4]`. More concretely,
we have for the first sequence/example in the batch to be `[47117,   351,   262]`
with the underlying text to be `'Relations with the'` and the corresponding
label to be `0`. Now you see there are 3 tokens in the sequence, it is 
normal because if we do autoregressive modelling, we need to predict the next
token given the previous tokens. However, when we move on to sequence level
classification, we actually want to predict the label for the entire sequence
and not just say, given the first token, predict the second token and so on.
Fundamentally, the backbone is not designed for this task. Currently, the
backbone outputs the hidden states for each token in the sequence. 

For example,

```python
[ 1.0732,  0.3017, -1.6063,  0.2314] # -> token embedding for `Relations`
[-0.6302,  1.6116, -0.0369, -0.9445] # -> token embedding for `with`
[-0.0674,  1.2399, -1.4665,  0.2940] # -> token embedding for `the`
```

We introduce the idea of pooling the hidden states to get a single representation
for the entire sequence. You can think of it as transforming the hidden states
of all 3 tokens in the sequence to 1 single sentence/sequence representation.

```python
[-0.0674,  1.2399, -1.4665,  0.2940] # -> pooled embedding for `Relations with the`
```

However, in decoder only models, we do not have the `[CLS]` token to pool the
hidden states. However, recall the causal mask format for the first sequence.

```python
[ True, False, False]
[ True,  True, False]
[ True,  True,  True]
```

Oh, so the last token in the sequence is the one that is not masked - which
defaults to _cross attention_ since it has information of _every token_ in the
sequence. So, we can simply pool the last token to get the sequence
representation.

```python
class LastTokenPooling(nn.Module):
    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        # pooling to extract the last token's logits
        # logits: [B, T, C] -> [B, C] the -1 takes the last token's logits along
        # the feature dimension
        pooled_logits = logits[:, -1, :]
        return pooled_logits
```

In [25]:
dry_run_pooler = dry_run_model.pooler
dry_run_pooler_output = dry_run_pooler(dry_run_backbone_last_layer_hidden_state)
dry_run_pooler_output = dry_run_pooler_output.detach().cpu()
pprint(dry_run_pooler_output)
pprint(dry_run_pooler_output.shape)

In [None]:
logits = dry_run_model(input_ids, causal_masks=causal_masks)
logits = logits.detach().cpu()
pprint(logits)

In [19]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0048)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=6, eta_min=0.0)
num_epochs = 6

In [20]:
train_dataset = FinancialDataset(train_df, tokenizer=tokenizer, max_length=MAX_LENGTH, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
valid_dataset = FinancialDataset(valid_df, tokenizer=tokenizer, max_length=MAX_LENGTH, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_for_unidirectional, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=collate_for_unidirectional, pin_memory=True)

In [21]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct_predictions = 0

    for batch in train_dataloader:
        input_ids, labels, causal_masks = [x.to(DEVICE) for x in batch]
        # print(input_ids.size(), labels.size(), causal_masks.size(), target_padding_masks.size())
        optimizer.zero_grad()
        outputs = model(input_tokens=input_ids, causal_masks=causal_masks)
        #print(outputs.size())
        #print(labels.size())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = torch.max(outputs, dim=1)
        correct_predictions += torch.sum(preds == labels).item()

    avg_loss = total_loss / len(train_dataloader)
    accuracy = correct_predictions / len(train_dataset)

    scheduler.step()

    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Training loss: {avg_loss:.4f}, Training accuracy: {accuracy:.4f}")

    # Validation loop
    model.eval()
    val_loss = 0.0
    val_correct_predictions = 0

    with torch.no_grad():
        for batch in valid_dataloader:
            input_ids, labels, causal_masks = [x.to(DEVICE) for x in batch]

            outputs = model(input_tokens=input_ids, causal_masks=causal_masks)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, preds = torch.max(outputs, dim=1)
            val_correct_predictions += torch.sum(preds == labels).item()

    avg_val_loss = val_loss / len(valid_dataloader)
    val_accuracy = val_correct_predictions / len(valid_dataset)

    print(f"Validation loss: {avg_val_loss:.4f}, Validation accuracy: {val_accuracy:.4f}")

Epoch 1/6
Training loss: 0.7015, Training accuracy: 0.7148
Validation loss: 0.6304, Validation accuracy: 0.6784
Epoch 2/6
Training loss: 0.3970, Training accuracy: 0.8409
Validation loss: 0.7096, Validation accuracy: 0.7048
Epoch 3/6
Training loss: 0.2276, Training accuracy: 0.9185
Validation loss: 0.7362, Validation accuracy: 0.8238
Epoch 4/6
Training loss: 0.0784, Training accuracy: 0.9759
Validation loss: 0.6097, Validation accuracy: 0.8722
Epoch 5/6
Training loss: 0.0288, Training accuracy: 0.9921
Validation loss: 0.7513, Validation accuracy: 0.8678
Epoch 6/6
Training loss: 0.0160, Training accuracy: 0.9971
Validation loss: 0.7333, Validation accuracy: 0.8634


## Using HuggingFace

In [22]:
def preprocess_function(
    batch: Dict[Literal["sentence", "label"], List[str | int]], **kwargs: Any
) -> Dict[Literal["input_ids", "attention_mask"], List[int]]:
    return tokenizer(batch["sentence"], **kwargs)

In [23]:
tokenized_train_dataset = train_dataset.map(
    preprocess_function,
    fn_kwargs={"truncation": True, "padding": "longest", "max_length": 32},
    batched=True,
    num_proc=psutil.cpu_count(logical=True),
    batch_size=1000,
)

tokenized_valid_dataset = valid_dataset.map(
    preprocess_function,
    fn_kwargs={"truncation": True, "padding": "longest", "max_length": 32},
    batched=True,
    num_proc=psutil.cpu_count(logical=True),
    batch_size=1000,
)

tokenized_train_dataset_torch = tokenized_train_dataset.with_format("torch", DEVICE=DEVICE)
tokenized_valid_dataset_torch = tokenized_valid_dataset.with_format("torch", DEVICE=DEVICE)

AttributeError: 'FinancialDataset' object has no attribute 'map'

In [None]:
tokenized_valid_dataset_torch[0]['attention_mask']