# Transformers

## Introduction

The **Transformer** architecture, introduced by Vaswani et al. in the paper "Attention is
All You Need" (2017), represents a paradigm shift in sequence processing using neural
networks. Unlike traditional recurrent architectures (RNNs and LSTMs), which process
sequences sequentially, Transformers employ attention mechanisms that allow parallel
processing of the entire sequence, capturing long-range dependencies more effectively.

The architecture is based on the **multi-head attention** mechanism, which allows the
model to simultaneously attend to different positions of the input sequence from multiple
representation subspaces. This capability, combined with residual connections, layer
normalization, and feed-forward networks, has proven to be extraordinarily effective in
natural language processing tasks, computer vision, and other applications involving
sequential data.

## Auxiliary functions for masking

Masking constitutes an essential component in the Transformer architecture, serving two
main functions. The **causal mask** prevents the decoder from accessing future tokens
during training, preserving the autoregressive nature of sequence generation. The
**padding mask** allows ignoring padding tokens that are added to standardize sequence
lengths in a batch.

In [1]:
import torch
import math
from torch import nn
from torch.nn import functional as F

def create_causal_mask(size: int) -> torch.Tensor:
    """
    Creates a causal mask to prevent the decoder from attending to future tokens during training.

    Args:
        size: Length of the sequence.

    Returns:
        Causal mask of shape (size, size).
    """
    return torch.tril(torch.ones(size, size))

def create_padding_mask(seq: torch.Tensor, pad_token: int = 0) -> torch.Tensor:
    """
    Creates a mask to ignore padding tokens in a sequence.

    Args:
        seq: Sequence of tokens, shape (B, seq_len).
        pad_token: Padding token value.

    Returns:
        Padding mask of shape (B, 1, 1, seq_len).
    """
    return (seq != pad_token).unsqueeze(1).unsqueeze(1)

## Input embeddings and positional encoding

The transformation of discrete tokens into continuous vector representations constitutes
the first step in processing with Transformers. **Input embeddings** map each token from
the vocabulary to a dense vector of dimension $d_{model}$, learned during training.
Following the specification of the original paper, these embeddings are scaled by
multiplying by $\sqrt{d_{model}}$ to stabilize training.

Since Transformers lack the intrinsic notion of sequential order present in RNNs, it is
necessary to explicitly inject positional information. **Positional encoding** adds
deterministic vectors to the embeddings, calculated using sinusoidal functions that allow
the model to distinguish relative positions:

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$
$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

where $pos$ represents the position in the sequence and $i$ the embedding dimension.

In [2]:
class InputEmbedding(nn.Module):
    """Embeds input tokens into vectors of dimension d_model."""

    def __init__(self, d_model: int, vocab_size: int) -> None:
        """
        Initializes input embedding layer.

        Args:
            d_model: Dimensionality of the embedding vectors.
            vocab_size: Size of the vocabulary.
        """
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the embedding layer.

        Args:
            input_tensor: Input tensor of token indices.

        Returns:
            Tensor of embedded input scaled by sqrt(d_model).
        """
        return self.embedding(input_tensor) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    """Adds positional encoding to input embeddings."""

    def __init__(self, d_model: int, sequence_length: int, dropout_rate: float) -> None:
        """
        Initializes positional encoding layer.

        Args:
            d_model: Dimensionality of the embedding vectors.
            sequence_length: Maximum sequence length.
            dropout_rate: Rate of dropout for regularization.
        """
        super().__init__()
        self.d_model = d_model
        self.sequence_length = sequence_length
        self.dropout = nn.Dropout(dropout_rate)

        pe_matrix = torch.zeros(size=(self.sequence_length, self.d_model))
        position = torch.arange(0, self.sequence_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        pe_matrix[:, 0::2] = torch.sin(position * div_term)
        pe_matrix[:, 1::2] = torch.cos(position * div_term)
        pe_matrix = pe_matrix.unsqueeze(0)

        self.register_buffer(name="pe_matrix", tensor=pe_matrix)

    def forward(self, input_embedding: torch.Tensor) -> torch.Tensor:
        """
        Forward pass to add positional encoding.

        Args:
            input_embedding: Tensor of input embeddings.

        Returns:
            Tensor of embeddings with added positional encoding.
        """
        x = input_embedding + (
            self.pe_matrix[:, : input_embedding.shape[1], :]  # type: ignore
        ).requires_grad_(False)
        return self.dropout(x)

## Layer normalization and feed-forward networks

**Layer normalization** stabilizes training by normalizing activations across features
for each individual example. Unlike batch normalization, which normalizes across the
batch, layer normalization is more suitable for sequential data of variable length. The
transformation is defined as:

$$\text{LayerNorm}(x) = \alpha \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

where $\mu$ and $\sigma^2$ are the mean and variance calculated over the features,
$\alpha$ and $\beta$ are learnable parameters, and $\epsilon$ is a small constant for
numerical stability.

**Feed-forward networks** apply non-linear transformations independently to each position
in the sequence. They consist of two linear transformations with an intermediate ReLU
activation:

$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

In [3]:
class LayerNormalization(nn.Module):
    """Applies layer normalization to input embeddings."""

    def __init__(self, features: int, eps: float = 1e-6) -> None:
        """
        Initializes layer normalization.

        Args:
            features: Number of features in the input.
            eps: Small constant for numerical stability.
        """
        super().__init__()
        self.features = features
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(self.features))
        self.bias = nn.Parameter(torch.zeros(self.features))

    def forward(self, input_embedding: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for layer normalization.

        Args:
            input_embedding: Tensor of input embeddings.

        Returns:
            Normalized tensor.
        """
        mean = torch.mean(input=input_embedding, dim=-1, keepdim=True)
        var = torch.var(input=input_embedding, dim=-1, keepdim=True, unbiased=False)
        return (
            self.alpha * ((input_embedding - mean) / (torch.sqrt(var + self.eps)))
            + self.bias
        )

class FeedForward(nn.Module):
    """Feed-forward neural network layer."""

    def __init__(self, d_model: int, d_ff: int, dropout_rate: float) -> None:
        """
        Initializes feed-forward network.

        Args:
            d_model: Dimensionality of model embeddings.
            d_ff: Dimensionality of feed-forward layer.
            dropout_rate: Rate of dropout for regularization.
        """
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.ffn = nn.Sequential(
            nn.Linear(in_features=self.d_model, out_features=self.d_ff),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(in_features=self.d_ff, out_features=self.d_model),
        )

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through feed-forward network.

        Args:
            input_tensor: Tensor of input embeddings.

        Returns:
            Tensor processed by feed-forward network.
        """
        return self.ffn(input_tensor)

## Multi-head attention mechanism

The **multi-head attention** mechanism constitutes the central component of the
Transformer architecture. It allows the model to attend to information from different
representation subspaces at different positions simultaneously. Scaled dot-product
attention is calculated as:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

where $Q$, $K$, and $V$ represent the query, key, and value matrices, and $d_k$ is the
dimensionality of the keys. The scaling factor $\frac{1}{\sqrt{d_k}}$ prevents dot
products from growing excessively in magnitude.

Multi-head attention linearly projects queries, keys, and values $h$ times with different
learned projections, applies the attention function in parallel, and concatenates the
results:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

where $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$.

In [4]:
class MultiHeadAttention(nn.Module):
    """Applies multi-head attention mechanism."""

    def __init__(self, d_model: int, h: int, dropout_rate: float) -> None:
        """
        Initializes multi-head attention layer.

        Args:
            d_model: Dimensionality of model embeddings.
            h: Number of attention heads.
            dropout_rate: Rate of dropout for regularization.
        """
        super().__init__()

        if d_model % h != 0:
            raise ValueError("d_model must be divisible by h")

        self.d_model = d_model
        self.h = h
        self.dropout = nn.Dropout(dropout_rate)
        self.d_k = self.d_model // self.h
        self.d_v = self.d_model // self.h

        self.W_K = nn.Linear(
            in_features=self.d_model, out_features=self.d_model, bias=False
        )
        self.W_Q = nn.Linear(
            in_features=self.d_model, out_features=self.d_model, bias=False
        )
        self.W_V = nn.Linear(
            in_features=self.d_model, out_features=self.d_model, bias=False
        )
        self.W_OUTPUT_CONCAT = nn.Linear(
            in_features=self.d_model, out_features=self.d_model, bias=False
        )

    @staticmethod
    def attention(
        k: torch.Tensor,
        q: torch.Tensor,
        v: torch.Tensor,
        mask: torch.Tensor | None = None,
        dropout: nn.Dropout | None = None,
    ):
        """
        Computes scaled dot-product attention.

        Args:
            k: Key tensor.
            q: Query tensor.
            v: Value tensor.
            mask: Optional mask tensor.
            dropout: Optional dropout layer.

        Returns:
            Tuple of attention output and scores.
        """
        matmul_q_k = q @ k.transpose(-2, -1)
        d_k = k.shape[-1]
        matmul_q_k_scaled = matmul_q_k / math.sqrt(d_k)

        if mask is not None:
            matmul_q_k_scaled.masked_fill_(mask == 0, -1e9)

        attention_scores = F.softmax(input=matmul_q_k_scaled, dim=-1)

        if dropout is not None:
            attention_scores = dropout(attention_scores)

        return (attention_scores @ v), attention_scores

    def forward(
        self,
        k: torch.Tensor,
        q: torch.Tensor,
        v: torch.Tensor,
        mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Forward pass through multi-head attention.

        Args:
            k: Key tensor.
            q: Query tensor.
            v: Value tensor.
            mask: Optional mask tensor.

        Returns:
            Tensor after attention and concatenation.
        """
        key_prima = self.W_K(k)
        query_prima = self.W_Q(q)
        value_prima = self.W_V(v)

        key_prima = key_prima.view(
            key_prima.shape[0], key_prima.shape[1], self.h, self.d_k
        ).transpose(1, 2)
        query_prima = query_prima.view(
            query_prima.shape[0], query_prima.shape[1], self.h, self.d_k
        ).transpose(1, 2)
        value_prima = value_prima.view(
            value_prima.shape[0], value_prima.shape[1], self.h, self.d_k
        ).transpose(1, 2)

        attention, attention_scores = MultiHeadAttention.attention(
            k=key_prima,
            q=query_prima,
            v=value_prima,
            mask=mask,
            dropout=self.dropout,
        )

        attention = attention.transpose(1, 2)
        b, seq_len, h, d_k = attention.size()
        attention_concat = attention.contiguous().view(b, seq_len, h * d_k)

        return self.W_OUTPUT_CONCAT(attention_concat)

## Residual connections and encoder-decoder blocks

**Residual connections** allow gradients to flow directly through the network,
facilitating the training of deep architectures. Each sublayer in the Transformer is
wrapped in a residual connection followed by layer normalization:

$$\text{Output} = \text{LayerNorm}(x + \text{Sublayer}(x))$$

**Encoder blocks** apply self-attention followed by a feed-forward network, both with
residual connections. **Decoder blocks** add an additional cross-attention layer that
attends to the encoder output, allowing each position in the decoder to attend to all
positions in the input sequence.

In [5]:
class ResidualConnection(nn.Module):
    """Applies residual connection around a sublayer."""

    def __init__(self, features: int, dropout_rate: float) -> None:
        """
        Initializes residual connection layer.

        Args:
            features: Number of features in the input.
            dropout_rate: Rate of dropout for regularization.
        """
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = LayerNormalization(features=features)

    def forward(self, input_tensor: torch.Tensor, sublayer: nn.Module) -> torch.Tensor:
        """
        Forward pass using residual connection.

        Args:
            input_tensor: Input tensor to the residual layer.
            sublayer: Sublayer to apply within the residual connection.

        Returns:
            Tensor with residual connection applied.
        """
        return input_tensor + self.dropout(sublayer(self.layer_norm(input_tensor)))

class EncoderBlock(nn.Module):
    """Encoder block with attention and feed-forward layers."""

    def __init__(self, d_model: int, d_ff: int, h: int, dropout_rate: float) -> None:
        """
        Initializes encoder block.

        Args:
            d_model: Dimensionality of model embeddings.
            d_ff: Dimensionality of feed-forward layer.
            h: Number of attention heads.
            dropout_rate: Rate of dropout for regularization.
        """
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.h = h
        self.dropout_rate = dropout_rate

        self.multi_head_attention_layer = MultiHeadAttention(
            d_model=self.d_model, h=self.h, dropout_rate=self.dropout_rate
        )
        self.residual_layer_1 = ResidualConnection(
            features=d_model, dropout_rate=self.dropout_rate
        )
        self.feed_forward_layer = FeedForward(
            d_model=self.d_model, d_ff=self.d_ff, dropout_rate=self.dropout_rate
        )
        self.residual_layer_2 = ResidualConnection(
            features=d_model, dropout_rate=self.dropout_rate
        )

    def forward(
        self, input_tensor: torch.Tensor, mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Forward pass through encoder block.

        Args:
            input_tensor: Input tensor to the encoder block.
            mask: Optional mask tensor.

        Returns:
            Tensor after processing by the encoder block.
        """
        input_tensor = self.residual_layer_1(
            input_tensor,
            lambda x: self.multi_head_attention_layer(k=x, q=x, v=x, mask=mask),
        )
        input_tensor = self.residual_layer_2(
            input_tensor, lambda x: self.feed_forward_layer(x)
        )
        return input_tensor

class DecoderBlock(nn.Module):
    """Decoder block with masked attention, cross-attention, and feed-forward layers."""

    def __init__(self, d_model: int, d_ff: int, h: int, dropout_rate: float) -> None:
        """
        Initializes decoder block.

        Args:
            d_model: Dimensionality of model embeddings.
            d_ff: Dimensionality of feed-forward layer.
            h: Number of attention heads.
            dropout_rate: Rate of dropout for regularization.
        """
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff
        self.h = h
        self.dropout_rate = dropout_rate

        self.masked_multi_head_attention_layer = MultiHeadAttention(
            d_model=self.d_model, h=self.h, dropout_rate=self.dropout_rate
        )
        self.residual_layer_1 = ResidualConnection(
            features=d_model, dropout_rate=self.dropout_rate
        )
        self.multi_head_attention_layer = MultiHeadAttention(
            d_model=self.d_model, h=self.h, dropout_rate=self.dropout_rate
        )
        self.residual_layer_2 = ResidualConnection(
            features=d_model, dropout_rate=self.dropout_rate
        )
        self.feed_forward_layer = FeedForward(
            d_model=self.d_model, d_ff=self.d_ff, dropout_rate=self.dropout_rate
        )
        self.residual_layer_3 = ResidualConnection(
            features=d_model, dropout_rate=self.dropout_rate
        )

    def forward(
        self,
        decoder_input: torch.Tensor,
        encoder_output: torch.Tensor,
        src_mask: torch.Tensor | None = None,
        tgt_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Forward pass through decoder block.

        Args:
            decoder_input: Input tensor to the decoder block.
            encoder_output: Output tensor from the encoder.
            src_mask: Optional source mask tensor.
            tgt_mask: Optional target mask tensor.

        Returns:
            Tensor after processing by the decoder block.
        """
        decoder_input = self.residual_layer_1(
            decoder_input,
            lambda x: self.masked_multi_head_attention_layer(
                k=x, q=x, v=x, mask=tgt_mask
            ),
        )
        decoder_input = self.residual_layer_2(
            decoder_input,
            lambda x: self.multi_head_attention_layer(
                k=encoder_output, q=x, v=encoder_output, mask=src_mask
            ),
        )
        decoder_output = self.residual_layer_3(
            decoder_input, lambda x: self.feed_forward_layer(x)
        )
        return decoder_output

## Complete Transformer architecture

The complete Transformer architecture integrates all the described components into an
encoder-decoder structure. The **encoder** processes the input sequence through multiple
identical layers, each applying self-attention and feed-forward transformations. The
**decoder** generates the output sequence autoregressively, using both masked
self-attention and cross-attention over the encoder output. A final projection layer
transforms the decoder representations into probabilities over the output vocabulary.

In [6]:
class ProjectionLayer(nn.Module):
    """Converts d_model dimensions back to vocab_size."""

    def __init__(self, d_model: int, vocab_size: int) -> None:
        """
        Initializes projection layer.

        Args:
            d_model: Dimensionality of model embeddings.
            vocab_size: Size of the vocabulary.
        """
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.projection_layer = nn.Linear(in_features=d_model, out_features=vocab_size)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through projection layer.

        Args:
            input_tensor: Input tensor to the projection layer.

        Returns:
            Tensor with projected dimensions.
        """
        return self.projection_layer(input_tensor)

class Transformer(nn.Module):
    """Transformer model with encoder and decoder blocks."""

    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        src_seq_len: int,
        tgt_seq_len: int,
        num_encoders: int,
        num_decoders: int,
        d_model: int,
        d_ff: int,
        h: int,
        dropout_rate: float,
    ) -> None:
        """
        Initializes transformer model.

        Args:
            src_vocab_size: Size of source vocabulary.
            tgt_vocab_size: Size of target vocabulary.
            src_seq_len: Maximum source sequence length.
            tgt_seq_len: Maximum target sequence length.
            num_encoders: Number of encoder blocks.
            num_decoders: Number of decoder blocks.
            d_model: Dimensionality of model embeddings.
            d_ff: Dimensionality of feed-forward layer.
            h: Number of attention heads.
            dropout_rate: Rate of dropout for regularization.
        """
        super().__init__()

        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.src_seq_len = src_seq_len
        self.tgt_seq_len = tgt_seq_len
        self.num_encoders = num_encoders
        self.num_decoders = num_decoders
        self.d_model = d_model
        self.d_ff = d_ff
        self.h = h
        self.dropout_rate = dropout_rate

        self.src_embedding = InputEmbedding(
            d_model=self.d_model, vocab_size=self.src_vocab_size
        )
        self.tgt_embedding = InputEmbedding(
            d_model=self.d_model, vocab_size=self.tgt_vocab_size
        )
        self.src_positional_encoding = PositionalEncoding(
            d_model=self.d_model,
            sequence_length=self.src_seq_len,
            dropout_rate=self.dropout_rate,
        )
        self.tgt_positional_encoding = PositionalEncoding(
            d_model=self.d_model,
            sequence_length=self.tgt_seq_len,
            dropout_rate=self.dropout_rate,
        )

        self.encoder_layers = nn.ModuleList(
            [
                EncoderBlock(
                    d_model=self.d_model,
                    d_ff=self.d_ff,
                    h=self.h,
                    dropout_rate=self.dropout_rate,
                )
                for _ in range(self.num_encoders)
            ]
        )

        self.decoder_layers = nn.ModuleList(
            [
                DecoderBlock(
                    d_model=self.d_model,
                    d_ff=self.d_ff,
                    h=self.h,
                    dropout_rate=self.dropout_rate,
                )
                for _ in range(self.num_decoders)
            ]
        )

        self.projection_layer = ProjectionLayer(
            d_model=self.d_model, vocab_size=self.tgt_vocab_size
        )

    def encode(
        self, encoder_input: torch.Tensor, src_mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        """
        Encodes source input tensor using encoder blocks.

        Args:
            encoder_input: Input tensor to the encoder.
            src_mask: Optional source mask tensor.

        Returns:
            Encoded tensor.
        """
        x = self.src_embedding(encoder_input)
        x = self.src_positional_encoding(x)

        for encoder_layer in self.encoder_layers:
            x = encoder_layer(input_tensor=x, mask=src_mask)

        return x

    def decode(
        self,
        decoder_input: torch.Tensor,
        encoder_output: torch.Tensor,
        src_mask: torch.Tensor | None = None,
        tgt_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Decodes target input tensor using decoder blocks.

        Args:
            decoder_input: Input tensor to the decoder.
            encoder_output: Output tensor from the encoder.
            src_mask: Optional source mask tensor.
            tgt_mask: Optional target mask tensor.

        Returns:
            Decoded tensor.
        """
        x = self.tgt_embedding(decoder_input)
        x = self.tgt_positional_encoding(x)

        for decoder_layer in self.decoder_layers:
            x = decoder_layer(
                decoder_input=x,
                encoder_output=encoder_output,
                src_mask=src_mask,
                tgt_mask=tgt_mask,
            )

        return x

    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: torch.Tensor | None = None,
        tgt_mask: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Processes input and target sequences through the encoder
        and decoder, applying optional source and target masks.

        Args:
            src: Input sequence tensor.
            tgt: Target sequence tensor.
            src_mask: Optional mask for the input sequence.
            tgt_mask: Optional mask for the target sequence.

        Returns:
            Tensor containing the final output after projection.
        """
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        return self.projection_layer(decoder_output)