# Attention Mechanism



In [None]:
%load_ext watermark
%watermark -v -p numpy,pandas,polars,torch,lightning --conda

In [None]:
# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt


# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

warnings.filterwarnings("ignore")


# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

### Context Vector

- Context vector is the weighted sum of the input vectors that captures the relevent information from the entire sequence for a given position. i.e. it can be thought of as an enriched embedding vector of the inout

#### Calculate Context Vector

- Attention Score: 
  - it's calculated by finding the dot product of the token's query vector and the key vector of the other tokens in the sequence.
  - The scores are normalized using softmax to produce the attention weights.
- Multiply the embedded input tokens with their corresponding attention weights and sum the resulting vectors to get the context vector.
- This is done for each position in the sequence to get the context vector for the entire sequence.

## Without Trainable Parameters (Simplified Version)

In [None]:
seed: int = 5

# Assume that we have an input with a 3-D embeddings shown below:
inputs: Tensor = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your (x^1)
        [0.55, 0.87, 0.66],  # journey (x^2)
        [0.57, 0.85, 0.64],  # starts (x^3)
        [0.22, 0.58, 0.33],  # with (x^4)
        [0.77, 0.25, 0.10],  # one (x^5)
        [0.05, 0.80, 0.55],  # step (x^6)
    ]
)

# Calculate the context vector for the 2nd token (x^2)
# 1: Cal the attention scores
query: Tensor = inputs[1]
attn_scores_index_1: Tensor = torch.empty(inputs.shape[0])

for idx, x_1 in enumerate(inputs):
    # Cal the dot product of the query vector and each key vector in the input
    attn_scores_index_1[idx] = torch.dot(x_1, query)

print(f"{attn_scores_index_1 = }")

In [None]:
# 2: Normalize the attention scores to obtain the attention weights
attn_scores_weights_1: Tensor = torch.softmax(attn_scores_index_1, dim=-1)
print(f"{attn_scores_weights_1 = }")
attn_scores_weights_1.sum(-1)

In [None]:
inputs_shape: tuple = tuple(inputs.shape)
attn_scores_weights_1_shape: tuple = tuple(attn_scores_weights_1.shape)
print(f"{attn_scores_weights_1_shape = } AND {inputs_shape = }")

# 3: Calculate the context vector as the weighted average of the values
# Transpose the inputs so that we can perform matrix multiplication
context_vector_1: Tensor = attn_scores_weights_1 @ inputs
context_vector_1

#### Calculate The Attention Weights Of The Sequence

In [None]:
# Step 1: Calculate the attention scores
print(f"{inputs.shape = } AND {inputs.T.shape = }")
attn_scores: Tensor = inputs @ inputs.T
print(f"\n{attn_scores.shape = }")

attn_scores

In [None]:
# Step 2: Calculate the attention weights. i.e. normalize the attention scores using softmax
attn_weights = torch.softmax(attn_scores, dim=-1)
print(f"\n{attn_weights.shape = }")
attn_weights

In [None]:
# Step 3: Calculate the context vector
print(f"{attn_weights.shape = } AND {inputs.shape = }")

context_vector: Tensor = attn_weights @ inputs
context_vector

<br><hr>

## Implement Self-Attention With Trainable Parameters

- AKA **Scaled Dot-Product Attention**
- Add weight matrices that are updated during training.
- It's scaled by the square root of the dimension size to improve the training performance and avoid small gradients.

In [None]:
# Calculate the trainable attention weights for a given token in the input
x_1: Tensor = inputs[1]
print(f"{x_1.shape = }")
# Embedding dimension
d_in: int = x_1.shape[-1]
# Output embedding size
d_out: int = 2

In [None]:
torch.manual_seed(seed)

# Trainable parameters: requires_grad=False (to reduce the clutter and keep things simple)
W_query: Tensor = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_key: Tensor = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_value: Tensor = nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

# Compute the query, key, and value tensors for the given index
query_1: Tensor = torch.matmul(x_1, W_query)
key_1: Tensor = torch.matmul(x_1, W_key)
value_1: Tensor = torch.matmul(x_1, W_value)

# Compute the key and value tensors for ALL the input
query: Tensor = torch.matmul(inputs, W_query)
key: Tensor = torch.matmul(inputs, W_key)
value: Tensor = torch.matmul(inputs, W_value)

print(f"{query_1.shape = }")
query_1

In [None]:
query

In [None]:
print(f"{query_1.shape = } | {key_1.shape = } | {value_1.shape = }")
print(f"{query.shape = } | {key.shape = } | {value.shape = }")
print()

# Calculate the attention scores
# For a single token in the query
attn_score_1: Tensor = torch.matmul(query_1, key.T)  # query_1 @ key_1

# For all the tokens in the query
attn_scores: Tensor = torch.matmul(query, key.T)

print(f"{attn_score_1.shape = } | {attn_score_1 = }")
print()
print(f"{attn_scores.shape = } | {attn_scores = }")

In [None]:
# Calculate the scaled attention weights. It's scaled by the square root of
# the dimension size to improve the training performance and avoid small gradients.
attn_weights_1: Tensor = torch.softmax(attn_score_1 / (d_out**0.5), dim=-1)
attn_weights: Tensor = torch.softmax(attn_scores / (d_out**0.5), dim=-1)


print(f"{attn_weights_1.shape =} | {attn_weights_1 = }")
print()
print(f"{attn_weights.shape = } | {attn_weights = }")

In [None]:
inputs.shape

In [None]:
print(f"{attn_weights_1.shape = } | {value.shape = }")
print(f"{attn_weights.shape = } | {value.shape = }")

context_vector_1: Tensor = attn_weights_1 @ value
context_vector: Tensor = attn_weights @ value
print()
print(f"{context_vector_1 = }\n\n")

print(f"{context_vector = }\n\n")

### Query, Key and Value

- **Query** : The query is analogous to a `search` in a `database`. It represents the current item/token the model focuses on.
- **Key** : The key is analogous to the `index` in a `database`. It represents the item/token that the model compares the query to.
- **Value** : The value is analogous to the `value` in a `key-value` pair. It represents the actual content or representation of the item/token.

In [None]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in: int, d_out: int, qkv_bias: bool = False) -> None:
        super().__init__()

        self.d_out = d_out

        self.W_query = self._init_weights(d_in, d_out, qkv_bias)
        self.W_key = self._init_weights(d_in, d_out, qkv_bias)
        self.W_value = self._init_weights(d_in, d_out, qkv_bias)

    def forward(self, x: Tensor) -> Tensor:
        queries: Tensor = torch.matmul(x, self.W_query)
        keys: Tensor = torch.matmul(x, self.W_key)
        values: Tensor = torch.matmul(x, self.W_value)
        attn_scores: Tensor = queries @ keys.T
        attn_weights: Tensor = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)
        context_vector: Tensor = torch.matmul(attn_weights, values)
        return context_vector

    @staticmethod
    def _init_weights(d_in: int, d_out: int, qkv_bias: bool = False) -> nn.Parameter:
        """This is used to initialize the weights."""
        if qkv_bias:
            weight: Tensor = torch.randn(d_in, d_out) + torch.randn(d_out)
        else:
            weight = torch.randn(d_in, d_out)

        return nn.Parameter(weight)

In [None]:
torch.manual_seed(seed)

self_attn_v1 = SelfAttention_v1(d_in=d_in, d_out=d_out)
print(f"{self_attn_v1 = }")
print(self_attn_v1(inputs))

### Update

- Improve the `SelfAttention_v1` implementation using PyTorch's `nn.Linear` layers instead of `nn.Parameter` layers.

- This is because:
  - `nn.Linear` performs effective matrix multiplication when the bias units are disabled.
  - `nn.Linear` has a an optimized weight initialization scheme.

In [None]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in: int, d_out: int, qkv_bias: bool = False) -> None:
        super().__init__()

        self.d_out = d_out

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x: Tensor) -> Tensor:
        queries: Tensor = self.W_query(x)
        keys: Tensor = self.W_key(x)
        values: Tensor = self.W_value(x)
        attn_scores: Tensor = queries @ keys.T
        attn_weights: Tensor = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)
        context_vector: Tensor = torch.matmul(attn_weights, values)

        return context_vector

In [None]:
torch.manual_seed(seed)
self_attn_v2 = SelfAttention_v2(d_in=d_in, d_out=d_out, qkv_bias=False)
print(f"{self_attn_v2 = }")
print(self_attn_v2(inputs))

In [None]:
print(f"{self_attn_v1.W_query.shape = } | {self_attn_v2.W_query.weight.shape = }")
self_attn_v1.W_query, self_attn_v2.W_query.weight

In [None]:
nn.Parameter(self_attn_v2.W_query.weight.T)

In [None]:
# Using the weights of v2, generate the context vector using v1
self_attn_v1.W_query = nn.Parameter(self_attn_v2.W_query.weight.T)
self_attn_v1.W_key = nn.Parameter(self_attn_v2.W_key.weight.T)
self_attn_v1.W_value = nn.Parameter(self_attn_v2.W_value.weight.T)

self_attn_v1(inputs)

### Hidding Future Words With Causal Attention

- Causal attention AKA Masked Attention.
- It restricts the model to only attend to past and current tokens in the input sequence.

<img src="../08-Makemore/images/causal-attention.png" width="600">

[image source](https://livebook.manning.com/book/build-a-large-language-model-from-scratch/chapter-3/v-7/198)

### Applying A Causal Attention Mask

- One way of obtaining the masked attention weight causally is to apply a softmax function to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting matrix.

In [None]:
torch.manual_seed(seed)

# Step 1: Calculate the weight matrices
self_attn_v2 = SelfAttention_v2(d_in=d_in, d_out=d_out)

# Step 2: Queries, Keys, and Values
queries: Tensor = self_attn_v2.W_query(inputs)
keys: Tensor = self_attn_v2.W_key(inputs)
values: Tensor = self_attn_v2.W_value(inputs)

print(f"{queries.shape=} | {keys.shape=} | {values.shape=}")

# Step 3: Calculate the attention weights
attn_scores: Tensor = torch.matmul(queries, keys.T)
attn_weights: Tensor = F.softmax(attn_scores / self_attn_v2.d_out**0.5, dim=-1)
attn_weights

In [None]:
# Step 4: Create a mask
context_length: int = attn_scores.shape[1]
mask_simple: Tensor = torch.tril(torch.ones(context_length, context_length))
print(f"mask_simple: \n{mask_simple }\n")

# Step 5: Multiply the mask with the attention weights
masked_simple: Tensor = attn_weights * mask_simple
print(f"masked_simple: \n{masked_simple }\n")

# Step 6: Re-normaalize the attention weights
row_sums: Tensor = masked_simple.sum(dim=1, keepdim=True)
# Step 5: Multiply the mask with the attention weights
masked_simple_norm: Tensor = masked_simple / row_sums
print(f"masked_simple_norm: \n{masked_simple_norm }\n")

### A More Efficient Implementation of Causal Self-Attention

- Calculate the attention scores.
- Mask with `-Inf` values above the diagonals.
- Mask the attention scores.
- Calculate the attention weights by applying softmax.
- Calculate the context vector by multiplying the attention weights with the values vectors.

In [None]:
# Create a triangular matrix mask with ones above the diagonal
# and zeros on and below the diagonal

mask: Tensor = torch.triu(torch.ones(context_length, context_length), diagonal=1)
# Fill the elements with True with -inf and False with the actual values

masked: Tensor = attn_scores.masked_fill(mask.bool(), float("-inf"))

# Normalize the attention scores
attn_weights: Tensor = F.softmax(masked / keys.shape[-1] ** 0.5, dim=1)

print(f"mask: \n{mask}\n")
print(f"mask.bool(): \n{mask.bool()}\n")
print(f"masked: \n{masked}\n")
print(f"attn_weights: \n{attn_weights}\n")

### Masking Additional Attention Weights With Dropout

- `Dropout` in deep learning is a technique to prevent overfitting. It works by randomly turning off some neurons in a layer during training.
- This forces the network to learn features that are independent of any specific neuron, making it more robust and adaptable to unseen data.
- Dropout is `ONLY` used during training and turned of during evaluation or inference.
- Dropout can be applied in transformer architectures in the following phases:
  - after calculating the attention scores.
  - after applying the softmax to normalize the attention scores (after computing the `attn_weights`).

In [None]:
torch.manual_seed(seed)
dropout_pct: float = 0.5
dropout = nn.Dropout(dropout_pct)
example: Tensor = torch.ones(6, 6)
print(f"example: \n{example}\n")
print(f"dropout: \n{dropout(example)}\n")
print(f"dropout_attn_weight: \n{dropout(attn_weights)}\n")

In [None]:
# Simulate batch inputs
# 2 inputs with 6 tokens each and a dimension of 3: (2, 6, 3)
batch_input: Tensor = torch.stack([inputs, inputs], dim=0)
print(f"{batch_input.shape = }\n")
print(f"batch_input: \n{batch_input}\n")

In [None]:
new: Tensor = torch.randn((2, 6, 3), dtype=torch.float32)
print(f"{new.shape = }\n")
print(f"new: {new}\n")
print(f"new.T: \n{new.T.shape}\n")
print(f"new.transpose(-1, -2): \n{new.transpose(-1, -2).shape}\n")
print(
    f"Shape `new.transpose(-1, -2)`: {tuple(new.transpose(-1, -2).shape)}\n{new.transpose(-1, -2)}\n"
)

In [None]:
class CausalAttention(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        context_length: int,
        dropout: float = 0.0,
        qkv_bias: bool = False,
    ) -> None:
        super().__init__()

        self.d_out = d_out
        self.dropout = nn.Dropout(dropout)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # self.mask: Create a mask to prevent attention to the future tokens
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x: Tensor) -> Tensor:
        # Batch size, sequence length, input dimension
        b, num_tokens, d_in = x.shape

        queries: Tensor = self.W_query(x)
        keys: Tensor = self.W_key(x)
        values: Tensor = self.W_value(x)
        # Switch the last 2 dimensions
        attn_scores: Tensor = queries @ keys.transpose(-1, -2)
        # Inplace operation
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], float("-inf")
        )
        attn_weights: Tensor = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vector: Tensor = attn_weights @ values

        return context_vector

In [None]:
torch.manual_seed(seed)

d_in: int = 3
d_out: int = 2
batch_input: Tensor = torch.stack([inputs, inputs], dim=0)
context_length: int = batch_input.shape[1]
dropout: float = 0.0

causal_attn = CausalAttention(
    d_in=d_in,
    d_out=d_out,
    context_length=context_length,
    dropout=dropout,
    qkv_bias=False,
)
print(f"causal_attn: \n{causal_attn}\n")

context_vectors: Tensor = causal_attn(batch_input)
print(f"context_vectors: \n{context_vectors.shape}\n")

In [None]:
context_vectors

### Extending Single-head Attention To Multi-head Attention

<img src="../08-Makemore/images/multi-head.png" width="600">

[image source](https://livebook.manning.com/book/build-a-large-language-model-from-scratch/chapter-3/v-7/282)

<br>

```text
Source: Gemini

Multi-head attention is a core component in Transformer models that allows them to focus on specific parts of an input sequence. Here's a breakdown:

- Attention mechanism: It attends to relevant parts of the sequence, like focusing on specific words in a sentence.

- Multiple heads: Instead of just one attention mechanism, it has multiple "heads" working in parallel.
Different perspectives: Each head learns to attend to the sequence from a slightly different angle, capturing various relationships between words.

- Combined power: The outputs from all heads are combined, giving the model a richer understanding of the sequence.

Think of it like having multiple analysts examining the same text. Each analyst focuses on slightly different aspects, and by combining their insights, you get a more comprehensive understanding of the content. This allows Transformers to deal with complex relationships within sequential data.
```

- The goal of multi-head attention is to run the attention mechanism in parallel with different, learned linear projections (i.e. the result of multiplying the query, key and value matrix by a learned matrix).

- In code, it can be achieved by implementing a simple `MultiHeadAttentionWrapper` class that stacks multiple instances of the previously implemented `CausalAttention` class.

In [None]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        context_length: int,
        num_heads: int,
        dropout: float = 0.1,
        qkv_bias: bool = False,
    ) -> None:
        super().__init__()

        self.heads = nn.ModuleList(
            [
                CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
                for _ in range(num_heads)
            ]
        )

    def forward(self, x: Tensor) -> Tensor:
        """Process the data `sequentially`."""
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [None]:
# if we use the `MultiHeadAttentionWrapper` class with 2 attention heads (i.e. num_heads=2) and
# `CausalAttention` output dimension (d_out=2), this results in a 4 dimensional context vector.
# i.e. ( d_out * num_heads = 4 )

torch.manual_seed(seed)

d_in: int = 3
d_out: int = 2
batch_input: Tensor = torch.stack([inputs, inputs], dim=0)
context_length: int = batch_input.shape[1]
dropout: float = 0.0

multi_head_attn = MultiHeadAttentionWrapper(
    d_in=d_in,
    d_out=d_out,
    context_length=context_length,
    num_heads=2,
    dropout=dropout,
    qkv_bias=False,
)
print(f"multi_head_attn: \n{multi_head_attn}\n")

context_vectors: Tensor = multi_head_attn(batch_input)
print(f"context_vectors: \n{context_vectors.shape}\n")

In [None]:
# Shape: (2, 6, 4). The 1st dimension of the result is 2 because we have 2 input texts which
# have been duplicated. The 2nd dimension is 6 because we have 6 tokens in each input. The
# 3rd dimension is 4 because we have 4-dimensional embedding of each token.
context_vectors

### Comment

- Shape: (2, 6, 4). The 1st dimension of the result is 2 because we have 2 input texts which have been duplicated.
- The 2nd dimension is 6 because we have 6 tokens in each input. 
- The 3rd dimension is 4 because we have 4-dimensional embedding of each token.

In [None]:
# Ex 3.2 (Return 2-dimensional embedding vectors)

torch.manual_seed(seed)

d_in: int = 3
d_out: int = 1
batch_input: Tensor = torch.stack([inputs, inputs], dim=0)
context_length: int = batch_input.shape[1]
dropout: float = 0.0

multi_head_attn = MultiHeadAttentionWrapper(
    d_in=d_in,
    d_out=d_out,
    context_length=context_length,
    num_heads=2,
    dropout=dropout,
    qkv_bias=False,
)
print(f"Shape of batch_input: \n{batch_input.shape}\n")
print(f"multi_head_attn: \n{multi_head_attn}\n")

context_vectors: Tensor = multi_head_attn(batch_input)
print(f"Shape of context_vectors: \n{context_vectors.shape}\n")
print(f"context_vectors: \n{context_vectors}\n")

<hr>

### implementing Multi-head Attention With Weight Splits

- Efficient implementation of `MultiHeadAttentionWrapper` class.
- The inputs are split into multiple heads by reshaping the projected query, key and value tensors and then combines the results from these heads after computing attention.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_in: int,
        d_out: int,
        context_length: int,
        dropout: float,
        num_heads: int,
        qkv_bias: bool = False,
    ) -> None:
        super().__init__()
        assert d_out % num_heads == 0, "`d_out` should be divisible by `num_heads`"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.dropout = nn.Dropout(dropout)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x: Tensor) -> Tensor:
        # Batch size, sequence length, input dimension
        b, num_tokens, d_in = x.shape
        queries: Tensor = self.W_query(x)
        keys: Tensor = self.W_key(x)
        values: Tensor = self.W_value(x)

        # Reshape and transpose the data
        # Split the matrix by adding `num_heads` dimension and unroll the last dimension
        # i.e. [b, num_tokens, d_out] -> [b, num_tokens, num_heads, head_dim]
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(
            1, 2
        )
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose(
            1, 2
        )

        attn_scores: Tensor = torch.matmul(queries, keys.transpose(-2, -1))
        # Mask truncated to the number of tokens and and use the mask to fill the attention scores
        mask_bool: Tensor = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # Compute the attention weights
        attn_weights: Tensor = F.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vector: Tensor = torch.matmul(attn_weights, values).transpose(1, 2)

        # Combine the heads
        context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out)
        # Apply optional linear output projection
        context_vector = self.out_proj(context_vector)
        return context_vector

### Comment


<img src="../08-Makemore/images/ch03_multihead_attn.png" width=600>

[image source](https://livebook.manning.com/book/build-a-large-language-model-from-scratch/chapter-3/v-7/304)

- In the `MultiheadAttention` class, we initialize one larger weight matrix `Wq` , only perform one matrix multiplication with the inputs to obtain a query matrix `Q`, and then split the query matrix into `Q1` and `Q2`, as shown at the top of this figure.

- The same is done for the keys and values, which are not shown to reduce visual clutter.

- The splitting of the query, key and value tensors as shown in the diagram above is achieved through tensor reshaping and transposing operations using `view` and `transpose` methods.

- The key operation is to split the `d_out` dimension into `num_heads` and `head_dim` where head_dim = d_out / num_heads.

- The splitting is then achieved using the `.view` method. i.e. (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim).

- The tensors are transposed such that `num_heads` dimension is moved to the front. i.e. (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim).

- This is important for correctly aligning the queries, keys, and values across the different heads and performing batched matrix multiplications efficiently.

In [None]:
# Ex. to illustrate batched matrix multiplication
A: Tensor = torch.tensor(
    [
        [
            [
                [0.2745, 0.6584, 0.2775, 0.8573],
                [0.8993, 0.0390, 0.9268, 0.7388],
                [0.7179, 0.7058, 0.9156, 0.4340],
            ],
            [
                [0.0772, 0.3565, 0.1479, 0.5331],
                [0.4066, 0.2318, 0.4545, 0.9737],
                [0.4606, 0.5159, 0.4220, 0.5786],
            ],
        ]
    ]
)

print(f"A.shape: {tuple(A.shape)}")
print(f"A.transpose(2, 3).shape: {tuple(A.transpose(2, 3).shape)}")

# (1, 2, 3, 4): (batch_size, num_heads, sequence_length, dim_size)
# (.., .., 3, 4) @ (.., .., 4, 3) i.e. transpose the last two dimensions
# A.transpose(2, 3) --> (.., .., 4, 3) i.e. transpose the last two dimensions (indexes 2 and 3)
# Matrix multiplication: (1, 2, 3, 4) @ (1, 2, 4, 3)
A @ A.transpose(2, 3)

In [None]:
first_head: Tensor = A[0, 0, :, :]
first_res: Tensor = first_head @ first_head.T  # first_head @ first_head.transpose(0, 1)
print(f"first_head.shape: {tuple(first_head.shape)}")
print(f"first_res: \n{first_res}\n")

second_head: Tensor = A[0, 1, :, :]
second_res: Tensor = second_head @ second_head.T
print(f"second_head.shape: {tuple(second_head.shape)}")
print(f"second_res: \n{second_res}\n")

### Comment

- After the computation of the attention weights and the context vectors, the context vectors from all the heads are transposed back to the shape (b, num_tokens, num_heads, head_dim).
- They are then concatenated (flattened) into the shape (b, num_tokens, d_out) effectively combining the outputs from all the heads.
- An optional output projection layer (`self.out_proj`) is added after combinng the heads.
- The `MultiHeadAttention` class is more complicated than the `MultiHeadAttentionWrapper` class due to the additional reshaping and transposition of the tensors, it's more efficient.
- It's more efficient because it needs ONLY one matrix multiplication to compute the queries, keys, and values.
- In `MultiHeadAttentionWrapper`, we need to repeat this multiplication `num_heads` times which is computationally expensive.

In [None]:
torch.manual_seed(seed)


batch_input: Tensor = torch.stack([inputs, inputs], dim=0)
batch_size, context_length, d_in = batch_input.shape
d_out: int = 2
dropout: float = 0.0

multi_head_attn = MultiHeadAttention(
    d_in=d_in,
    d_out=d_out,
    context_length=context_length,
    num_heads=2,
    dropout=dropout,
    qkv_bias=False,
)

print(f"Shape of batch_input: \n{batch_input.shape}\n")
print(f"multi_head_attn: \n{multi_head_attn}\n")

context_vectors: Tensor = multi_head_attn(batch_input)
print(f"Shape of context_vectors: \n{context_vectors.shape}\n")
print(f"context_vectors: \n{context_vectors}\n")