# Encoder-only transformer model for AG News classification

# TODO

1. Add Positional Encoding
2. Add LR Scheduler
3. Check why need to use `torch.nn.utils.clip_grad_norm_` to clip gradients
4. Why unsqueeze mask?
5. Can you init weights inside Encoder instead of outside?

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html

In this notebook, I train a encoder-only transformer to do text classification on the AG_NEWS dataset.
Text classification seems to be a pretty simple task, and using transformer is probably overkill. But this is my first time implementing the transformer structure from scratch (including the self-attention module), and it was fun :-)

In [346]:
# # some commands in th is notebook require torchtext 0.12.0
# !pip install torch --upgrade --quiet
# !pip install torchtext --upgrade --quiet
# !pip install torchdata --quiet
# !pip install torchinfo --quiet
# !pip install portalocker --quiet

In [347]:
import collections
import math
from dataclasses import dataclass
from rich.pretty import pprint
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchdata
import torchinfo
import torchtext
from torch.utils.data import DataLoader
from torchtext.datasets import AG_NEWS
from tqdm import tqdm
import os
import time
import random

from typing import Optional

In [348]:
def seed_all(seed: Optional[int] = 1992, seed_torch: bool = True) -> int:
    """
    Seed all random number generators.

    Parameters
    ----------
    seed : int, optional
        Seed number to be used, by default 1992.
    seed_torch : bool, optional
        Whether to seed PyTorch or not, by default True.
    """
    # fmt: off
    os.environ["PYTHONHASHSEED"] = str(seed)        # set PYTHONHASHSEED env var at fixed value
    np.random.seed(seed)                            # numpy pseudo-random generator
    random.seed(seed)                               # python's built-in pseudo-random generator

    if seed_torch:
        torch.manual_seed(seed)
        # torch.manual_seed may call manual_seed_all but calling it again here
        # to make sure it gets called at least once
        torch.cuda.manual_seed_all(seed)             # pytorch (both CPU and CUDA)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.enabled = False
    # fmt: on
    return seed

In [349]:
seed_all(42, seed_torch=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cpu


## Data processing

In [350]:
# One can easily modify the data processing part of this code to accommodate for   other datasets for text classification listed in https://pytorch.org/text/stable/datasets.html#text-classification
train_iter, test_iter = AG_NEWS()

num_classes = len(set([label for (label, text) in train_iter]))
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [351]:
# see an example of the dateset
next(iter(train_iter))

(3,
 "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")

In [352]:
# fmt: off
# convert the labels to be in range(0, num_classes)
y_train       = torch.tensor([label - 1 for (label, text) in train_iter])
y_test        = torch.tensor([label - 1 for (label, text) in test_iter])

# There are many "\\" in the texts in the AG_news dataset, we get rid of them.
train_iter    = ((label, text.replace("\\", " ")) for label, text in train_iter)
test_iter     = ((label, text.replace("\\", " ")) for label, text in test_iter)

# tokenize the texts, and truncate the number of words in each text to max_seq_len
max_seq_len   = 100
x_train_texts = [tokenizer(text.lower())[0:max_seq_len] for (label, text) in train_iter]
x_test_texts  = [tokenizer(text.lower())[0:max_seq_len] for (label, text) in test_iter]
# fmt: on

In [353]:
# build the vocabulary and word-to-integer map
counter = collections.Counter()
for text in x_train_texts:
    counter.update(text)

vocab_size = 15000
most_common_words = np.array(counter.most_common(vocab_size - 2))
vocab = most_common_words[:,0]

# indexes for the padding token, and unknown tokens
PAD = 0
UNK = 1
word_to_id = {vocab[i]: i + 2 for i in range(len(vocab))}

In [354]:
len(vocab)

14998

In [355]:
# map the words in the training and test texts to integers
x_train = [
    torch.tensor([word_to_id.get(word, UNK) for word in text]) for text in x_train_texts
]
x_test = [
    torch.tensor([word_to_id.get(word, UNK) for word in text]) for text in x_test_texts
]
x_test = torch.nn.utils.rnn.pad_sequence(x_test, batch_first=True, padding_value=PAD)


In [356]:
# constructing the dataset in order to be compatible with torch.utils.data.Dataloader
class AGNewsDataset:
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, item):
        return self.features[item], self.labels[item]


train_dataset = AGNewsDataset(x_train, y_train)
test_dataset  = AGNewsDataset(x_test, y_test)

In [357]:
# collate_fn to be used in torch.utils.data.DataLoader().
# It pads the texts in each batch such that they have the same sequence length.
def pad_sequence(batch):
    texts = [text for text, label in batch]
    labels = torch.tensor([label for text, label in batch])
    texts_padded = torch.nn.utils.rnn.pad_sequence(
        texts, batch_first=True, padding_value=PAD
    )
    return texts_padded, labels


train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, collate_fn=pad_sequence
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128, shuffle=True, collate_fn=pad_sequence
)


In [358]:
# # test loader
# counter = 0
# for batch in train_loader:
#     print(batch[0].shape)
#     print(batch[1].shape)
#     counter += 1
#     if counter == 5:
#         break


# my

In [359]:
import copy
import math
import unittest
from abc import ABC, abstractmethod
from typing import Optional, Tuple

import numpy as np
import rich
import torch
# from d2l import torch as d2l
from rich.pretty import pprint
from torch import nn
from dataclasses import dataclass

# from src.utils.reproducibility import seed_all


class Attention(ABC, nn.Module):
    def __init__(self, dropout: float = 0.0) -> None:
        super().__init__()
        self.dropout = nn.Dropout(p=dropout, inplace=False)

    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        ...


class ScaledDotProductAttention(Attention):
    """
    Scaled Dot-Product Attention Class.

    This class performs scaled dot-product attention following the equations:

    .. math::
        \\text{Attention}(Q, K, V) = \\text{softmax} \\left( \\frac{QK^T}{\\sqrt{d_k}} \\right) V

    Inherits from Attention class.

    Methods
    -------
    forward(query, key, value, mask)
        Forward pass for scaled dot-product attention.
    """

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.BoolTensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Perform the forward pass for scaled dot-product attention.

        This function applies the attention mechanism on the input tensors
        `query`, `key`, and `value`. It's worth noting that for cross-attention,
        the sequence lengths of `query` and `key`/`value` may differ.
        This is because :math:`query` is usually projected from the decoder's states,
        while :math:`key` and :math:`value` are from the encoder's states.

        Notations:

        - :math:`D` : Embedding dimension
        - :math:`d_k`: Dimension of the keys, queries, and values
        - :math:`N`: Batch size
        - :math:`T`: Sequence length for `query`
        - :math:`S`: Sequence length for `key` and `value`

        NOTE: We use :math:`L` in our notes instead of :math:`T` and :math:`S`
        since we assume all query, key and value are of same length.

        Parameters
        ----------
        query : torch.Tensor
            Tensor containing query vectors for each sequence.
            Shape: :math:`(N, T, d_k)`.
        key : torch.Tensor
            Tensor containing key vectors for each sequence.
            Shape: :math:`(N, S, d_k)`.
        value : torch.Tensor
            Tensor containing value vectors for each sequence.
            Shape: :math:`(N, S, d_k)`.
        mask : torch.BoolTensor, optional
            Optional mask tensor. Used for padding and future masking.
            Shape could be :math:`(N, S)` or :math:`(T, T)`, depending on the type
            of attention (self-attention or cross-attention).

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            The context vectors and the attention weights. The context vectors are the weighted sum
            of the `value` vectors, representing the information to be attended to.
            The attention weights represent the attention probabilities.

            - Context Vectors shape: :math:`(N, T, d_k)`
            - Attention Weights shape: :math:`(N, T, S)`
        """

        # 1. Find the embedding dimension D or d_q from the query feature vector.
        #    Q = Z @ W_q \in R^{L x D}
        #    Q_h = Q @ W_q^h \in R^{L x d_q}
        d_q = query.size(dim=-1)

        # 2. Compute the dot product of the query feature vector with the key feature vector.
        #    Note since key is of dim (batch_size, L, d_k) so we operate the
        #    transpose on the last two dimensions, specified by dim0 and dim1.
        #    key.transpose(dim0=-2, dim1=-1) means let the second last dimension
        #    be the last dimension and let the last dimension be the second last dimension.
        # fmt: off
        attention_scores = torch.matmul(query, key.transpose(dim0=-2, dim1=-1)) / math.sqrt(d_q)
        # fmt: on
        torch.testing.assert_close(
            attention_scores,
            torch.matmul(query, key.transpose(dim0=-2, dim1=-1)) / math.sqrt(d_q),
            msg="attention scores from bmm and matmul should be the same.",
        )

        # 3. Apply mask to the scores if mask is not None.
        if mask is not None:
            # TODO: give example of shape of mask
            #print(f"mask.shape: {mask.shape}")
            mask = mask.squeeze(2)
            attention_scores = attention_scores.masked_fill(mask, float("-inf"))

        # 4. Apply softmax to the attention scores to obtain attention weights and context vectors.
        attention_weights = attention_scores.softmax(dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vector = torch.matmul(attention_weights, value)
        torch.testing.assert_close(
            context_vector,
            torch.matmul(attention_weights, value),
            msg="context vector from bmm and matmul should be the same.",
        )
        return context_vector, attention_weights


@dataclass
class ModelConfig:
    attention: Attention
    num_decoder_blocks: int
    vocab_size: int
    H: int
    d_model: int
    d_ff: int
    dropout: float
    max_seq_len: int
    bias: bool = True


class MultiHeadedAttention(nn.Module):
    __slots__ = [
        "d_model",
        "d_k",
        "d_q",
        "d_v",
        "H",
        "W_Q",
        "W_K",
        "W_V",
        "W_O",
        "attention",
        "dropout",
    ]

    def __init__(
        self,
        attention: Attention,
        H: int,
        d_model: int,
        dropout: float = 0.1,
        bias: bool = False,
    ) -> None:
        super().__init__()
        assert d_model % H == 0

        # fmt: off
        self.d_model   = d_model       # D
        self.d_k       = d_model // H  # stay true to notations
        self.d_q       = d_model // H
        self.d_v       = d_model // H

        self.H         = H             # number of heads

        # shadow my notations, actually they are of shape D x D.
        self.W_Q       = nn.Linear(self.d_model, self.d_q * self.H, bias=bias)  # D x D
        self.W_K       = nn.Linear(self.d_model, self.d_k * self.H, bias=bias)
        self.W_V       = nn.Linear(self.d_model, self.d_v * self.H, bias=bias)
        self.W_O       = nn.Linear(self.d_model, self.d_model, bias=bias)

        self.attention = attention
        self.dropout   = nn.Dropout(p=dropout, inplace=False)
        # fmt: on

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.BoolTensor] = None,
    ) -> torch.Tensor:

        if mask is not None:
            mask = mask.unsqueeze(1)

        # fmt: off
        Q = self.W_Q(query).contiguous() # Z @ W_Q -> LxD @ DxD = LxD
        K = self.W_K(key).contiguous()   # Z @ W_K
        V = self.W_V(value).contiguous() # Z @ W_V


        Q = self.transpose_qkv(Q)        # [B, H, L, D]
        K = self.transpose_qkv(K)
        V = self.transpose_qkv(V)

        # Attention
        # same as the other code: x = torch.matmul(p_atten, value)
        context_vector, attention_weights = self.attention(Q, K, V, mask)
        context_vector_concat = self.reverse_transpose_qkv(context_vector)

        # fmt: on
        return self.W_O(context_vector_concat)

    def transpose_qkv(self, q_or_k_or_v: torch.Tensor) -> torch.Tensor:
        """Transposition for parallel computation of multiple attention heads.
        TODO: Why does transpose allow parallel computation?
        """
        # fmt: off
        # 1. q_or_k_or_v is shape (B, L, D)
        # 2. aim to make it of shape (B, L, H, D / H = d_qkv)
        batch_size, seq_len, _ = q_or_k_or_v.shape
        q_or_k_or_v            = q_or_k_or_v.view(batch_size, seq_len, self.H, self.d_model // self.H)

        # 3. switch H from 3rd to 2nd dimension, or in python swap 2nd to 1st
        q_or_k_or_v            = q_or_k_or_v.permute(0, 2, 1, 3)
        # fmt: on
        return q_or_k_or_v

    def reverse_transpose_qkv(self, q_or_k_or_v: torch.Tensor) -> torch.Tensor:
        """Reverse the transposition operation for concatenating multiple attention heads."""
        # fmt: off
        # 1. q_or_k_or_v is shape (B, H, L, D / H = d_qkv)
        # 2. aim to make it of shape (B, L, H, D / H = d_qkv)
        q_or_k_or_v = q_or_k_or_v.permute(0, 2, 1, 3)

        # 3. Merge H and d_qkv into D
        batch_size, seq_len, _, _ = q_or_k_or_v.shape
        q_or_k_or_v = q_or_k_or_v.contiguous().view(batch_size, seq_len, self.d_model)
        # fmt: on
        return q_or_k_or_v


class ResidualConnection(nn.Module):
    """residual connection: x + dropout(sublayer(layernorm(x)))"""

    def __init__(self, d_model, dropout):
        super().__init__()
        self.drop = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, sublayer):
        return x + self.drop(sublayer(self.norm(x)))


# I simply let the model learn the positional embeddings in this notebook, since this
# almost produces identital results as using sin/cosin functions embeddings, as claimed
# in the original transformer paper. Note also that in the original paper, they multiplied
# the token embeddings by a factor of sqrt(d_embed), which I do not do here.


class Encoder(nn.Module):
    """Encoder = token embedding + positional embedding -> a stack of N EncoderBlock -> layer norm"""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.d_model = config.d_model
        self.tok_embed = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, config.max_seq_len, config.d_model)
        )
        self.encoder_blocks = nn.ModuleList(
            [EncoderBlock(config) for _ in range(config.num_decoder_blocks)]
        )
        self.dropout = nn.Dropout(config.dropout)
        self.norm = nn.LayerNorm(config.d_model)

        # self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                torch.nn.init.xavier_uniform_(p)

    def forward(self, input, mask=None):
        x = self.tok_embed(input)
        x_pos = self.pos_embed[:, : x.size(1), :]
        x = self.dropout(x + x_pos)
        for layer in self.encoder_blocks:
            x = layer(x, mask)
        return self.norm(x)


class PositionWiseFFN(nn.Module):
    """The positionwise feed-forward network."""

    pass


class EncoderBlock(nn.Module):
    """EncoderBlock: self-attention -> position-wise fully connected feed-forward layer"""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.mha = MultiHeadedAttention(
            config.attention, config.H, config.d_model, config.dropout, config.bias
        )
        self.feed_forward = nn.Sequential(
            nn.Linear(config.d_model, config.d_ff),
            nn.ReLU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.d_ff, config.d_model),
        )
        self.residual1 = ResidualConnection(config.d_model, config.dropout)
        self.residual2 = ResidualConnection(config.d_model, config.dropout)

    def forward(self, x, mask=None):
        # self-attention
        x = self.residual1(x, lambda x: self.mha(x, x, x, mask=mask))
        # position-wise fully connected feed-forward layer
        return self.residual2(x, self.feed_forward)


class Transformer(nn.Module):
    def __init__(self, config: ModelConfig, num_classes):
        super().__init__()
        self.encoder = Encoder(config)
        self.linear = nn.Linear(config.d_model, num_classes)

    def forward(self, x, pad_mask=None):
        x = self.encoder(x, pad_mask)
        return self.linear(torch.mean(x, -2))


## Building the encoder-only transformer model for text classification

In [360]:
# config = ModelConfig(
#     attention=ScaledDotProductAttention(),
#     num_decoder_blocks=6,
#     vocab_size=vocab_size,
#     H=8,
#     d_model=512,
#     d_ff=2048,
#     dropout=0.1,
#     max_seq_len=max_seq_len,
#     bias=True
# )

config = ModelConfig(
    attention=ScaledDotProductAttention(),
    num_decoder_blocks=1,
    vocab_size=vocab_size,
    H=1,
    d_model=32,
    d_ff=4*32,
    dropout=0.0,
    max_seq_len=max_seq_len,
    bias=True
)
pprint(config)

model = Transformer(config, num_classes).to(DEVICE)

# initialize model parameters
# it seems that this initialization is very important!
for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

In [361]:
print(torchinfo.summary(model))
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

Layer (type:depth-idx)                                       Param #
Transformer                                                  --
├─Encoder: 1-1                                               3,200
│    └─Embedding: 2-1                                        480,000
│    └─ModuleList: 2-2                                       --
│    │    └─EncoderBlock: 3-1                                12,704
│    └─Dropout: 2-3                                          --
│    └─LayerNorm: 2-4                                        64
├─Linear: 1-2                                                132
Total params: 496,100
Trainable params: 496,100
Non-trainable params: 0


## Train the model

In [362]:

def train_epoch(model, dataloader):
    model.train()
    losses, acc, count = [], 0, 0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for idx, (x, y) in pbar:
        optimizer.zero_grad()
        features = x.to(DEVICE)
        labels = y.to(DEVICE)
        pad_mask = (features == PAD).view(features.size(0), 1, 1, features.size(-1))
        pred = model(features, pad_mask)

        loss = loss_fn(pred, labels).to(DEVICE)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        acc += (pred.argmax(1) == labels).sum().item()
        count += len(labels)
        # report progress
        if idx > 0 and idx % 50 == 0:
            pbar.set_description(
                f"train loss={loss.item():.4f}, train_acc={acc/count:.4f}"
            )
    return np.mean(losses), acc / count


def train(model, train_loader, test_loader, epochs):
    for ep in range(epochs):
        train_loss, train_acc = train_epoch(model, train_loader)
        val_loss, val_acc = evaluate(model, test_loader)
        print(f"ep {ep}: val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")


def evaluate(model, dataloader):
    model.eval()
    losses = []
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader), total=len(dataloader))
        for x, y in pbar:
            features = x_test.to(DEVICE)
            labels = y_test.to(DEVICE)
            pad_mask = (features == PAD).view(features.size(0), 1, 1, features.size(-1))
            pred = model(features, pad_mask)
            loss = loss_fn(pred, labels).to(DEVICE)
            losses.append(loss.item())
            acc = (pred.argmax(1) == labels).sum().item()
            count = len(labels)
    return np.mean(losses), acc / count


In [363]:
# config = ModelConfig(encoder_vocab_size = vocab_size,
#                      d_embed = 32,
#                      d_ff = 4*32,
#                      h = 1,
#                      N_encoder = 1,
#                      max_seq_len = max_seq_len,
#                      dropout = 0.1
#                      )
# model = make_model(config)
# print(torchinfo.summary(model))
# optimizer = torch.optim.Adam(model.parameters())
# loss_fn = nn.CrossEntropyLoss()

In [364]:
train(model, train_loader, test_loader, epochs=4)

train loss=0.2568, train_acc=0.8996: 100%|██████████| 938/938 [00:23<00:00, 39.93it/s]
100%|██████████| 60/60 [00:46<00:00,  1.29it/s]


ep 0: val_loss=0.2333, val_acc=0.9200


train loss=0.0873, train_acc=0.9441: 100%|██████████| 938/938 [00:23<00:00, 39.73it/s]
100%|██████████| 60/60 [00:46<00:00,  1.30it/s]


ep 1: val_loss=0.2481, val_acc=0.9176


train loss=0.0828, train_acc=0.9593: 100%|██████████| 938/938 [00:23<00:00, 39.80it/s]
100%|██████████| 60/60 [00:45<00:00,  1.31it/s]


ep 2: val_loss=0.2649, val_acc=0.9208


train loss=0.1102, train_acc=0.9702: 100%|██████████| 938/938 [00:23<00:00, 39.69it/s]
100%|██████████| 60/60 [00:45<00:00,  1.31it/s]

ep 3: val_loss=0.2932, val_acc=0.9171





## News classification example

In [365]:
ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

def classify_news(news):
    x_text = tokenizer(news.lower())[0:max_seq_len]
    x_int = torch.tensor([[word_to_id.get(word, UNK) for word in x_text]]).to(DEVICE)

    model.eval()
    with torch.no_grad():
        pred = model(x_int).argmax(1).item() + 1
    print(f"This is a {ag_news_label[pred]} news")

# The model correctly classifies a theoretical physics news as Sci/Tec news, :-)
news = """The conformal bootstrapDavid Poland1,2and David Simmons-Duﬃn2*The conformal bootstrap was
proposed in the 1970s as a strategy for calculating the properties of second-order phasetransitions.
After spectacular success elucidating two-dimensional systems, little progress was made on systems in
 higher dimensions until a recent renaissance beginning in 2008. We report on some of the main results and
  ideas from thisrenaissance, focusing on new determinations of critical exponents and correlation
  functions in the three-dimensional Ising and O(N) models.
"""
classify_news(news)

This is a Sci/Tec news


In [366]:
import unittest
from abc import ABC, abstractmethod

import torch
from d2l import torch as d2l
from torch import nn

from src.utils.reproducibility import seed_all


class PositionalEncoding(ABC, nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.0) -> None:
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout, inplace=False)

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ...


class Sinusoid(PositionalEncoding):
    def __init__(
        self, d_model: int, dropout: float = 0.0, max_seq_len: int = 3
    ) -> None:
        super().__init__(d_model, dropout)
        self.max_seq_len = max_seq_len
        self.d_model = d_model

        P = self._init_positional_encoding()
        self.register_buffer("P", P, persistent=True)

    def _init_positional_encoding(self) -> torch.Tensor:
        """Initialize the positional encoding tensor."""
        P = torch.zeros((1, self.max_seq_len, self.d_model))
        position = self._get_position_vector()
        div_term = self._get_div_term_vector()
        P[:, :, 0::2] = torch.sin(position / div_term)
        P[:, :, 1::2] = torch.cos(position / div_term)
        return P

    def _get_position_vector(self) -> torch.Tensor:
        """Return a vector representing the position of each token in a sequence."""
        return torch.arange(self.max_seq_len, dtype=torch.float32).reshape(-1, 1)

    def _get_div_term_vector(self) -> torch.Tensor:
        """Return a vector representing the divisor term for positional encoding."""
        return torch.pow(
            10000,
            torch.arange(0, self.d_model, 2, dtype=torch.float32) / self.d_model,
        )

    def forward(self, Z: torch.Tensor) -> torch.Tensor:
        Z = self._add_positional_encoding(Z)
        return self.dropout(Z)

    def _add_positional_encoding(self, Z: torch.Tensor) -> torch.Tensor:
        """Add the positional encoding tensor to the input tensor."""
        return Z + self.P[:, : Z.shape[1], :].to(Z.device)


class TestPositionalEncoding(unittest.TestCase):
    def setUp(self) -> None:
        seed_all(42, seed_torch=True)

        # Initialize queries, keys, and values
        # fmt: off
        self.batch_size    = 1  # B
        self.num_heads     = 2  # H
        self.seq_len       = 60 # L
        self.d_model = 32 # D
        self.dropout       = 0.0
        self.max_seq_len       = 1000

        self.embeddings    = torch.zeros(self.batch_size, self.seq_len, self.d_model) # Z

        self.pos_encoding  = Sinusoid(d_model=self.d_model, dropout=self.dropout, max_seq_len=self.max_seq_len)
        # fmt: on

        # Initialize the attention models
        self.pos_encoding_d2l = d2l.PositionalEncoding(
            self.d_model, dropout=self.dropout, max_seq_len=self.max_seq_len
        )
        self.pos_encoding_d2l.eval()

    def test_positional_encoding_with_d2l_as_sanity_check(self) -> None:
        # fmt: off
        # d2l implementation
        Z_d2l = self.pos_encoding_d2l(self.embeddings)
        P_d2l = self.pos_encoding_d2l.P[:, : Z_d2l.shape[1], :]

        # own implementation
        Z     = self.pos_encoding(self.embeddings)
        P     = self.pos_encoding.P[:, : Z.shape[1], :]
        # fmt: on

        # Test if both are close
        self.assertTrue(torch.allclose(Z, Z_d2l))
        self.assertTrue(torch.allclose(P, P_d2l))


if __name__ == "__main__":
    unittest.main()


ModuleNotFoundError: No module named 'src'