<a href="https://colab.research.google.com/github/dhrits/LLM-Engineering-Foundations-to-SLMs/blob/main/03_Attention/Focusing_in_on_Attention_Transformer_from_Scratch_Hardmode_Version.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Encoder-Decoder Transformer Model from Scratch in PyTorch

In today's notebook, we'll be focusing on the main engine of the Transformer - ATTENTION!

In [43]:
!pip install -qU torch=='2.4.1+cu121' torchvision=='0.19.1+cu121' torchaudio=='2.4.1+cu121' --index-url https://download.pytorch.org/whl/cu121

In [44]:
!pip install -qU flash-attn

# 🤝 BREAKOUT ROOM #1:

# The Building Block Fundamentals of Transformer Architecture

We're going to start with an example of an encoder-decoder model - the kind found in the classic paper:

[Attention is All You Need](https://arxiv.org/pdf/1706.03762.pdf).

We'll walk through each step in code - leveraging the [PyTorch]() library heavily - in order to get an idea of how these models work.

While this example notebook could be extended to a sincere usecase - we'll be using a toy dataset, and we will not fully train the model until it converges (under-train), as the full training process might take many days!

## The Desired Architecture

![image](https://i.imgur.com/YPjbqW6.png)

We'll skip over the diagram for now, and talk through each component in detail!

In [45]:
import torch
import torch.nn as nn
import math
from typing import Optional
from dataclasses import dataclass
from enum import Enum

## Embedding

![image](https://i.imgur.com/sFlEZ2e.png)

The first step will be do convert our tokenized sequence of inputs into an embedding vector. This allows use to understand a rich amount of information about input sequences and their semantic meanings.

As the embedding layer will be training along side the rest of the model - it will allow us to have an excellent vector-representation of the tokens in our dataset.

Let's see how it looks in code!

In [46]:
class InputEmbeddings(nn.Module):
  def __init__(self, d_model: int, vocab_size: int, verbose=False) -> None:
    """
    vocab_size - the size of our vocabulary
    d_model - the dimension of our embeddings and the input dimension for our model
    """
    super().__init__()
    self.vocab_size = vocab_size
    self.d_model = d_model
    self.embedding = nn.Embedding(vocab_size, d_model)
    self.verbose = verbose

  def forward(self, x):
    if self.verbose:
      print(f"Embedding Vector (1st 5 elements): {self.embedding(x)[:5] * math.sqrt(self.d_model)}")
    return self.embedding(x) * math.sqrt(self.d_model) # scale embeddings by square root of d_model

### Test Embedding Layer

We'll set up a sample Embedding Layer and then test that it does what we'd expect!

In [47]:
def test_input_embeddings_with_example():
    # Create a small embedding layer
    embed = InputEmbeddings(d_model=512, vocab_size=1000)

    # Example sentence tokens (simplified)
    tokens = torch.tensor([[1, 2, 3, 4, 5]])  # "The cat sat down quickly"

    output = embed(tokens)
    print(f"Input shape: {tokens.shape}")
    print(f"Output shape: {output.shape}")
    print("\nExample shows how words are converted to high-dimensional vectors")

    # Run technical test
    assert output.shape == (1, 5, 512), f"Expected shape (1, 5, 512), got {output.shape}"
    print("✓ Input Embeddings Test Passed")

In [48]:
test_input_embeddings_with_example()

Input shape: torch.Size([1, 5])
Output shape: torch.Size([1, 5, 512])

Example shows how words are converted to high-dimensional vectors
✓ Input Embeddings Test Passed


## Positional Encoding

![image](https://i.imgur.com/IIA3NK3.png)

We need to impart information about where each token is in the sequence, but we aren't using any recurrence or convolutions - the easiest way to encode positional information is to inject positional information into our input embeddings.

We're going to use the process outlined in the paper to do this - which is to use a specific combination of functions to add positional information to the embeddings.

In [49]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model: int, seq_len: int, dropout: float, verbose=False) -> None:
    super().__init__()
    self.d_model = d_model
    self.seq_len = seq_len
    self.dropout = nn.Dropout(dropout)
    self.verbose=verbose

    positional_embeddings = torch.zeros(seq_len, d_model)
    positional_sequence_vector = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    positional_model_vector = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    positional_embeddings[:, 0::2] = torch.sin(positional_sequence_vector * positional_model_vector)
    positional_embeddings[:, 1::2] = torch.cos(positional_sequence_vector * positional_model_vector)
    positional_embeddings = positional_embeddings.unsqueeze(0)

    self.register_buffer('positional_embeddings', positional_embeddings)

  def forward(self, x):
    x = x + (self.positional_embeddings[:, :x.shape[1], :]).requires_grad_(False)
    if self.verbose:
      print(f"Positional Encodings (1st 5 elements): {x}")
    return self.dropout(x)

### Test Positional Encoding Layer

We'll set up a sample Positional Encoding Layer and then test that it does what we'd expect!

In [50]:
def test_positional_encoding_with_example():
    pos = PositionalEncoding(d_model=512, seq_len=10, dropout=0.1)

    # Create sample embeddings for "The cat sat"
    x = torch.randn(1, 3, 512)

    output = pos(x)
    print("Input tokens position:  [1, 2, 3]")
    print("Added position info to each word's embedding")
    print(f"Output maintains shape: {output.shape}")

    # Verify position information was added
    assert not torch.allclose(output, x), "Position information should modify embeddings"
    print("✓ Positional Encoding Test Passed")

In [51]:
test_positional_encoding_with_example()

Input tokens position:  [1, 2, 3]
Added position info to each word's embedding
Output maintains shape: torch.Size([1, 3, 512])
✓ Positional Encoding Test Passed


## Add & Norm

Next we'll tackle the Add & Norm Block of the diagram.

![image](https://i.imgur.com/otdEq4D.png)

### Layer Normalization

The first step is to add layer normalization. You can read more about it [here](https://paperswithcode.com/method/layer-normalization)!

The basic idea is that it makes training the model a bit easier, and allows the model to generalize a bit better.

In [52]:
class LayerNormalization(nn.Module):
  def __init__(self, features: int, epsilon:float=10**-6) -> None:
    super().__init__()
    self.epsilon = epsilon
    self.gamma = nn.Parameter(torch.ones(features))
    self.beta = nn.Parameter(torch.zeros(features))

  def forward(self, x):
    mean = x.mean(dim = -1, keepdim = True)
    standard_deviation = x.std(dim = -1, keepdim = True)
    return self.gamma * (x - mean) / (standard_deviation + self.epsilon) + self.beta

### Test Layer Normalization

We'll set up a sample Layer Normalization and then test that it does what we'd expect!

In [53]:
def test_layer_normalization_with_example():
    layer_norm = LayerNormalization(features=3)  # Smaller feature size for example

    # Simulate word embeddings with different magnitudes
    word_embeddings = torch.tensor([
        [2.5, 4.1, -3.2],  # "The" (high magnitude)
        [0.1, 0.2, -0.1],  # "cat" (low magnitude)
        [8.2, -6.1, 5.5]   # "sat" (very high magnitude)
    ]).unsqueeze(0)

    normalized = layer_norm(word_embeddings)

    print("Before normalization (magnitudes vary greatly):")
    print(word_embeddings[0])
    print("\nAfter normalization (values scaled to similar ranges):")
    print(normalized[0])

    # Verify statistical properties
    mean = normalized.mean(dim=-1)
    var = normalized.var(dim=-1)
    assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-5)
    assert torch.allclose(var, torch.ones_like(var), atol=1e-5)
    print("✓ Layer Normalization Test Passed")

In [54]:
test_layer_normalization_with_example()

Before normalization (magnitudes vary greatly):
tensor([[ 2.5000,  4.1000, -3.2000],
        [ 0.1000,  0.2000, -0.1000],
        [ 8.2000, -6.1000,  5.5000]])

After normalization (values scaled to similar ranges):
tensor([[ 0.3562,  0.7732, -1.1293],
        [ 0.2182,  0.8729, -1.0911],
        [ 0.7459, -1.1363,  0.3905]], grad_fn=<SelectBackward0>)
✓ Layer Normalization Test Passed


### Residual Connection

Another technique that makes model training easier, we add a Residual connection to the outputs of the Attention Block - this helps to prevent vanishing gradient.

In [55]:
class ResidualConnection(nn.Module):
  def __init__(self, features: int, dropout: float = 0.1) -> None:
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.layernorm = LayerNormalization(features)

  def forward(self, x, sublayer):
    return x + self.dropout(sublayer(self.layernorm(x)))

### Testing Residual Connection

We'll set up a sample Residual Connection and then test that it does what we'd expect!

In [56]:
def test_residual_connection_with_example():
    residual = ResidualConnection(features=3, dropout=0.1)

    # Original input "The cat"
    x = torch.tensor([
        [1.0, 1.0, 1.0],
        [2.0, 2.0, 2.0]
    ]).unsqueeze(0)

    # Sublayer that makes meaningful changes
    def sublayer(x):
        return torch.nn.functional.relu(x + 0.5) # Non-linear transformation

    output = residual(x, sublayer)

    print("Original input:")
    print(x[0])
    print("\nAfter residual connection (combines original + transformed):")
    print(output[0])

    # Verify output changed but maintained shape
    assert output.shape == x.shape
    assert torch.any(torch.abs(output - x) > 1e-6), "Output should differ from input"
    print("✓ Residual Connection Test Passed")

In [57]:
test_residual_connection_with_example()

Original input:
tensor([[1., 1., 1.],
        [2., 2., 2.]])

After residual connection (combines original + transformed):
tensor([[1.5556, 1.5556, 1.5556],
        [2.5556, 2.0000, 2.5556]], grad_fn=<SelectBackward0>)
✓ Residual Connection Test Passed


## Feed Forward Network

![image](https://i.imgur.com/woEqBjQ.png)

Moving onto the next component, we have our feed forward network.

The feed forward networks servers two purposes in our model:

1. It reforms the attention outputs into a format that works with the next block.

2. It helps add complexity to prevent each attention block acting in a similar fashion.

In [58]:
class FeedForwardBlock(nn.Module):
  def __init__(self, d_model: int, d_ff: int = 2048, dropout: float = 0.1) -> None:
    """
    d_model - dimension of model
    d_ff - dimension of feed forward network
    dropout - regularization measure
    """
    super().__init__()
    self.linear_1 = nn.Linear(d_model, d_ff)
    self.dropout = nn.Dropout(dropout)
    self.linear_2 = nn.Linear(d_ff, d_model)

  def forward(self, x):
    return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

### Testing the Feed-forward Block

Let's test!

In [59]:
def test_feed_forward_block_with_example():
   ff_block = FeedForwardBlock(d_model=3, d_ff=8)  # Small dimensions for demonstration

   # Input: Word embeddings for "The cat"
   x = torch.tensor([
       [1.0, 0.5, 0.2],  # "The"
       [2.0, -0.3, 1.1]  # "cat"
   ]).unsqueeze(0)

   output = ff_block(x)

   print("Input embeddings:")
   print(x[0])
   print("\nAfter feed-forward transformation:")
   print(output[0])

   # First linear layer expands to d_ff dimensions
   # ReLU keeps only positive values
   # Second linear layer projects back to d_model dimensions
   assert output.shape == x.shape
   assert not torch.allclose(output, x)
   print("✓ Feed Forward Block Test Passed")

In [60]:
test_feed_forward_block_with_example()

Input embeddings:
tensor([[ 1.0000,  0.5000,  0.2000],
        [ 2.0000, -0.3000,  1.1000]])

After feed-forward transformation:
tensor([[-0.0450, -0.1149, -0.0714],
        [ 0.1952, -0.3175,  0.1493]], grad_fn=<SelectBackward0>)
✓ Feed Forward Block Test Passed


## Task #1: Multi-Head Attention

![image](https://i.imgur.com/4qOT46y.png)

Next up is the heart and soul of the Transformer - Multi-Head Attention.

We'll break it down into the basic building blocks in code in the following section!

### Task 2: Multi-Head Attention Class



In [61]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model: int = 512, num_heads: int = 8, dropout: float = 0.1) -> None:
    super().__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    assert d_model % num_heads == 0, "d_model is not divisible by h"

    self.d_k = d_model // num_heads

    self.w_q = nn.Linear(d_model, d_model, bias=False)
    self.w_k = nn.Linear(d_model, d_model, bias=False)
    self.w_v = nn.Linear(d_model, d_model, bias=False)

    self.w_o = nn.Linear(d_model, d_model, bias=False)

    self.dropout = nn.Dropout(dropout)

### 👪❓Discussion Question 1:

Describe, in your own words, intuitively what each of `Q`, `K`, and `V` represent - beyond their names and functions.

Discuss among your group!

There is no single correct answer - and if you get stuck - feel free to ask your favourite Large Language Model!

I've listed out a more pedantic definition at the bottom, but have also tried to describe `query`, `key` and `value` and in terms of various analogies and metaphors.

#### Amazon Search
* `query` - When searching for a product on Amazon, the `query` is the search term a user enters into the search box.
* `key` - These are the descriptions of various objects on the search results page against which the `query` is matched by the user to find the objects of interest.
* `value` - This is the actual object description on its dedicated page (or the object itself).

The above also applies to any search system.

#### Table of Contents of a Book
* `query` - This might be a topic of interest which a user is looking for in a book.
* `key` - This is the table of contents of the book which the user will look in to try and match the topic of interest.
* `value` - This is the actual passages/chapters in the book with the information.


More formally:

* Q: or `query` Represents a transformation of the current token looking for relevant information in the rest of the sequence.
* K: or `key` can be thought of as the address (a transformation) of the rest of (really all of) the sequence which can be matched against Q to find similarities.
* V: or `value` represents a transformation of the rest of (all of) of sequence which holds the actual information needed by the query Q.
O: or output is a final linear transformation applied after attention.


### Testing Multi-Head Attention

Let's test it out!

### Task #3: Scaled Dot-Product Attention

![image](https://i.imgur.com/Yp48DuB.png)

#### 🏗️ Activity #1:

Implement the above!

In [62]:
def attention(query, key, value, mask, d_k, dropout: nn.Dropout = None):

  ### YOUR CODE HERE
  attention_scores = (query @ key.transpose(-2, -1))/math.sqrt(d_k)

  if mask is not None:
    attention_scores = attention_scores.masked_fill(mask == 0, -1e9)

  attention_scores = attention_scores.softmax(dim=-1)

  if dropout is not None:
    attention_scores = dropout(attention_scores)

  return (attention_scores @ value), attention_scores

### ❓Question 1:

Describe the above code that defines the attension mechanism.

Write, in natural language, what each step is doing.

Describe the above code that defines the attension mechanism.

Write, in natural language, what each step is doing.

1. The very first step is a matrix multiplication (or dot product) between the query and the key which finds similarity between the current tokens and all other tokens. The dot-product is a measure of vector similarity.
2. We also scale the magnitude of the dot-product, `attention_scores` by dividing it by `sqrt(d_k)` which ensures the magnitude of the dot product doesn't overwhelm the softmax calculation which comes after.
3. We then apply the `mask`. This has many functions. In the case of the decoder, it prevents attention scores from looking forward in the sequence when modeling auto-regressing behavior. It may also mask out `[PAD]` or other unnecessary tokens.
4. A `softmax` is applied over the attention scores which creates a valid probability distribution.
5. `dropout` is applied as a means of regularization to prevent overfitting.
6. Finally a matmul between `attention_scores` and `value` extracts the relevant attention output.

### Task #4: Forward Method

This is code is required to do a forward pass with our model.

In [63]:
def forward(self, query, key, value, mask):
  query = self.w_q(query)
  key = self.w_k(key)
  value = self.w_v(value)

  query = query.view(query.shape[0], -1, self.num_heads, self.d_k).transpose(1, 2).contigous()
  key = key.view(key.shape[0], -1, self.num_heads, self.d_k).transpose(1, 2).contigous()
  value = value.view(value.shape[0], -1, self.num_heads, self.d_k).transpose(1, 2).contigous()

  x, self.attention_scores = attention(query, key, value, mask, self.d_k, self.dropout)

  x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.d_model)

  return self.w_o(x)

### Task #5: Combining it All Together - and Introducing Flash Attention

We'll start with implementing a simple set of `configs` - and then get straight into an Attention implementation!

In [64]:
class AttentionType(Enum):
    REGULAR = "regular"
    FLASH = "flash"

@dataclass
class AttentionConfig:
    attention_type: AttentionType
    d_model: int = 512
    num_heads: int = 8
    dropout: float = 0.1

All we need to do is tie together the above steps into one big block!

In [89]:
class ModularAttention(nn.Module):
    def __init__(self, config: AttentionConfig):
        super().__init__()
        self.config = config
        self.d_model = config.d_model
        self.num_heads = config.num_heads
        self.d_k = config.d_model // config.num_heads

        self.w_q = nn.Linear(config.d_model, config.d_model, bias=False)
        self.w_k = nn.Linear(config.d_model, config.d_model, bias=False)
        self.w_v = nn.Linear(config.d_model, config.d_model, bias=False)
        self.w_o = nn.Linear(config.d_model, config.d_model, bias=False)
        self.dropout = config.dropout
        if isinstance(self.dropout, float):
          self.dropout_layer = nn.Dropout(self.dropout)
        elif isinstance(self.dropout, nn.Dropout):
          self.dropout_layer = self.dropout

    def _regular_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                         mask: Optional[torch.Tensor]) -> torch.Tensor:
        d_k = q.shape[-1]
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = torch.softmax(scores, dim=-1)
        if self.training:
            attention_weights = torch.dropout(attention_weights, self.dropout, self.training)

        return attention_weights @ v

    def _flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                    mask: Optional[torch.Tensor]) -> torch.Tensor:
        try:
            from flash_attn import flash_attn_func

            if not all(x.dtype == torch.float16 for x in [q, k, v]):
                print("Warning: Inputs must be in float16 for Flash Attention")
                return self._regular_attention(q, k, v, mask)

            if not all(x.is_cuda for x in [q, k, v]):
                print("Warning: Inputs must be on CUDA for Flash Attention")
                return self._regular_attention(q, k, v, mask)

            if mask is not None:
                mask = mask.to(dtype=torch.float32)

            dropout_p = self.dropout if self.training else 0.0

            q = q.contiguous()
            k = k.contiguous()
            v = v.contiguous()

            return flash_attn_func(
                q, k, v,
                dropout_p=dropout_p,
                causal=False,
                softmax_scale=None
            )

        except ImportError:
            print("Flash Attention not available, falling back to regular attention")
            return self._regular_attention(q, k, v, mask)
        except Exception as e:
            print(f"Error using Flash Attention: {str(e)}")
            print("Falling back to regular attention")
            return self._regular_attention(q, k, v, mask)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size = query.shape[0]

        q = self.w_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()

        if self.config.attention_type == AttentionType.FLASH:
            attn_output = self._flash_attention(q, k, v, mask)
        else:
            attn_output = self._regular_attention(q, k, v, mask)

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.w_o(attn_output)

#### 👪❓Discussion Question 2:

Work with your breakout room to build a visualization of the above `forward` pass - including dummy inputs.

You can use drawing programs (like [Excalidraw](https://excalidraw.com/)), or write a visualization in code.

> NOTE: LLMs like Claude 3.5 Sonnet is a fantastic tool to produce (and test) visualizations in it's [Web UI](https://claude.ai/new

![attention visualization](./attention_visualization.png)

### Testing MultiHeadAttention

Let's test it out!

In [90]:
def test_attention_mechanisms():
    attention_configs = [
        AttentionConfig(attention_type=AttentionType.REGULAR, d_model=6, num_heads=2, dropout=0.1),
        AttentionConfig(attention_type=AttentionType.FLASH, d_model=6, num_heads=2, dropout=0.1)
    ]

    # Create sequence in float16
    seq = torch.tensor([
        [1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 0.0, 0.0, 1.0, 1.0]
    ], dtype=torch.float16).unsqueeze(0).cuda()

    mask = None

    for config in attention_configs:
        try:
            print(f"\nTesting {config.attention_type.value} attention:")
            print("-" * 50)

            mha = ModularAttention(config).cuda().half()

            # Convert model weights to float16
            for param in mha.parameters():
                param.data = param.data.half()

            output = mha(seq, seq, seq, mask)

            print(f"Input shape: {seq.shape}")
            print("Input values:")
            print(seq[0])
            print("\nOutput values:")
            print(output[0])

            assert output.shape == seq.shape
            assert not torch.allclose(output, seq)
            print(f"✓ {config.attention_type.value.title()} Attention Test Passed")

        except ImportError:
            print(f"Flash Attention not available, skipping {config.attention_type.value} test")
        except Exception as e:
            print(f"Error testing {config.attention_type.value} attention: {str(e)}")

        # Cleanup to prevent CUDA memory issues
        torch.cuda.empty_cache()

In [91]:
test_attention_mechanisms()


Testing regular attention:
--------------------------------------------------
Input shape: torch.Size([1, 3, 6])
Input values:
tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)

Output values:
tensor([[ 0.1278, -0.2284,  0.0298, -0.3015,  0.0834,  0.0563],
        [ 0.1577, -0.2035, -0.0206, -0.2812,  0.0796,  0.0403],
        [ 0.0676, -0.1810,  0.0102, -0.2306,  0.0589,  0.0561]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SelectBackward0>)
✓ Regular Attention Test Passed

Testing flash attention:
--------------------------------------------------
Input shape: torch.Size([1, 3, 6])
Input values:
tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)

Output values:
tensor([[ 0.2380, -0.4624,  0.0812,  0.1268,  0.1678, -0.4475],
        [-0.0828,  0.0969,  0.0195, -0.0183, -0.0431,  0.0547],
       

In [92]:
class TransformerBlock(nn.Module):
    def __init__(self,
                 features: int,
                 attn_config: AttentionConfig,
                 feed_forward: nn.Module,
                 dropout: float,
                 is_decoder: bool = False):
        super().__init__()
        self.is_decoder = is_decoder
        self.self_attention = ModularAttention(attn_config)
        self.feed_forward = feed_forward

        num_connections = 3 if is_decoder else 2
        self.residual_connections = nn.ModuleList([
            ResidualConnection(features, dropout) for _ in range(num_connections)
        ])

        if is_decoder:
            self.cross_attention = ModularAttention(attn_config)

    def forward(self, x: torch.Tensor,
                encoder_output: Optional[torch.Tensor] = None,
                self_mask: Optional[torch.Tensor] = None,
                cross_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = self.residual_connections[0](x, lambda x: self.self_attention(x, x, x, self_mask))

        if self.is_decoder:
            if encoder_output is None:
                raise ValueError("Decoder block requires encoder_output")
            x = self.residual_connections[1](x, lambda x: self.cross_attention(x, encoder_output, encoder_output, cross_mask))
            return self.residual_connections[2](x, self.feed_forward)

        return self.residual_connections[1](x, self.feed_forward)

## Encoder

When we pass information through our model - the first thing we will do is Encode it by passing it through our Encoder Blocks.


### Encoder Block

![image](https://i.imgur.com/nwNYZAT.png)

The encoder takes in the source language sentence (e.g. English). Each word is converted into a vector representation using an embedding layer. Then a positional encoder adds information about the position of each word. This goes through multiple self-attention layers, where each word vector attends to all other word vectors to build contextual representations.

In [93]:
class EncoderBlock(nn.Module):
  def __init__(self, features: int, self_attention_block: MultiHeadAttention, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
    super().__init__()
    self.self_attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

  def forward(self, x, input_mask):
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, input_mask))
    x = self.residual_connections[1](x, self.feed_forward_block)
    return x

### Testing the EncoderBlock

Testing time!

In [94]:
def test_encoder_blocks():
    # Use float32 for regular attention, float16 only for flash attention
    input_seq = torch.tensor([
        [1.0, 1.0, 0.0, 0.0, 0.0, 0.0],  # "The"
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0],  # "cat"
        [0.0, 0.0, 0.0, 0.0, 1.0, 1.0]   # "sleeps"
    ]).unsqueeze(0).cuda()

    mask = torch.ones(1, 1, 3, 3, dtype=torch.bool).cuda()

    attention_configs = [
        AttentionConfig(AttentionType.REGULAR, d_model=6, num_heads=2, dropout=0.1),
        AttentionConfig(AttentionType.FLASH, d_model=6, num_heads=2, dropout=0.1)
    ]

    for config in attention_configs:
        try:
            print(f"\nTesting Encoder with {config.attention_type.value} attention:")
            print("-" * 50)

            attention = ModularAttention(config).cuda()
            ff = FeedForwardBlock(d_model=6, d_ff=12).cuda()
            encoder = TransformerBlock(
                features=6,
                attn_config=config,
                feed_forward=ff,
                dropout=0.1
            ).cuda()

            # Convert to float16 only for flash attention
            if config.attention_type == AttentionType.FLASH:
                input_seq = input_seq.half()
                encoder = encoder.half()
                for param in encoder.parameters():
                    param.data = param.data.half()

            output = encoder(input_seq, self_mask=mask)

            print("Input sequence:")
            print(input_seq[0])
            print("\nAfter encoder processing:")
            print(output[0])

            assert output.shape == input_seq.shape
            assert not torch.allclose(output, input_seq)
            print(f"✓ Encoder Block with {config.attention_type.value} attention Test Passed")

        except ImportError:
            print(f"Flash Attention not available, skipping {config.attention_type.value} test")
        except Exception as e:
            print(f"Error testing encoder with {config.attention_type.value} attention: {str(e)}")

        # Reset to float32 for next iteration
        input_seq = input_seq.float()
        torch.cuda.empty_cache()

In [95]:
test_encoder_blocks()


Testing Encoder with regular attention:
--------------------------------------------------
Input sequence:
tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1.]], device='cuda:0')

After encoder processing:
tensor([[ 1.5171,  1.1446,  0.3976, -0.9150, -0.2350, -0.4203],
        [ 0.3453, -0.1121,  1.3127,  0.9664,  0.2854,  0.0253],
        [ 0.1391,  0.6917,  0.1248, -0.3540,  1.3931,  0.9762]],
       device='cuda:0', grad_fn=<SelectBackward0>)
✓ Encoder Block with regular attention Test Passed

Testing Encoder with flash attention:
--------------------------------------------------
Input sequence:
tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0.],
        [0., 0., 0., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)

After encoder processing:
tensor([[ 0.7783,  1.0234, -0.3586, -0.0851, -0.3760, -0.1392],
        [ 0.5840,  0.1870,  0.8926,  1.1494,  0.5498,  0.4712],
        [ 0.0084,  0.0208, -0.6611, -0.3779,  0.

### Encoder Stack

Following along from the original paper - we will organize these blocks into a set of 6.

These 6 Encoder Blocks (each with 8 Attention Heads) will comprise our Encoding Stack.

In [96]:
class EncoderStack(nn.Module):
  def __init__(self, features: int, layers: nn.ModuleList) -> None:
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization(features)

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

## Decoder

Next, we will take the encoded sequence and decode it through our Decoder Blocks.

### Decoder Block

![image](https://i.imgur.com/HtAAXZc.png)

The decoder takes in the target language sentence (e.g. Italian). It also converts words to vectors and adds positional info. Then it goes through self-attention layers. Here, a mask is applied so each word can only see the words before it, not after.

The decoder also does attention over the encoder output. This allows each French word to find relevant connections with the English words.

In [97]:
class DecoderBlock(nn.Module):
  def __init__(self, features: int, self_attention_block: MultiHeadAttention, cross_attention_block: MultiHeadAttention, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

  def forward(self, x, encoder_output, input_mask, target_mask):
    x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, target_mask))
    x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, input_mask))
    x = self.residual_connections[2](x, self.feed_forward_block)
    return x

### Testing DecoderBlock

You know what's up next...testing!

In [98]:
def test_decoder_blocks():
    x = torch.tensor([
        [1.0, 1.0, 0.0, 0.0, 0.0, 0.0],  # "El"
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0],  # "gato"
    ]).unsqueeze(0).cuda()

    encoder_output = torch.tensor([
        [1.0, 1.0, 0.0, 0.0, 0.0, 0.0],  # "The"
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0],  # "cat"
    ]).unsqueeze(0).cuda()

    cross_mask = torch.ones(1, 1, 2, 2, dtype=torch.bool).cuda()
    self_mask = torch.tril(torch.ones(1, 1, 2, 2, dtype=torch.bool)).cuda()

    attention_configs = [
        AttentionConfig(AttentionType.REGULAR, d_model=6, num_heads=2, dropout=0.1),
        AttentionConfig(AttentionType.FLASH, d_model=6, num_heads=2, dropout=0.1)
    ]

    for config in attention_configs:
        try:
            print(f"\nTesting Decoder with {config.attention_type.value} attention:")
            print("-" * 50)

            ff = FeedForwardBlock(d_model=6, d_ff=12).cuda()
            decoder = TransformerBlock(
                features=6,
                attn_config=config,
                feed_forward=ff,
                dropout=0.1,
                is_decoder=True
            ).cuda()

            if config.attention_type == AttentionType.FLASH:
                x = x.half()
                encoder_output = encoder_output.half()
                decoder = decoder.half()
                for param in decoder.parameters():
                    param.data = param.data.half()

            output = decoder(x, encoder_output=encoder_output,
                           self_mask=self_mask, cross_mask=cross_mask)

            print("Input target sequence:")
            print(x[0])
            print("\nSource (encoder) sequence:")
            print(encoder_output[0])
            print("\nDecoder output:")
            print(output[0])

            assert output.shape == x.shape
            assert not torch.allclose(output, x)
            print(f"✓ Decoder Block with {config.attention_type.value} attention Test Passed")

        except ImportError:
            print(f"Flash Attention not available, skipping {config.attention_type.value} test")
        except Exception as e:
            print(f"Error testing decoder with {config.attention_type.value} attention: {str(e)}")

        x = x.float()
        encoder_output = encoder_output.float()
        torch.cuda.empty_cache()

In [99]:
test_decoder_blocks()


Testing Decoder with regular attention:
--------------------------------------------------
Input target sequence:
tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0.]], device='cuda:0')

Source (encoder) sequence:
tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0.]], device='cuda:0')

Decoder output:
tensor([[ 0.8096,  0.7881, -0.4624,  0.0876,  0.0038, -0.4498],
        [-0.2040, -0.0768,  1.0516,  1.4122,  0.2071,  0.0495]],
       device='cuda:0', grad_fn=<SelectBackward0>)
✓ Decoder Block with regular attention Test Passed

Testing Decoder with flash attention:
--------------------------------------------------
Input target sequence:
tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0.]], device='cuda:0', dtype=torch.float16)

Source (encoder) sequence:
tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0.]], device='cuda:0', dtype=torch.float16)

Decoder output:
tensor([[ 0.5659,  0.4185,  0.1379,  0.2859,  0.2703, -0.4829],
 

### Decoder Stack

We'll use the same number of Decoder Blocks as we did Encoder Blocks - leaving us with 6 Deocder Blocks in our Decoder Stack.

In [100]:
class DecoderStack(nn.Module):
  def __init__(self, features: int, layers: nn.ModuleList) -> None:
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization(features)

  def forward(self, x, encoder_output, input_mask, target_mask):
    for layer in self.layers:
      x = layer(x, encoder_output, input_mask, target_mask)
    return self.norm(x)

## Linear Projection Layer

After the decoder's self-attention and encoder-decoder attention layers, we have a context vector representing each Italian word. This context vector has a high dimension (e.g. 512 or 1024).

We want to take this context vector and generate a probability distribution over the French vocabulary so we can pick the next translated word.

The linear projection layer helps with this. It projects the context vector into a much larger vector called the vocabulary distribution - one entry per word in the vocabulary.

For example, if our Italian vocabulary has 50,000 words, the vocabulary distribution will have 50,000 dimensions. Each dimension corresponds to the probability of that Italian word being the correct translation.

In [101]:
class LinearProjectionLayer(nn.Module):
  def __init__(self, d_model, vocab_size) -> None:
    super().__init__()
    self.proj = nn.Linear(d_model, vocab_size)

  def forward(self, x) -> None:
    return self.proj(x)

## The Transformer

At this point, all we need to do is create a class that represents our model!

In [102]:
class Transformer(nn.Module):
  def __init__(self, encoder: EncoderBlock, decoder: DecoderBlock, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: LinearProjectionLayer) -> None:
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.tgt_embed = tgt_embed
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.projection_layer = projection_layer

  def encode(self, src, src_mask):
    src = self.src_embed(src)
    src = self.src_pos(src)
    return self.encoder(src, src_mask)

  def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
    tgt = self.tgt_embed(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

  def project(self, x):
    return self.projection_layer(x)

## Building Our Transformer

Now that we have each of our components - we need to construct an actual model!

We'll use this helper function to aid in our goal and set up our Encoder/Decoder Stacks!

In [103]:
def build_transformer(config: dict, attention_type: AttentionType = AttentionType.REGULAR) -> nn.Module:
    attn_config = AttentionConfig(
        attention_type=attention_type,
        d_model=config['d_model'],
        num_heads=config.get('num_heads', 8),
        dropout=config.get('dropout', 0.1)
    )

    input_embeddings = InputEmbeddings(config['d_model'], config['input_vocab_size'])
    target_embeddings = InputEmbeddings(config['d_model'], config['target_vocab_size'])

    input_position = PositionalEncoding(config['d_model'], config['seq_len'], config.get('dropout', 0.1))
    target_position = PositionalEncoding(config['d_model'], config['seq_len'], config.get('dropout', 0.1))

    encoder_blocks = [
        TransformerBlock(
            features=config['d_model'],
            attn_config=attn_config,
            feed_forward=FeedForwardBlock(config['d_model'], config.get('d_ff', 2048)),
            dropout=config.get('dropout', 0.1)
        ) for _ in range(config.get('N', 6))
    ]

    decoder_blocks = [
        TransformerBlock(
            features=config['d_model'],
            attn_config=attn_config,
            feed_forward=FeedForwardBlock(config['d_model'], config.get('d_ff', 2048)),
            dropout=config.get('dropout', 0.1),
            is_decoder=True
        ) for _ in range(config.get('N', 6))
    ]

    model = Transformer(
        encoder=EncoderStack(config['d_model'], nn.ModuleList(encoder_blocks)),
        decoder=DecoderStack(config['d_model'], nn.ModuleList(decoder_blocks)),
        src_embed=input_embeddings,
        tgt_embed=target_embeddings,
        src_pos=input_position,
        tgt_pos=target_position,
        projection_layer=LinearProjectionLayer(config['d_model'], config['target_vocab_size'])
    )

    # Initialize parameters
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

# 🤝 BREAKOUT ROOM #2:

## Benchmarking Flash Attention Against Naive Attention

In this Breakout Room, we're going to explore the differences in Flash Attention and Naive Attention by way of benchmarking one against the other.

We'll start with some verification helper functions to make sure things are working as expected.


### Task 1: Verification Functions

Let's start by making sure we *can* use Flash Attention.

Then confirm that we have access to a GPU environment - and then confirm we're running in reduced precision (required for Flash Attention).

In [104]:
def check_flash_attn_available():
    try:
        import flash_attn
        return True
    except ImportError:
        return False

def verify_flash_attention(model):
    """Verify Flash Attention is actually being used"""
    if not check_flash_attn_available():
        print("Flash Attention is not available")
        return False

    # Check if CUDA is available
    if not torch.cuda.is_available():
        print("CUDA is required for Flash Attention")
        return False

    # Verify model precision
    if not any(p.dtype == torch.float16 for p in model.parameters()):
        print("Model needs to be in float16 for Flash Attention")
        return False

    return True

def profile_attention_call(model, input_data):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()
    start_event.record()

    output = model(input_data)

    end_event.record()
    torch.cuda.synchronize()

    return start_event.elapsed_time(end_event)

### Task 2: Implement Benchmarking Suite

This step should be easy enough - we just need to set up a class that can help us benchmark our model using the two forms of Attention we care about!

#### 🏗️ Activity #1:

Implement the single run benchmark code!

In [105]:
import time
import pandas as pd
from typing import List, Tuple
import numpy as np
import gc
from tqdm import tqdm
from dataclasses import dataclass
import torch

@dataclass
class BenchmarkResult:
    time_ms: float

class AttentionBenchmark:
    def __init__(
        self,
        seq_lengths: List[int] = [128, 256, 512, 1024, 2048, 4096],
        num_trials: int = 3,
        num_warmup: int = 5,
        num_iterations: int = 100,
        device: str = "cuda"
    ):
        self.seq_lengths = seq_lengths
        self.num_trials = num_trials
        self.num_warmup = num_warmup
        self.num_iterations = num_iterations
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")

    def _clear_memory(self):
        """Clear GPU and CPU memory."""
        if self.device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    def _run_single_test(self, attention_type, seq_len: int, d_model: int = 256, num_heads: int = 4) -> Tuple[torch.Tensor, BenchmarkResult]:
        """Run a single test for the specified attention type and sequence length."""

        ### YOUR CODE HERE
        self._clear_memory()

        if self.device.type == "cuda":
            torch.cuda.reset_max_memory_allocated()

        batch_size = 1

        x = torch.randn(batch_size, seq_len, d_model).to(self.device)
        mask = torch.ones(batch_size, 1, seq_len, seq_len).bool().to(self.device)

        attn_config = AttentionConfig(
            attention_type=attention_type,
            d_model=d_model,
            num_heads=num_heads,
            dropout=0.1
        )
        model = ModularAttention(attn_config).to(self.device)

        if attention_type == AttentionType.FLASH:
            if not check_flash_attn_available():
                print(f"Flash Attention not available for seq_len {seq_len}. Skipping...")
                return None, None
            model = model.half()
            x = x.half()
            if not verify_flash_attention(model):
                print(f"Flash Attention verification failed for seq_len {seq_len}. Skipping...")
                return None, None

        try:
            for _ in range(self.num_warmup):
                _ = model(x, x, x, mask)
                torch.cuda.synchronize()

            total_time_ms = 0

            for _ in range(self.num_iterations):
                torch.cuda.synchronize()
                start_time = time.perf_counter()

                _ = model(x, x, x, mask)

                torch.cuda.synchronize()
                end_time = time.perf_counter()

                iteration_time_ms = (end_time - start_time) * 1000
                total_time_ms += iteration_time_ms

            avg_time_ms = total_time_ms / self.num_iterations

            del x, mask, model
            self._clear_memory()

            return None, BenchmarkResult(time_ms=avg_time_ms)

        except Exception as e:
            print(f"Error during {attention_type.value} test for seq_len {seq_len}: {str(e)}")
            return None, None


    def run_benchmark(self) -> pd.DataFrame:
        """Run the complete benchmark suite."""
        results = []

        for seq_len in tqdm(self.seq_lengths, desc="Testing sequence lengths"):
            batch_size = 4
            trial_results = {attention_type: [] for attention_type in AttentionType}

            for _ in range(self.num_trials):
                for attention_type in AttentionType:
                    _, result = self._run_single_test(
                        attention_type=attention_type,
                        seq_len=seq_len
                    )
                    if result is None:
                        print(f"Skipping {attention_type.value} for seq_len {seq_len} due to failure.")
                        continue

                    trial_results[attention_type].append(result)

                self._clear_memory()

            for attention_type in AttentionType:
                trial_data = trial_results[attention_type]
                if not trial_data:
                    continue

                results.append({
                    'Sequence Length': seq_len,
                    'Attention Type': attention_type.value,
                    'Batch Size': batch_size,
                    'Time per Iteration (ms)': np.mean([r.time_ms for r in trial_data]),
                    'Time Std Dev (ms)': np.std([r.time_ms for r in trial_data])
                })

            self._clear_memory()

        if not results:
            print("No successful benchmark results obtained.")
            return pd.DataFrame()

        return pd.DataFrame(results)

#### 🏗️ Activity #2:

Write out, in natural language, how this benchmark class works!

The benchmark class computes average time taken to compute attention outputs for different types of attentions. It is parameterized on a few parameters:
1. `sequence_lengths` - The sequence lengths over which different types of attentions are computed. The time taken for computing the forward pass is measured.
2. `num_trials` - The number of times computation is repeated for each sequence length
3. `num_warmup` - Number of times we run a "warmup" computation to ensure the GPU is ready for more accurate measurements.
4. `num_iterations` - The number of times attention is computed for a given sequence length. The computed time is averaged over `num_iterations` to get a more accurate measurement of time.

Essentially, the benchmark class loops over **every sequence length** then computes regular attention and flash attention using dummy inputs of the same sequence length *several times* to get an accurate measurement of the time taken to compute attention of the specific type for that sequence length. Several measurements are taken over the same sequence length for a better average measurement of the time taken. Threafter, we store the time taken to compute attion for each sequence length and attention type. These result tuples of sequence length, attention type, batch size, measured time and standard deviation for measured time are then stored in a pandas dataframe.

To avoid inconsistency in measurement while the GPU "warms up", several warm-up iterations are performed which are not measured as part of the average time taken.

### Task 3: Run Benchmark!

Easily enough, all we need to do is actually fire the benchmark off!

In [106]:
benchmark = AttentionBenchmark(
    seq_lengths=[128, 256, 512, 1024, 2048, 4096, 8192, 16392],
    num_trials=5,
)
results_df = benchmark.run_benchmark()

Testing sequence lengths: 100%|██████████| 8/8 [01:12<00:00,  9.07s/it]


### Task 4: Plotting the Results

Now that we have a benchmark created - we'll use Plotly to...well, plot our results!

In [107]:
import plotly.graph_objects as go

def plot_benchmark_results(df: pd.DataFrame) -> go.Figure:
    """Visualize benchmark results focusing on time per iteration."""
    fig = go.Figure()

    # Colors for different attention types
    colors = {
        'regular': 'blue',
        'flash': 'red'
    }

    # Plot: Time per Iteration vs Sequence Length
    for att_type in df['Attention Type'].unique():
        data = df[df['Attention Type'] == att_type]
        fig.add_trace(
            go.Scatter(
                x=data['Sequence Length'],
                y=data['Time per Iteration (ms)'],
                name=f"{att_type.title()} Time",
                line=dict(color=colors.get(att_type, 'black')),
                mode='lines+markers'
            )
        )

    # Update layout
    fig.update_layout(
        title='Attention Mechanism Benchmark Results',
        xaxis_title='Sequence Length',
        yaxis_title='Time per Iteration (ms)',
        height=600,
        showlegend=True,
        hovermode='x unified'
    )

    return fig


In [108]:
fig = plot_benchmark_results(results_df)
fig.show()

![benchmark results](./benchmark_results.png)

#### 👪❓Discussion Question 3:

Explore with your group what these results show!


#### Regular Attention
The time taken for regular attention grows very quickly with sequence length in piecewise linear chunks, although the overall trend looks quadratic. As the sequence length increases, the time taken to compute attention also grows quickly.

#### Flash Attention
While the time taken for flash attention does grow with the sequence length, it grows much more slowly compared to linear attention. In fact on the comparative plot above, it *almost* looks flat (although it is growing).

#### Acknowledgements

This notebook is heavily adapted from a number of incredible resources on Transformers, including but not limited to:

- https://blog.floydhub.com/the-transformer-in-pytorch/
- https://arxiv.org/pdf/1706.03762.pdf
- https://txt.cohere.com/what-are-transformer-models/
- https://jalammar.github.io/illustrated-transformer/