---
title: "Transformer"
description: "Implementing a Transformer Architecture from scratch in PyTorch"
date: "2025-08-09"
#date-modified: "2025-02-22"
#categories: [news]
bread-crumbs: true
back-to-top-navigation: true
toc: true
toc-depth: 3
#image: images/pizza-13601_256.gif
---

The Transformer architecture is the fundamental concept of LLMs (make this a better sentence).
If you're working in Generative AI or you aspire to work in Generativ AI, you need to understand this. something sothine NLP

The Transformer is a neural network architecture mostly used for natural language processing (NLP) tasks and was introduced in the ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) paper by Vaswani et al. in 2017.

<!-- Something about the Transformer architecture and the Attention mechanism -->

The core idea is that for any given position in a sequence, the transformer asks "What other positions in this sequence should I pay attention to?"

This article explains the core concepts of the Transformer architecture and its Python implementation from scratch using PyTorch.

## Motivation for the Transformer

[Sequence-to-sequence (seq2seq) models](TODO: Sutskever, 2014) transform an input sequence into an output sequence, where both sequences can be of arbitrary length. 
Therefore, they are useful for many NLP tasks, such as machine translation, question answering, or text summarization.

<!-- Motivation for RNNs -->
Early approaches for these models used [**Recurrent Neural Networks (RNNs)**](TODO), which process sequences step-by-step by maintaining a hidden state that captures information from previous inputs. 
This enabled memory of past information for tasks like language modeling and time series prediction.

<!-- Limitation of RNNs -->
However, vanilla RNNs suffer from the **vanishing gradients problem**, making them unable to remember information from earlier parts of longer input sequences.
During backpropagation through time gradients are computed by repeatedly multiplying small partial derivatives as they flow backward through the network.
This chain of multiplications causes gradients to become vanishingly small for earlier time steps, causing their weights to barely update during training.
Since the network can't effectively learn how to use information from earlier time steps, it can't remember that information when making predictions.

![](./images/translation_seq2seq.png)

<!-- Motivation for LSTMs -->
[**Long Short-Term Memory networks (LSTMs)**](TODO) are a specialized type of RNN, specifically developed to address this limitation. LSTMs use gating mechanisms to selectively remember and forget information over long sequences.
Although they improved learning long-range dependency compared to vanilla RNNs, they still process sequences sequentially, resulting in training inefficiency.

<!-- Motivation for Transformer -->
The Transformer architecture was motivated by the need to overcome these sequential processing limitations while maintaining the ability to capture long-term dependencies. 
**TODO:** Add something more here...

<!-- The difference between RNNs and Transformers -->
<!-- Benefits of Transformers over RNNs -->
- **Training efficiency:** The key difference between RNN-based (vanilla RNNs, LSTMs, GRUs) models and Transformer-based models is whether they process the tokens in a sequence step-by-step (sequentially) or simultaneously (in parallel). The parallelization enables faster training on modern hardware
- **Handling of long sequences:** Transformers use self-attention mechanisms to focus on any part of a sequence, regardless of their distance. This enables them to bebetter at capturing relationships between distant tokens in a sequence

<!-- When would you use which -->
Today, Transformers dominate most sequence-to-sequence tasks due to their superior performance  over RNNs. 
Therefore, they are used in most generative AI systems today.
However, RNNs are still useful for simple NLP tasks or time series forecasting or when computational resources are limited due to their simple architecture. 

## Transformer Architecture Overview

::: {.callout-warning}
Interview Questions:
- Walk through the forward pass of a transformer block*
:::

The original Transformer has an encoder-decoder architecture. Think of it like a human translator who first reads and comprehends the entire English sentence, then writes the French translation using that understanding:


- **Encoder**: Processes the input sequence $(x_1, ..., x_n)$ and creates rich representations in the form of vector embeddings $\mathbf{z} = (z_1, ...,z_n)$  ("I'll read the English and create a rich understanding")
- **Decoder**: Generates the output sequence $(y_1,...,y_m)$ one token at a time, attending to both its own partial output and the encoder's representations ("I'll use that understanding to generate French")

![Transformer architecture overview diagram showing encoder-decoder structure.](images/transformer_encoders_decoders.png)

Both the encoder and decoder components are stacks of multiple ($N$) repeated encoder and decoder layers respectively.
In the original paper, they defined N=6.

The output of each encoder layer is passed as input to the next encoder layer.
The output of the last encoder layer is passed as an input to every decoder layer in the decoder components.

The encoder and decoder layers are all identical in structure but don't share any weights.

![](images/transformer_encoders_decoder_stacks.png)

At a more detailed level, the Transformer has the following architecture composed of the following core components:

![Detailed Transformer architecture with encoder and decoder layers, self-attention, and feed-forward networks](images/architecture_detailed.png)

::: {.callout-warning}
What to do with this?
- Input Embeddings (Inputs, Input embedding, output embedding, outputs shifted right)
- Positional Encodings: to maintain sequence order without recurrence
- Projection Layer (Softmax and Liniear)
:::

### Encoder Layer
Each encoder layer has two sub-layers:

1. **Multi-head self-attention layer**: helps the encoder look at other words in the input sentence as it encodes a specific word `MultiHeadAttention`
2. **Position-wise fully-connected Feed Forward Neural Network:** what does this do? `FeedForward`

> We employ residual connections around each of the sub-layers, followed by layer normalization `nn.LayerNorm`.


![](images/encoder-architecture.png)

### Decoder Layer
Each decoder layer has the following sub-layers:

- **Masked multi-Head self-attention layer** (masked multi head )
    - Sub-layer 1: Self-attention (same sequence)
    - query=key=value=target_input
- **Mutli-Head Cross-attention**:
    - helps the decoder focus on relevant parts of the input sentence
    - query=target_input:          What we're generating
    - key=value=encoder_output :    What we can look at
- **Feed forward layer**

Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization.

> We employ residual connections around each of the sub-layers, followed by layer normalization `nn.LayerNorm`.

![](images/decoder-architecture.png)
!["Placeholder"](imgages/placeholder.png)

## Implementing the Transformer from Scratch

::: {.callout-tip}
A Note on "from scratch"

Note that PyTorch has the [class `nn.Transformer` and its components](https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/modules/transformer.py) already built-in.
The goal of this article, however, is to implement the Transformer from scratch to gain a better understanding of it.

We will use some built-in PyTorch functions, such as the following because TODO:

`nn.Linear`, `nn.Embedding`, `nn.LayerNorm`, `nn.Dropout`, `nn.ReLu`, `F.softmax`, TODO complete list...
:::

In [None]:
import torch
import torch.nn as nn
import math

### Embeddings

The transformer uses **learned embeddings** to represent the discrete tokens of text as numerical input for the neural network.
Each unique token is mapped to a dense vector of fixed dimension $d_{\text{model}}$ (typically 256, 512, or 768 dimensions).
During training, these embedding vectors are learned to capture semantic relationships between tokens.

![](images/embeddings.png)


::: {.callout-tip}
Early NLP approaches created vector representations of discrete tokens by using one-hot encoding of the entire vocabulary. 
This resulted in large, sparse vectors with the dimension of the vocabulary size (e.g., 50,000 tokens).
These numerical representations were computationally inefficient and didn't carry any semantic meaning.
:::


The original paper scales embeddings by $\sqrt{d_{model}}$ to:

- ensure that embedding values and positional encoding values are roughly the same magnitude so that when they are  added together, neither dominates the other (TODO: we're already talking about positional embeddings here) 
- helps with **training stability**

```pytorch
src_emb = self.src_embedding(src) * math.sqrt(self.d_model)
```

The resulting embedding matrix is `d_model * vocab_size`

<!-- Example 
- Suppose each embedding vector is of 512 dimension
- suppose our vocab size is 100, 
- then our embedding matrix will be of size 100x512.
This embedding matrix will be learned during training.
During inference each word will be mapped to the corresponding 512 dimensional vector.
Suppose we have batch size of 32 and sequence length of 10 words, The output will be 32x10x512.
-->

<!-- Difference between input and output embeddings -->

**Input embeddings:**

- the embedding only happens in the bottom-most encoder.
- each encoder receives a list of vector each of the size (d_model 512)
    - in the bottom encoder that would be the token embeddings
    - in the other encoders, that's the output of the encoder that's directly below (the size of this list is a hyperparaeter we can set, the length of the longest sentece in the training dataset)
- token in each position flows throgh its own path in the encoder
    - there are dependencies between these paths in the self attention layer
    - the feed-forard layer does not have those dependencies - thus the paths can be executed in parallel

**Output embeddings:**

1. The output of each step is fed to the bottom decoder in the next time step,
2. the decoders bubble up their decoding results just like the encoders did.
3. And just like we did with the encoder inputs, we embed and add positional encoding to those decoder inputs to indicate the position of each word.
4. The following steps repeat the process until a special symbol (e.g. TODO `'<EOS>'`) is reached indicating the transformer decoder has completed its output. -> where should this go? auto-regression?

### Positional Encodings

Positional encodings are vectors that tell the transformer where each token sits in a sequence.
The motivation for positional encoding is that unlike RNNs, which processes tokens one by one, the transformer looks at all tokens in a sequence at the same time.
This makes the transformer fast but it also means that - without positional information - it can't distinguish between "The cat sat on the mat" and "The mat sat on the cat" because it just sees the same [bag of words](bag_of_words.ipynb).

To overcome this problem, you can add positional information in the form of positional encoding vectors directly to the input embeddings. The positional encoding vectors have the same dimensions $d_{model}$ (`d_model`)  as the input embeddings, so that they can be summed element-wise.
This gives each token a combined representation that captures both its meaning and its location in the sequence.

TODO: also add the vector visuals?
![](images/positional_encodings.png)
<!-- Visual comment 
for eg: if we have batch size of 32 and seq length of 10 and let embedding dimension be 512.
Then we will have embedding vector of dimension 32 x 10 x 512.
Similarly we will have positional encoding vector of dimension 32 x 10 x 512. Then we add both.
-->

The original transformer uses the **sinusoidal positional encodings**  based on sine and cosine functions of different frequencies:
$$
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)
$$
$$
PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)
$$

where $pos$ is the position and $i$ is the dimension index.

> Below the positional encoding will add in a sine wave based on position. The frequency and offset of the wave is different for each dimension.

That is, each dimension of the positional encoding corresponds to a sinusoid.  
The wavelengths form a geometric progression from $2\pi$ to $10000 \cdot 2\pi$.  

These formulas create unique patterns for each position by using different frequencies across the embedding dimensions. 
Position 0 has one pattern, position 1 has a slightly different pattern, and so on. 
This enables the model to learn to recognize not just absolute positions, but also relative distances between tokens.

!["Placeholder"](images/placeholder.png)

The inuition is that adding these values to the embeddings provides meaningful distances between the embedding vectors once they are projected into QKV vectors and during dot-product attention.

In addition, we apply dropout to the sums of the embeddings and the positional encodings in both the encoder and decoder stacks.  
For the base model, we use a rate of $P_{drop}=0.1$.

In [None]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model,  max_seq_len=5000, dropout=0.1,):
        """
        Args:
            d_model: dimension of embeddings
            dropout: dropout rate, the original paper uses 0.1
            max_seq_len: maximum sequence length
        """
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1)

        # Create a div term for the denominator
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

        # Apply sin to even indices (0, 2, 4, ...)
        pe[:, 0::2] = torch.sin(position * div_term)

        # Apply cos to odd indices (1, 3, 5, ...)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add batch dimension
        pe = pe.unsqueeze(0)

        # Register as buffer (saved with model, not trained)
        self.register_buffer("pe", pe)

    def forward(self, x):
        # x shape : [batch_size, seq_len, d_model]
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len].requires_grad_(False) # TODO: what does this "requires_grad_(False)" do?
        return self.dropout(x)

::: {.callout-tip}
Note, that there are different ways to add positional information to input tokens, such as RoPE.
<!-- Read more on [positional encoding](positional_encoding.ipynb) -->
:::

### Multi-Head Attention

The main component in the Transformer is the **multi-head self-attention mechanism**.

<!-- Explain the attention mechanism in simple terms>
<!-- Attention in simple terms -->
The **attention mechanism**  allows the model to look ("pay attention") at different tokens to gather relevant information when processing a sequence.
Instead of treating every token equally, the model learns which tokens are most relevant to each other for building good representations.
"It computes attention weights to determine the relevance of different parts."

<!-- Example -->
Take the following sentence for example:

> ”The animal didn't cross the street because **it** was too tired”

When processing the token "it", the model must determine whether it refers to the "street" or to the "animal".
The self-attention mechanism allows the model to resolve this pronoun ambiguity and associate the token "it" with the token "animal".

<!-- Attention visualization -->
![](./images/attention_visualized.png)

<!--
![](images/transformer_self-attention_visualization.png)
Be sure to check out the Tensor2Tensor notebook where you can load a Transformer model, and examine it using this interactive visualization.
- [] Add visual of formula here
-->

::: {.callout-tip}
<!-- Other notes -->
Note, that the attention mechanism is not a new concept. The concept of attention  was introduced by [Bahdanau 2014](https://arxiv.org/abs/1409.0473)

There's many different attention mechanisms:

- Content-base attention (Graves 2014)
- [Additive: Bahdanau 2015](https://arxiv.org/abs/1409.0473):
    Additive attention computes the compatibility function using a feed-forward network with a single hidden layer.  
- Location-base (Luong2015)
- General (Luong 2015)
- Dot-product (multiplicative) (Luong 2014):
    Dot-product attention is identical to our algorithm, except for the scaling factor of $\frac{1}{\sqrt{d_k}}$. 
- Scaled dot-product (Vaswani2017)

The two most commonly used attention functions are additive attention, and dot-product (multiplicative) attention. 
While the two are similar in theoretical complexity, dot-product attention is much faster and more space-efficient in practice, since it can be implemented using highly optimized matrix multiplication code.
:::

#### Scaled dot-product attention

We call our particular attention **"Scaled Dot-Product Attention"**.

Create three vectors from each of the encoder’s input vectors (in this case, the embedding of each token).
- Query and key vectors of dimension $d_k$
- Value vectors of dimension $d_v$

<!-- How are these vectors generated -->
These vectors are created by multiplying the embedding by three matrices that we trained during the training process.
<!-- What are these vectors? -->
They’re abstractions that are useful for calculating and thinking about attention.

The transformer views the encoded representation of the input as a set of key-value pairs $(K, V)$
An attention function can be described as mapping a query and a set of key-value pairs to an output. 
The output vector is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.


We compute the dot products of the query with all keys, divide each by $\sqrt{d_k}$, and apply a softmax function to obtain the weights on the values.

<!-- Why do we scale the dot-product attention?-->
While for small values of $d_k$ the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of $d_k$ [(cite)](https://arxiv.org/abs/1703.03906). 
We suspect that for large values of $d_k$, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.
To illustrate why the dot products get large, assume that the components of $q$ and $k$ are independent random variables with mean $0$ and variance $1$.  
Then their dot product, $q \cdot k = \sum_{i=1}^{d_k} q_ik_i$, has mean $0$ and variance $d_k$.
To counteract this effect, we scale the dot products by $\frac{1}{\sqrt{d_k}}$.
Without the scaling factor, embeddings can dominate positional encodings. 
It's not arbitrary - it's mathematically necessary for stable training.
Scale by $\sqrt{d_k}$ to prevent gradients from becoming too small/prevents saturation in the softmax when `d_model` is large

In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix $Q$.  
The keys and values are also packed together into matrices $K$ and $V$.  
We compute the matrix of outputs as:

<!-- - *Derive the attention formula. How does it work mathematically? -->
$$
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

- Query ($Q$): What we're looking for
- Key ($K$): What we're looking at
- Value ($V$): What we actually use

1. We compute similarity scores between queries and keys using dot products
2. Scale by $\sqrt{d_k}$ to prevent gradients from becoming too small/prevents saturation in the softmax when `d_model` is large
3.  Masking lets us ignore padded tokens or implement causal attention
3. Apply softmax to get attention weights that sum to 1 (Softmax normalizes the scores so they’re all positive and add up to 1.)
4. Use these weights to compute a weighted average of the values

![](./images/sclaed_dot_product_attention.png)

<!-- *How would you implement multi-head attention from scratch?* -->

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    """
    Compute 'Scaled Dot Product Attention'
    Attention with optional masking
    mask shape: [batch_size, seq_len, seq_len] or broadcastable
    """

    # Compute scaled attention scores
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # Apply mask before softmax (set masked positions to large negative value)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # Softmax over the last dimension
    p_attn = scores.softmax(dim=-1)

    # Apply droppout after softmax
    if dropout is not None:
        p_attn = dropout(p_attn)

    return torch.matmul(p_attn, value), p_attn

#### Multi-head attentions (duplicated heading)

- Instead of computing the attention once, the multi-head attention mechanism runs through the scaled dot-product attention multiple times in parallen
- The independent attention outputs are concatenated in linearly transformed into the expected dimensions

(IMG: multi head Ecoder-Decode attention layer)

<!-- Motivation: Why do transformers use multiple attention heads? What does each head learn? -->
- (I assume the motivation is because ensembling always helps? - Lilian weng)
- The multi-headed attention improves the performance of the attention layer in two ways:

1. It expands the model’s ability to focus on different positions. Yes, in the example above, z1 contains a little bit of every other encoding, but it could be dominated by the actual word itself. If we’re translating a sentence like “The animal didn’t cross the street because it was too tired”, it would be useful to know which word “it” refers to.
2. It gives the attention layer multiple “representation subspaces”. As we’ll see next, with multi-headed attention we have not only one, but multiple sets of Query/Key/Value weight matrices (the Transformer uses eight attention heads, so we end up with eight sets for each encoder/decoder). Each of these sets is randomly initialized. Then, after training, each set is used to project the input embeddings (or vectors from lower encoders/decoders) into a different representation subspace.

We concat the matrices then multiply them by an additional weights matrix $W_O$.

![Multi-Head Attention mechanism diagram: visualization of how multiple attention heads process input sequences in parallel, each focusing on different representation subspaces, then concatenating their outputs.](images/multihead-attention.png)

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.

$$
\mathrm{MultiHead}(Q, K, V) =
    \mathrm{Concat}(\mathrm{head_1}, ..., \mathrm{head_h})W^O \\
    \text{where}~\mathrm{head_i} = \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)
$$


Where the projections are parameter matrices $W^Q_i \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $W^K_i \in \mathbb{R}^{d_{\text{model}} \times d_k}$, $W^V_i \in \mathbb{R}^{d_{\text{model}} \times d_v}$ and $W^O \in \mathbb{R}^{hd_v \times d_{\text{model}}}$.

In this work we employ $h=8$ parallel attention layers, or heads.
For each of these we use $d_k=d_v=d_{\text{model}}/h=64$.
Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

**Multi-head attention** runs multiple attention mechanisms in parallel, each focusing on different aspects of the relationships, then concatenates and projects the results.

<!-- Attention vs. Multi-Head attention -->
---

The Transformer uses multi-head attention in three different ways:

![Difference between Self-Attention in Encoder and Decoder vs. Cross-Attention](images/transformer_self_attention_vs_cross_attention.png)

#### Self-Attention

The encoder contains self-attention layers.  In a self-attention layer all of the keys, values and queries come from the same place, in this case, the output of the previous layer in the encoder.  Each position in the encoder can attend to all positions in the previous layer of the encoder.

#### Masked Self-Attention
self-attention layers in the decoder allow each position in the decoder to attend to all positions in the decoder up to and including that position.  We need to prevent leftward information flow in the decoder to preserve the auto-regressive property.  We implement this inside of scaled dot-product attention by masking out (setting to $-\infty$) all values in the input of the softmax which correspond to illegal connections.
In the decoder, the self-attention layer is only allowed to attend to earlier positions in the output sequence. This is done by masking future positions (setting them to -inf) before the softmax step in the self-attention calculation.
Why masking? In the decoder, we need to prevent tokens from seeing future tokens during training.

The encoder start by processing the input sequence.
The output of the top encoder is then transformed into a set of attention vectors K and V.
These are to be used by each decoder in its “encoder-decoder attention” layer which helps the decoder focus on appropriate places in the input sequence:

We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions.  This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position $i$ can depend only on the known outputs at positions less than $i$.

> Below the attention mask shows the position each tgt word (row) is allowed to look at (column). Words are blocked for attending to future words during training.

TODO: add visualization of mask here

<!--  
- How does causal/masked attention work in decoder models
- Self-attention vs. Masked self-attention

-->

In [None]:
# "Mask out subsequent positions."
def create_causal_mask(seq_len):
    """Create causal mask to prevent attending to future positions"""
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return mask == 0  # Convert to boolean mask

#### Cross-Attention

The cross-attention mechanism, or often also called "encoder-decoder attention", works just like the self-attention mechanism, with the exception of where queries, keys, and values come from. While in the self-attention mechanism queries, keys, and values all come from same input sequence, in the cross-attention mechanism querie come from different sequences than  keys and values:

- Queries: come from the previous decoder layer
- Keys and values: come from the output of the encoder stack.

This allows every position in the decoder to attend over all position in the encoder sequence.  

<!-- Explain cross-attention vs self-attention -->

Implementation
- Supports both self-attention and cross-attention
- Handles different sequence lengths for encoder/decoder

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # dimension per head

        # Linear projections for queries, keys, and values
        self.W_query = nn.Linear(d_model, d_model)
        self.W_key = nn.Linear(d_model, d_model)
        self.W_value = nn.Linear(d_model, d_model)

        # Output projection
        self.W_output = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        seq_len_query = query.size(1)
        seq_len_key = key.size(1)

        # Step 1: Linear projections for all heads at once
        # Shape: [batch_size, seq_len, d_model]
        Q = self.W_query(query)
        K = self.W_key(key)
        V = self.W_value(value)

        # Step 2: Reshape to separate heads
        # From [batch_size, seq_len, d_model] to [batch_size, seq_len, num_heads, d_k]
        Q = Q.view(batch_size, seq_len_query, self.num_heads, self.d_k)
        K = K.view(batch_size, seq_len_key, self.num_heads, self.d_k)
        V = V.view(batch_size, seq_len_key, self.num_heads, self.d_k)

        # Step 3: Transpose to [batch_size, num_heads, seq_len, d_k] for efficient computation
        Q = Q.transpose(1, 2)  # [batch_size, num_heads, seq_len_query, d_k]
        K = K.transpose(1, 2)  # [batch_size, num_heads, seq_len_key, d_k]
        V = V.transpose(1, 2)  # [batch_size, num_heads, seq_len_key, d_k]

        # Step 4: Apply scaled dot-product attention to each head
        attention_output, attention_weights = scaled_dot_product_attention(
            Q, K, V, mask
        )
        # attention_output: [batch_size, num_heads, seq_len_query, d_k]

        # Step 5: Concatenate heads
        # Transpose back: [batch_size, seq_len_query, num_heads, d_k]
        attention_output = attention_output.transpose(1, 2)

        # Reshape to concatenate heads: [batch_size, seq_len_query, d_model]
        attention_output = attention_output.contiguous().view(
            batch_size, seq_len_query, self.d_model
        )

        # Step 6: Final linear projection
        output = self.W_output(attention_output)

        return output, attention_weights

In [None]:
# Example usage
def example_usage():
    batch_size, seq_len, d_model = 2, 10, 512
    num_heads = 8
    dropout= 0.1
    # Create sample input
    x = torch.randn(batch_size, seq_len, d_model)

    # Initialize multi-head attention
    mha = MultiHeadAttention(d_model, num_heads, dropout)

    # Self-attention (query, key, value are all the same)
    output, weights = mha(x, x, x)

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Attention weights shape: {weights.shape}")

    # For encoder-decoder attention, you'd use different inputs:
    # output, weights = mha(decoder_hidden, encoder_output, encoder_output)

example_usage()

### Position-wise Feed-Forward Networks

The outputs of the self-attention layer are fed to a feed-forward neural network in the encoder.

In addition to attention sub-layers, each of the layers in our
encoder and decoder contains a fully connected feed-forward network,
which is applied to each position separately and identically.  This
consists of two linear transformations with a ReLU activation in
between.

$$\mathrm{FFN}(x)=\max(0, xW_1 + b_1) W_2 + b_2$$

While the linear transformations are the same across different
positions, they use different parameters from layer to
layer. Another way of describing this is as two convolutions with
kernel size 1.  The dimensionality of input and output is
$d_{\text{model}}=512$, and the inner-layer has dimensionality
$d_{ff}=2048$.

In [None]:
class FeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(FeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))

### Residual Connection and Layer normalization

When we look at the encoder and decoder blocks, we see several normalization layers called Add & Norm.

> We employ a [residual connection](https://arxiv.org/abs/1512.03385) around each of the two sub-layers, followed by [layer normalization](https://arxiv.org/abs/1607.06450).


**Residual connection**: TODO: what is the purpose of this?
we employ residual connections around each of the sub-layers, (The residual connection itself is just the addition operation: `x + Sublayer(x))`

- The Core Idea: Instead of having each layer learn a complete new representation, we have it learn only the changes or refinements to make to the existing representation.
- Think of it like editing a document:
- Without residual connections: Each editor throws away the previous version and writes a completely new document from scratch
- With residual connections: Each editor takes the existing document and only adds their improvements to it


When we look at the architecture of the Transformer, we see that each sub-layer, including the self-attention and Feed Forward blocks, adds its output to its input before passing it to the Add & Norm layer. This approach integrates the output with the original input in the Add & Norm layer. This process is known as the skip connection, which allows the Transformer to train deep networks more effectively by providing a shortcut for the gradient to flow through during backpropagation.

**Layer Normalization**: TODO: what is the purpose of this?
*What is the purpose of layer normalization in transformers?*

The output of each sub-layer is $\mathrm{LayerNorm}(x +\mathrm{Sublayer}(x))$, where $\mathrm{Sublayer}(x)$ is the function implemented by the sub-layer itself.
We apply [dropout](http://jmlr.org/papers/v15/srivastava14a.html) to the output of each sub-layer, before it is added to the sub-layer input and normalized.
To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension $d_{\text{model}}=512$.


The `LayerNormalization` class below performs layer normalization on the input data.
During its forward pass, we compute the mean and standard deviation of the input data.
We then normalize the input data by subtracting the mean and dividing by the standard deviation plus a small number called epsilon to avoid any divisions by zero.
This process results in a normalized output with a mean 0 and a standard deviation 1.

We will then scale the normalized output by a learnable parameter alpha and add a learnable parameter called bias. The training process is responsible for adjusting these parameters. The final result is a layer-normalized tensor, which ensures that the scale of the inputs to layers in the network is consistent.

Layer normalization helps the transformer learn better and faster.

Think of it like this: imagine you're trying to learn math, but every day the teacher uses completely different scales - sometimes numbers from 1-10, sometimes 1-1000, sometimes -500 to +500. It would be really hard to focus on the actual math concepts because you'd be constantly adjusting to these different scales.

Layer normalization solves this problem by keeping all the numbers in a consistent, predictable range. For each example, it looks at all the features and normalizes them so they have an average of 0 and spread nicely around that average. This way, each layer in the transformer gets inputs that are always in the same comfortable range.

This makes training much more stable and allows the model to learn the important patterns instead of getting distracted by wildly varying number scales.


```python
x = self.layer_norm_1(x + self.dropout(attention_output))
```

Layer Normalization dramatically improves trainability.

- Post-norm (original) $z_i = \text{LN}(\text{Module}(x_i) + x_i)$
- Post-norm (modern) $z_i = \text{Module}(\text{LN}(x_i)) + x_i$

In [None]:
# Create some example data with different scales and ranges
torch.manual_seed(42)  # For reproducible results

# Example 1: Data with wildly different scales
batch_size, seq_len, d_model = 2, 3, 4
x = torch.tensor([
    # First sequence: small numbers
    [[0.1, 0.2, 0.3, 0.4],
     [0.2, 0.1, 0.4, 0.3],
     [0.3, 0.4, 0.1, 0.2]],

    # Second sequence: large numbers
    [[100, 200, 300, 400],
     [150, 250, 350, 450],
     [200, 300, 400, 500]]
], dtype=torch.float32)

print("Original data (notice the different scales):")
print("Sequence 1 (small numbers):")
print(x[0])
print("Sequence 2 (large numbers):")
print(x[1])
print()

# Show statistics before normalization
#print("Statistics BEFORE layer norm:")
#print(f"Sequence 1 - Mean: {x[0].mean(-1):.2f}, Std: {x[0].std(-1):.2f}")
#print(f"Sequence 2 - Mean: {x[1].mean(-1):.2f}, Std: {x[1].std(-1):.2f}")
#print()

# Apply layer normalization
layer_norm = nn.LayerNorm(d_model)
x_normalized = layer_norm(x)

print("After layer normalization:")
print("Sequence 1:")
print(x_normalized[0])
print("Sequence 2:")
print(x_normalized[1])
print()


Original data (notice the different scales):
Sequence 1 (small numbers):
tensor([[0.1000, 0.2000, 0.3000, 0.4000],
        [0.2000, 0.1000, 0.4000, 0.3000],
        [0.3000, 0.4000, 0.1000, 0.2000]])
Sequence 2 (large numbers):
tensor([[100., 200., 300., 400.],
        [150., 250., 350., 450.],
        [200., 300., 400., 500.]])

After layer normalization:
Sequence 1:
tensor([[-1.3411, -0.4470,  0.4470,  1.3411],
        [-0.4470, -1.3411,  1.3411,  0.4470],
        [ 0.4470,  1.3411, -1.3411, -0.4470]], grad_fn=<SelectBackward0>)
Sequence 2:
tensor([[-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416],
        [-1.3416, -0.4472,  0.4472,  1.3416]], grad_fn=<SelectBackward0>)



### Linear Projection

The decoder stack outputs a vector of floats.
How do we turn that into a word?

The linear projection consists of two layers:

- **Linear layer:** is a simple fully connected neural network that projects the vector produced by the stack of decoders, into a much, much larger vector called a logits vector.
    Let’s assume that our model knows 10,000 unique English words (our model’s “output vocabulary”) that it’s learned from its training dataset.
    This would make the logits vector 10,000 cells wide – each cell corresponding to the score of a unique word.
    That is how we interpret the output of the model followed by the Linear layer.
- **Softmax layer**: then turns those scores into probabilities (all positive, all add up to 1.0).
The cell with the highest probability is chosen, and the word associated with it is produced as the output for this time step.

We also use the usual learned linear transformation and softmax function to convert the decoder output to predicted next-token probabilities.  
In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation.

Example:
The transformer outputs 256-dimensional vectors, but we need probabilities over 1790 French words.
The linear layer projects from 256 → 1790 dimensions, then softmax gives us a probability distribution.


![](images/transformer_decoder_output_softmax.png)


```python
# Final output projection to target vocabulary
self.output_projection = nn.Linear(d_model, target_vocab_size)

# Step 3: Project to vocabulary logits
output_logits = self.output_projection(decoder_output)
```


Notice, how we don't have the Softmax layer here?

During training the loss function handles the softmax internally

```python
loss = nn.CrossEntropyLoss()(logits.view(-1, vocab_size), targets.view(-1))
```

And during Apply softmax to get probabilities

```python
probs = torch.softmax(logits[:, -1, :], dim=-1)
```


## Transformer Model

In [None]:
class Encoder(nn.Module):
    def __init__(self,
                 d_model,
                 num_heads,
                 num_layers,
                 d_ff,
                 dropout):
        super().__init__()

        # Stack of encoder layers
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
            ])

    def forward(self, x, source_mask=None):

        # Pass the input (and mask) through each encoder layer
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, source_mask)

        return x  # [batch, seq_len, d_model]

In [None]:
class EncoderLayer(nn.Module):
    def __init__(
            self,
            d_model,
            num_heads,
            d_ff,
            dropout
        ):
        super().__init__()

        # Sub-layer 1: Multi-head self-attention
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)

        # Sub-layer 2: Feed-forward network
        self.feed_forward = FeedForward(d_model, d_ff, dropout)

        # Layer normalization for each sub-layer
        self.layer_norm_1 = nn.LayerNorm(d_model)
        self.layer_norm_2 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(
            self,
            x,
            mask=None
        ):

        # Sub-layer 1: Multi-head self-attention with residual connection
        # For self-attention: query, key, and value are all the same input
        attention_output, attention_weights = self.self_attention(
            query=x,    # Same input
            key=x,      # Same input
            value=x,    # Same input
            mask=mask
        )

        # Post-norm: residual connection then normalize
        x = self.layer_norm_1(x + self.dropout(attention_output))

        # Sub-layer 2: Feed-forward with residual connection
        feed_forward_output = self.feed_forward(x)

        # Post-norm: residual connection then normalize
        x = self.layer_norm_2(x + self.dropout(feed_forward_output))

        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self,
                 d_model,
                 num_heads,
                 num_layers,
                 d_ff,
                 dropout):
        super().__init__()

        # Stack of decoder layers
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
            ])

    def forward(self, x, encoder_output, source_mask=None, target_mask=None):

        # Pass through each decoder layer
        for decoder_layer in self.decoder_layers:
            x = decoder_layer(x, encoder_output, source_mask, target_mask)

        return x  # [batch, seq_len, d_model] # TODO: revisit if we should include layer norm here...

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self,
                 d_model,
                 num_heads,
                 d_ff,
                 dropout
                 ):
        super().__init__()

        # Sub-layer 1: Masked multi-head self-attention
        self.masked_self_attention = MultiHeadAttention(d_model, num_heads, dropout)

        # Sub-layer 2: Multi-head encoder-decoder attention
        self.encoder_decoder_attention = MultiHeadAttention(d_model, num_heads, dropout) # should this be called cross attention``

        # Sub-layer 3: Feed-forward network
        self.feed_forward = FeedForward(d_model, d_ff, dropout)

        # Layer normalization for each sub-layer
        self.layer_norm_1 = nn.LayerNorm(d_model)
        self.layer_norm_2 = nn.LayerNorm(d_model)
        self.layer_norm_3 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self,
                x,
                encoder_output,
                source_mask=None,
                target_mask=None
                ):

        # Sub-layer 1: Masked self-attention on target sequence
        # For self-attention: query, key, and value are all the same input (target)
        masked_attention_output, masked_attention_weights = self.masked_self_attention(
            query=x,     # Same target input
            key=x,       # Same target input
            value=x,     # Same target input
            mask=target_mask        # Causal mask to prevent seeing future tokens
        )

        # Post-norm: residual connection then normalize
        x = self.layer_norm_1(x + self.dropout(masked_attention_output))

        # Sub-layer 2: Encoder-decoder attention
        # Query comes from decoder, key and value come from encoder
        encoder_attention_output, encoder_attention_weights = self.encoder_decoder_attention(
            query=x,     # What the decoder is generating
            key=encoder_output,     # What information is available from encoder
            value=encoder_output,   # What information to retrieve from encoder
            mask=source_mask        # Mask for padding tokens in source
        )

        # Post-norm: residual connection then normalize
        x = self.layer_norm_2(x + self.dropout(encoder_attention_output))

        # Sub-layer 3: Feed-forward network
        feed_forward_output = self.feed_forward(x)

        # Post-norm: residual connection then normalize
        x = self.layer_norm_3(x + self.dropout(feed_forward_output))

        return x

In [None]:
class TransformerModel(nn.Module):
    def __init__(self,
                 source_vocab_size,     # Source vocabulary size
                 target_vocab_size,     # Target vocabulary size
                 d_model=512,           # Model dimension
                 num_heads=8,           # Number of attention heads
                 num_layers=6,          # Number of encoder/decoder layers
                 d_ff=2048,             # Feed-forward dimension
                 max_seq_len=5000,      # Maximum sequence length
                 dropout=0.1):
        super().__init__()

        self.d_model = d_model

        # Input processing of embeddings and positional encodings 
        self.src_embedding = nn.Embedding(source_vocab_size, d_model) 
        self.tgt_embedding = nn.Embedding(target_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len, dropout)

        # Encoder stack
        self.encoder = Encoder(
            d_model=d_model,
            num_heads=num_heads,
            num_layers=num_layers,
            d_ff=d_ff,
            dropout=dropout
        )

        # Decoder stack
        self.decoder = Decoder(
            d_model=d_model,
            num_heads=num_heads,
            num_layers=num_layers,
            d_ff=d_ff,
            dropout=dropout
        )

        # Output projection to target vocabulary
        self.output_projection = nn.Linear(d_model, target_vocab_size)

    def forward(self, source_tokens, target_tokens, source_mask=None, target_mask=None):
        """
        Forward pass for training (teacher forcing)

        source_tokens: [batch_size, source_seq_len] - source token ids
        target_tokens: [batch_size, target_seq_len] - target token ids
        """

        # Scale embeddings
        src_emb = self.src_embedding(source_tokens) * math.sqrt(self.d_model)
        tgt_emb = self.tgt_embedding(target_tokens) * math.sqrt(self.d_model)

        # Add positional encoding
        src_emb = self.pos_encoder(src_emb)
        tgt_emb = self.pos_encoder(tgt_emb)

        # Step 1: Encode source sequence
        encoder_output = self.encoder(src_emb, 
                                      source_mask) # Shape: [batch_size, source_seq_len, d_model]

        # Step 2: Decode target sequence
        decoder_output = self.decoder(
            tgt_emb,
            encoder_output=encoder_output,
            source_mask=source_mask,
            target_mask=target_mask
        )
        # Shape: [batch_size, target_seq_len, d_model]

        # Step 3: Project to vocabulary logits
        output_logits = self.output_projection(decoder_output)
        # Shape: [batch_size, target_seq_len, target_vocab_size]

        return output_logits

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create a dummy model instance with placeholder vocab sizes
# Replace with actual vocab sizes if available
model = TransformerModel(source_vocab_size=100, 
                         target_vocab_size=100,
                         ).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")
print(f"Model device: {next(model.parameters()).device}")

Let's visualize the implemented model.

In [None]:
%%capture
%pip install torchview

In [None]:
from torchview import draw_graph

# Create sample inputs
batch_size = 2
seq_len = 10

sample_source = torch.randint(1, 16, (batch_size, seq_len)).to(device) # Move to device
sample_target = torch.randint(1, 16, (batch_size, seq_len)).to(device) # Move to device

# Visualize model
model_graph = draw_graph(
    model, # Use the dummy model instance
    input_data=[sample_source, sample_target],
    expand_nested=True
)

model_graph.visual_graph.render('images/', format='png')

Using device: cpu


'images/transformer_model_pytorch.png'

![](images/transformer_model_pytorch.png)

In [None]:
%%capture
%pip install torchinfo

In [None]:
from torchinfo import summary

print(summary(model,
              input_data=[sample_source, sample_target],
              col_names=["input_size", "output_size", "num_params"],
              depth=4))


Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape
TransformerModel                              [2, 10]                   [2, 10, 100]              --                        --
├─Embedding: 1-1                              [2, 10]                   [2, 10, 512]              51,200                    --
├─Embedding: 1-2                              [2, 10]                   [2, 10, 512]              51,200                    --
├─PositionalEncoding: 1-3                     [2, 10, 512]              [2, 10, 512]              --                        --
│    └─Dropout: 2-1                           [2, 10, 512]              [2, 10, 512]              --                        --
├─PositionalEncoding: 1-4                     [2, 10, 512]              [2, 10, 512]              --                        --
│    └─Dropout: 2-2                           [2, 10, 512]              [2, 10, 512]              -- 

## Training and Inference

Now that we’ve covered the entire forward-pass process through an untrained Transformer, it would be useful to glance at the intuition of training the model.

Since the original encoder-decoder Transformer model is useful for translation tasks, we will use an XYZ dataset:
https://www.kaggle.com/datasets/mohamedlotfy50/wmt-2014-english-german?select=wmt14_translate_de-en_train.csv

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

### Dataset Creation

We'll create a sample English-German translation dataset. This can be easily replaced with larger datasets like WMT or Multi30k later.

In [None]:
# Download data
import kagglehub
from kagglehub import KaggleDatasetAdapter

# Set the path to the file you'd like to load
file_path = "eng_-french.csv"

# Load the latest version
df = kagglehub.load_dataset(
  KaggleDatasetAdapter.PANDAS,
  "devicharith/language-translation-englishfrench",
  file_path,
  # Provide any additional arguments like
  # sql_query or pandas_kwargs. See the
  # documenation for more information:
  # https://github.com/Kaggle/kagglehub/blob/main/README.md#kaggledatasetadapterpandas
)

# Shuffle the dataframe with a fixed seed
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

df = df[:1000]
print("First 5 records:", df.head())

In [None]:
# Use the downloaded data from the dataframe 'df'
# Assuming 'df' has columns named 'English words/sentences' and 'French words/sentences'
translation_pairs = []
for index, row in df.iterrows():
    english_sentence = row['English words/sentences']
    french_sentence = row['French words/sentences']
    translation_pairs.append((english_sentence, french_sentence))

# Split into training and validation sets
train_size = int(0.8 * len(translation_pairs))
train_pairs = translation_pairs[:train_size]
val_pairs = translation_pairs[train_size:]

print(f"Total pairs: {len(translation_pairs)}")
print(f"Training pairs: {len(train_pairs)}")
print(f"Validation pairs: {len(val_pairs)}")
print(f"\\nExample pairs:")
for i, (en, fr) in enumerate(train_pairs[:3]):
    print(f"{i+1}. English: '{en}' -> French: '{fr}'")

### Data Preprocessing and Tokenization

We'll create vocabularies and tokenization functions for both English and German.

 tokens can be (words, subwords, characters)

In [None]:
class Vocabulary:
    def __init__(self):
        self.word2idx = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.idx2word = {0: "<pad>", 1: "<sos>", 2: "<eos>", 3: "<unk>"}
        self.word_count = {}
        self.n_words = 4  # Count default tokens

    def add_sentence(self, sentence):
        # Simple split by space for now - can be improved with a proper tokenizer
        for word in sentence.lower().split():
            self.add_word(word)

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.n_words
            self.idx2word[self.n_words] = word
            self.word_count[word] = 1
            self.n_words += 1
        else:
            self.word_count[word] += 1

    def sentence_to_indices(self, sentence, add_eos=True):
        indices = [self.word2idx["<sos>"]]
        for word in sentence.lower().split():
            if word in self.word2idx:
                indices.append(self.word2idx[word])
            else:
                indices.append(self.word2idx["<unk>"])
        if add_eos:
            indices.append(self.word2idx["<eos>"])
        return indices

    def indices_to_sentence(self, indices):
        words = []
        for idx in indices:
            if idx == self.word2idx["<eos>"]:
                break
            if idx not in [self.word2idx["<pad>"], self.word2idx["<sos>"]]:
                words.append(self.idx2word[idx])
        return " ".join(words)

# Create vocabularies
english_vocab = Vocabulary()
french_vocab = Vocabulary() # Changed from german_vocab

# Build vocabularies from training data
print("Building vocabularies...")
for en_sentence, fr_sentence in train_pairs: # Changed from de_sentence
    english_vocab.add_sentence(en_sentence)
    french_vocab.add_sentence(fr_sentence) # Changed from german_vocab.add_sentence

print(f"English vocabulary size: {english_vocab.n_words}")
print(f"French vocabulary size: {french_vocab.n_words}") # Changed from German

# Show some vocabulary examples
print(f"\\nEnglish words: {list(english_vocab.word2idx.keys())[:10]}")
print(f"French words: {list(french_vocab.word2idx.keys())[:10]}") # Changed from German

# Test tokenization
test_en = "Hello world"
test_fr = "Bonjour le monde" # Changed from German
en_indices = english_vocab.sentence_to_indices(test_en)
fr_indices = french_vocab.sentence_to_indices(test_fr) # Changed from German

print(f"\\nTokenization test:")
print(f"'{test_en}' -> {en_indices}")
print(f"'{test_fr}' -> {fr_indices}") # Changed from German
print(f"Back to text: '{english_vocab.indices_to_sentence(en_indices)}'")
print(f"Back to text: '{french_vocab.indices_to_sentence(fr_indices)}')") # Changed from German

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, pairs, src_vocab, tgt_vocab):
        self.pairs = pairs
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src_sentence, tgt_sentence = self.pairs[idx]

        # Convert to indices
        src_indices = self.src_vocab.sentence_to_indices(src_sentence)
        tgt_indices = self.tgt_vocab.sentence_to_indices(tgt_sentence)

        return {
            'src': torch.tensor(src_indices, dtype=torch.long),
            'tgt': torch.tensor(tgt_indices, dtype=torch.long),
            'src_text': src_sentence,
            'tgt_text': tgt_sentence
        }

def collate_fn(batch):
    """Custom collate function to pad sequences in a batch"""
    src_sequences = [item['src'] for item in batch]
    tgt_sequences = [item['tgt'] for item in batch]
    src_texts = [item['src_text'] for item in batch]
    tgt_texts = [item['tgt_text'] for item in batch]

    # Pad sequences
    src_padded = pad_sequence(src_sequences, batch_first=True, padding_value=0)
    tgt_padded = pad_sequence(tgt_sequences, batch_first=True, padding_value=0)

    return {
        'src': src_padded,
        'tgt': tgt_padded,
        'src_text': src_texts,
        'tgt_text': tgt_texts
    }

# Create datasets
train_dataset = TranslationDataset(train_pairs, english_vocab, french_vocab)
val_dataset = TranslationDataset(val_pairs, english_vocab, french_vocab)

# Create data loaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Test the data loader
print("\\nTesting data loader:")
for batch in train_loader:
    print(f"Source batch shape: {batch['src'].shape}")
    print(f"Target batch shape: {batch['tgt'].shape}")
    print(f"Source texts: {batch['src_text']}")
    print(f"Target texts: {batch['tgt_text']}")
    break

In [None]:
# Create model
model = TransformerModel(
    src_vocab_size=english_vocab.n_words,
    tgt_vocab_size=french_vocab.n_words,
    d_model=512,
    nhead=8,
    num_layers=6,
    dim_feedforward=2048,
    dropout=0.1
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")
print(f"Model device: {next(model.parameters()).device}")
# Set models to evaluation mode to disable dropout --> TODO, what does this mean?
#model.eval()

NameError: name 'english_vocab' is not defined

In [None]:

# Helper function to pad sequences
def pad_sequences(sequences, pad_token=0):
    max_len = max(len(seq) for seq in sequences)
    padded = []
    for seq in sequences:
        padded.append(seq + [pad_token] * (max_len - len(seq)))
    return torch.tensor(padded, dtype=torch.long)

# Create padded tensors for our dataset
source_tensor = pad_sequences(english_ids)
target_tensor = pad_sequences(german_ids)

print(f"Source tensor shape: {source_tensor.shape}")
print(f"Target tensor shape: {target_tensor.shape}")
print(f"Source tensor:\n{source_tensor}")
print(f"Target tensor:\n{target_tensor}")

### Training Loop

In [None]:
# Training loop
# Training configuration
learning_rate = 0.0001
num_epochs = 1#0
#patience = 10  # For early stopping

# Loss function and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding tokens
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

# Training history
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience_counter = 0

print(f"Starting training for {num_epochs} epochs...")
print(f"Learning rate: {learning_rate}")
#print(f"Patience: {patience}")

for epoch in range(num_epochs):
    # Training phase
    model.train()
    total_train_loss = 0
    num_train_steps = 0

    for batch_idx, batch in enumerate(train_loader):
        src = batch['src'].to(device)
        tgt = batch['tgt'].to(device)

        # Prepare input and target for training
        tgt_input = tgt[:, :-1]  # Remove last token (<EOS>)
        tgt_output = tgt[:, 1:]  # Remove first token (<SOS>)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(src, tgt_input)

        # Calculate loss
        loss = criterion(output.reshape(-1, output.shape[-1]), tgt_output.reshape(-1))

        # Backward pass
        loss.backward()

        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Update parameters
        optimizer.step()

        total_train_loss += loss.item()
        num_train_steps += 1

        # Print progress
        if (batch_idx + 1) % 100 == 0:
            avg_loss = total_train_loss / num_train_steps
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {avg_loss:.4f}")

    # Calculate average training loss
    avg_train_loss = total_train_loss / num_train_steps
    train_losses.append(avg_train_loss)

    # Validation phase
    model.eval()
    total_val_loss = 0
    num_val_steps = 0

    with torch.no_grad():
        for batch in val_loader:
            src = batch['src'].to(device)
            tgt = batch['tgt'].to(device)

            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            output = model(src, tgt_input)
            loss = criterion(output.reshape(-1, output.shape[-1]), tgt_output.reshape(-1))

            total_val_loss += loss.item()
            num_val_steps += 1

    avg_val_loss = total_val_loss / num_val_steps if num_val_steps > 0 else float('inf')
    val_losses.append(avg_val_loss)

    # Learning rate scheduling
    scheduler.step(avg_val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    """
    # Early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), 'best_transformer_model.pth')
        print(f"New best validation loss: {best_val_loss:.4f} - Model saved!")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch + 1} epochs")
            break
    """

    print("-" * 50)

print("Training completed!")

# Load best model
#model.load_state_dict(torch.load('best_transformer_model.pth'))
print(f"Best validation loss: {best_val_loss:.4f}")

### Evaluation and Metrics

Let's visualize the training progress and evaluate the model performance.

In [None]:
# Plot training and validation loss
plt.figure(figsize=(12, 4))

# Plot on the first subplot
plt.plot(train_losses, 'b-', label='Training Loss')
plt.plot(val_losses, 'r-', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.show()

print(f"Final training loss: {train_losses[-1]:.4f}")
print(f"Final validation loss: {val_losses[-1]:.4f}")

# Calculate accuracy on validation set
def calculate_accuracy(model, data_loader, vocab):
    model.eval()
    correct_tokens = 0
    total_tokens = 0

    with torch.no_grad():
        for batch in data_loader:
            src = batch['src'].to(device)
            tgt = batch['tgt'].to(device)

            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            output = model(src, tgt_input)
            predictions = output.argmax(dim=-1)

            # Only count non-padding tokens
            mask = (tgt_output != 0)
            correct_tokens += ((predictions == tgt_output) & mask).sum().item()
            total_tokens += mask.sum().item()

    return correct_tokens / total_tokens if total_tokens > 0 else 0

val_accuracy = calculate_accuracy(model, val_loader, french_vocab)
print(f"Validation accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")

In [None]:
def translate_sentence(model, sentence, src_vocab, tgt_vocab, max_length=50):
    """
    Translate a single sentence using the trained model
    """
    model.eval()

    # Convert input sentence to indices
    src_indices = src_vocab.sentence_to_indices(sentence, add_eos=True)
    src_tensor = torch.tensor([src_indices], dtype=torch.long).to(device)

    # Initialize target with SOS token
    tgt_indices = [tgt_vocab.word2idx["<sos>"]]

    with torch.no_grad():
        for _ in range(max_length):
            tgt_tensor = torch.tensor([tgt_indices], dtype=torch.long).to(device)

            # Get model prediction
            output = model(src_tensor, tgt_tensor)

            # Get the prediction for the last token
            next_token_logits = output[0, -1, :]
            next_token = next_token_logits.argmax().item()

            # Add predicted token to target sequence
            tgt_indices.append(next_token)

            # Stop if we predict EOS token
            if next_token == tgt_vocab.word2idx["<eos>"]:
                break

    # Convert indices back to sentence
    translated_sentence = tgt_vocab.indices_to_sentence(tgt_indices)
    return translated_sentence

# Test on validation examples
for i, (en_sentence, de_sentence) in enumerate(val_pairs[:5]):
    translation = translate_sentence(model, en_sentence, english_vocab, french_vocab)
    print(f"Example {i+1}:")
    print(f"English: '{en_sentence}'")
    print(f"Expected: '{de_sentence}'")
    print(f"Generated: '{translation}'")
    print("-" * 30)

## Discussion

### Architecture Variations
<!-- What are the differences between encoder-only, decoder-only, and encoder-decoder transformers? -->
Transformer-based models have three primary variations:
Each variation has minor architectureal differences that make them suitable for süecific tasks

- encoder-only
    - architecture:
    - how it works: proceesses input sequence as a whole and makes predictions about it
    - use case: sentiment analysis, sentence classification, named entity recognition (NER), used for tasks that required understanding the overall meaning of a text
    - Examples: Google's BERT, Meta's RoBERTa
- decoder-only
    - architecture:
    - specifically designed to generate new seqeunces
    - how it works: processes input sequence and generates a new sequence iteratively
    - use case: text generation
    - examples: most LLMs, scuha s OpenAI's GPT-4, Meta's LLaMa, Google's Gemini, Anthropic's Claude, xAI's Grok
- The original Transformer is an encoder-decoder: 
    - The original Transformer architecture, introduced in the paper "Attention Is All You Need" is an encoder-decoder architecture.
    - encoder component processes input seuaence and decoder uses that processes information to gnerate the output sequence
    - use case: where the output is a transformation of the input, translation
    - examples: MEta's BART, Google's T5

![](./images/encoder-only_vs_decoder-only_transformer.png)

### Limitations of the Attention mechanism

The main limitation of the original Transformer architecture and it's derivatives is the self-attention mechanism's [**quadratic memory and computational requirements**](https://research.google/blog/constructing-transformers-for-longer-sequences-with-sparse-attention-methods/) ($O(N^2)$) with respect to the input sequence length ($N$). 

<!-- Why -->
- Memory: Storing the attention matrix (N×N) requires O(N²) space
- Computational: Computing all pairwise attention scores requires O(N²) operations

![](./images/attention_visualized.png)

<!-- Why do we want to process long sequences? -->
This quadratic scaling was historically a major limitation that made Transformers expensive or unsuitable for tasks requiring long input sequences. For example, they couldn't effectively process entire articles or books for document summarization or long-form question answering.

<!-- How to solve it -->
 but modern techniques have enabled 
 Various techniques are introduced to reduce the complexity of attention.
- There have been many $O(N)$ approximations, such as Linear Transformer, but they always come with a [speed and quality trade-off](https://arxiv.org/abs/2011.04006) (2020)
- Quadratic memory isn't required for the attention mechanism if you're using [FlashAttention](https://arxiv.org/abs/2205.14135) (2022), [FlashAttention-2](https://arxiv.org/abs/2307.08691) (2023), /FlashAttention removes the quadratic memory requirement
- Mamba
- to cite: Efficient Transformers: A Survey" by Yi Tay, Mostafa Dehghani, Dara Bahri, Donald Metzler
- groupattention

As an example, an average novel has about 100,000 words. 
The rule of thumb is that one English word is about 1.3 tokens.
That means, an average novel has about 130,000 tokens.


## References
- [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762)
- [The Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
- [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)
<!-- https://www.youtube.com/watch?v=rBCqOTEfxvg -->
- https://peterbloem.nl/blog/transformers
- https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html
- https://docs.google.com/presentation/d/1ZXFIhYczos679r70Yu8vV9uO6B1J0ztzeDxbnBxD1S0/edit?slide=id.g13dd67c5ab8_0_2543#slide=id.g13dd67c5ab8_0_2543 
- https://lilianweng.github.io/posts/2018-06-24-attention/

In [None]:
#| echo: false
#| output: false
a=b
import math

class Embeddings(nn.Module):
    def __init__(self, vocab_size, d_model):
        """
        Args:
            vocab_size: size of vocabulary
            d_model: dimension of embeddings
        """
        super(Embeddings, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        """
        Args:
            x: input vector
        Returns:
            out: scaled embedding vector
        """
        # Scale by sqrt(d_model) from original paper
        return self.embedding(x) * math.sqrt(self.d_model)

In [None]:
#| echo: false
#| output: false
class ResidualConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super(ResidualConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))


class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


In [None]:
#| echo: false
#| output: false
def encode(self, source_tokens, source_mask=None):
    """Encode source sequence (for inference)"""
    return self.encoder(source_tokens, source_mask)

def decode_step(self, target_tokens, encoder_output, source_mask=None, target_mask=None):
    """Decode one step (for autoregressive generation)"""
    decoder_output = self.decoder(target_tokens, encoder_output, source_mask, target_mask)
    return self.output_projection(decoder_output)