# Chapter 3: CODING ATTENTION MECHANISMS

## What is "Attention", what is it used for?

The goal of the attention mechanism is to **learn better context vectors**. What is a "contex vector"? Keep reading.

As of writing, I am a aware of two types of "attention":

1. Self-attention
2. Cross-attention

## "Attending" to different parts of the input with self-attention

### A simple self-attention mechanism WITHOUT trainable weights

This section aims to introduce us to the topic and its motivation. Let's go step by step through an example. Let's consider the following sentence:

> "Your journey starts with one step."

Each element $i$ of this sequence (i.e. token) has an associate *token embedding* $x^{(i)}$ which is a $d$-dimensional vector representing the $i^{th}$ element of the input sequence.

For instance, $x^{(1)}$ corresponds to a $d$-dimensional vector representing the token "Your" in the our example.

In self-attention, our goal is **calculate a "context vector" $z^{(i)}$ for EACH element $x^{(i)}$ in the input sequence**. A "context vector" can be interpreted as some form of enriched embedding vector.

> What is the difference between a "token embedding" and a "context vector"? I mean... why do we need to "enrich" an already existing embedding vector? 🤷

It's to allow elements in the input sequence to adjust their representation (i.e. embedding vector, i.e. behavior in the embedding space) based on the absence/presence of other elements in the same sequence (self-attention) or another sequence (cross-attention). That way, their representation better aligns with the *context* the element finds itself in. Do you follow?

Let me give you a tangible example. When I was kid, I had a way of behaving when my parents, especially my dad, left home for work. But when he comes back, his presence causes me to adjust my behavior accordingly. If my dad and I were two vectors, I would be updating my internal representation (i.e. my embedding vector) based on the presence/absence of my dad's embedding vector. We say that I (meaning my vector) would be "attending" to my dad's. 

Another example can I think of. When I am with the boys, I can be goofier, than when I am with the lady I am in love with. Here again, my behavior changes slightly depending on the "*context*" I find myself in.

The same idea applies to elements in our original input sequence. They adjust (or should I say enrich) their representation (i.e. embedding vector) based on the presence/absence of other tokens to better capture the "meaning" of the sequence they're in. If i am the word "fire 🔥", and I find myself the following two sentences:

-  "This is a fire"
-  "This is fire"

I, the word fire 🔥, mean two different things. I am expressing two different ideas in those sentences. So, I need to adjust my embedding to reflect that. How? by scouring my surrounding, which in this case, is the sentence I am in. After my update my representation better reflects the context I am in. This new representation is what I called my "**context vector**".

> Hum... okay, okay... but how do they "adjust" their representation? Like how does it happen exactly.

It's through the attention mechanism that, input sequence elements, form their corresponding context vectors. In the case of self-attention, each element in the input sequence incorporate information from ALL other element in the same sequence. When it's cross-attention, element in the input sequence incorporate information from all other elements in *another* sequence. 

To illustrate this concept, let’s focus on the embedding vector of the second input element, $x^{(2)}$ (which corresponds to the token “journey”) and the corresponding context vector, $z^{(2)}$. This enhanced context vector,$z^{(2)}$, is an embedding that contains information about $x^{(2)}$ and all other input elements, $x^{(1)}$ to $x^{(T)}$. 

![Computing x(2) context vector](imgs/3/0.png)

Also, notice the weight values on the edges. These weights, which we call attention scores, indicate how much information each token should incorporate from its surrounding tokens. For example, a higher attention score between "journey" and another word means "journey" will incorporate more information from that word into its context vector. 

Let's see this in action by coding those three steps.

#### Step 1: Compute attention scores

The first step of implementing self-attention is to **compute the "attention scores" $w$**. These are intermediate values obtained by **computing the dot product between $x^{(2)}$ and all other elements of the input sequence $x^{(1)}$ to $x^{(T)}$**. By the way, while we're using $x^{(2)}$ as our running example here, keep in mind that this same process happens for every single token in our input sequence - each token gets its chance to be the star of the show and query all other tokens!

Computing the attention scores looks like this:

![Computing attention scores](imgs/3/1.png)

Since we are computing the attention scores of $x^{(2)}$, we call this vector the "query"; I think we call it that because this $x^{(2)}$ is asking: "What information in the input sequence is relevant to me?" or "How much of each element in the input sequence is relevant to me?". It's "querying" the input sequence in a sense. The other vectors we are dotting $x^{(2)}$'s representation with (marked as "inputs" on the image) are called: "keys". A "key", in this context, refers a numerical representation that each element of the input sequence "exposes" for the query vector to extract information from. A "key" is information the element of the input sequence exposes about itself.

Quick note: In this simple version of self-attention, we're actually using the *same* vector to serve as "query", "key", and "value" (I talk about "value" a little down there) for each token. So $x^{(2)}$ is playing all these roles! In later sections, we'll see how we can create separate vectors for each of these roles, which gives our attention mechanism more flexibility and power.

#### Step 2: Compute attention weights

The second step in computing self-attention consists in obtaining the **attention weights** by **normalizing the attention scores** obtained in the first step. The main goal behind this step is so that the attention scores *sum to 1*. We do this using the softmax function. This normalization step is super important - it ensures we're taking weighted averages of our vectors rather than potentially ending up with exploding values that could mess up our computations.

![Normalizing attentions scores](imgs/3/2.png)

#### Step 3: Compute context vectors

The third and final step in implementing self-attention is **scaling the "value" vectors (not the "key" vectors we used earlier!) with the attention weights, and adding them together**. While we used "key" vectors to figure out how relevant to the "query" vector each element in the input sequence was through dot products, we now use these attention scores to mix together a different set of vectors called "value" vectors. It's like at this stage the "query" vector is saying: "Now, that I know how much each input is relevant to me (attention scores), to enhance my representation and get a context-aware representation, I want $0.2$% of the value vector of the first input element, then $0.6$% of the value vector of the second input element, etc."

![Adding scaled value vectors](imgs/3/3.png)

Let's implement those three steps we described so far. Remember that in this simplified example, the "query", "key" and "value" vector for the $i^{th}$ element in the sequence is $x^{i}$. $x^{i}$ plays all THREE roles. It will change in the future.

#### Implementing a 'non-trainable' self-attention mechanism

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# STEP 0: Define random 3d vectors to serve as embeddings
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],    # Your      (x^1)
     [0.55, 0.87, 0.66],    # journey   (x^2)
     [0.57, 0.85, 0.64],    # starts    (x^3)
     [0.22, 0.58, 0.33],    # with      (x^4)
     [0.77, 0.25, 0.10],    # one       (x^5)
     [0.05, 0.80, 0.55]]    # step      (x^6)
)

In [2]:
# STEP 1: Compute attention scores
attn_scores = torch.matmul(inputs, inputs.T)
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

The first row of `attn_scores` matrix for tells us how relevant the first word, finds every other word in the input sequence. For instance, there is a `0.4753` relevancy score between the first word and the fourth word in the sequence.

Notes to myself:

- Dot Product Depends on Both **Magnitude** and **Direction**: The dot product measures both the similarity in direction (alignment) and the magnitudes of the two vectors. If a vector $v$ has a relatively small magnitude compared to another vector $u$, the dot product $u \cdot v$ can be larger than $v \cdot v$, even if $u$ and $v$ are not perfectly aligned.

In [3]:
# STEP 2: Compute attention weights by normalizing attention scores (using softmax)
attn_weights = F.softmax(attn_scores, dim=1)
print(attn_weights)

## Just me verifying that each row sums to one
# attn_weights_sum = torch.sum(attn_weights, dim=1, keepdim=True)
# print(attn_weights_sum)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [4]:
# STEP 3: Compute context vectors
ctx_vectors = attn_weights @ inputs
ctx_vectors

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

Earlier, I said that the context vectors are doing a weighted sum of the input sequence using the attention weights. The matrix multiplication `attn_weights @ inputs` is compact way of expressing this idea.

During the matrix multiplication, the first row of `attn_weights` is used to compute the first context vector by effectively taking a weighted combination of the rows in `inputs`. Specifically, each weight in the first row of `attn_weights` scales the corresponding row in `inputs`. This means that the \( $n^{th}$ \) weight in the first row of `attn_weights` determines how much the \( $n^{th}$ \) vector in `inputs` contributes to the resulting context vector.

Mathematically, the computation for the first context vector \( $z^1$ \) can be expressed as:

$$
z^1 = w_{1,1} \cdot x^1 + w_{1,2} \cdot x^2 + w_{1,3} \cdot x^3 + w_{1,4} \cdot x^4 + w_{1,5} \cdot x^5 + w_{1,6} \cdot x^6
$$

Here:
- \( $w_{1,1}, w_{1,2}, \dots, w_{1,6}$ \) are the weights from the first row of `attn_weights` (remember what this row means).
- \( $x^1, x^2, \dots, x^6$ \) are the row vectors from `inputs`.

This process is repeated for each row in `attn_weights`, where the \( $i^{th}$ \) row produces the \( $i^{th}$ \) context vector:

$$
z^i = \sum_{j=1}^{6} w_{i,j} \cdot x^j
$$

where \( $w_{i,j}$ \) is the weight from the \( $i^{th}$ \) row and \( $j^{th}$ \) column of `attn_weights`, and \( $x^j$ \) is the \( $j^{th}$ \) row in `inputs`.

By representing this operation compactly using matrix multiplication (`attn_weights @ inputs`), we are simultaneously computing all the context vectors, where each one is the weighted sum of the input vectors, weighted by the respective row of attention scores.

This highlights the intuition: the attention weights determine how much "focus" each input vector gets in constructing the context vector. The higher the weight, the more influence that particular input vector has on the resulting context vector.


### A simple self-attention mechanism WITH trainable weights

Earlier, we were implementing our baby self-attention mechanism without trainable weights. For this, we made the following key assumption:

> In this simple version of self-attention, we're actually using the *same* vector to serve as "query", "key", and "value" (I talk about "value" a little down there) for each token. So $x^{(2)}$ is playing all these roles! In later sections, we'll see how we can create separate vectors for each of these roles, which gives our attention mechanism more flexibility and power.

With a basic intuition of how to compute self-attention, and what "query", "key", and "value" vectors represent at the conceptual level, we are now in a position where we can introduce trainable weights.

> How can we do this?

By introducing three trainable weight matrices: $W_q$, $W_k$, $W_v$. Those three weight matrices are used the project the embedded input tokens $x^(i)$ into "query", "key" and "value" vectors respectively.

![Trainable weight matrices](imgs/3/4.png)

> But why are we doing this?

So that through backpropagation, the network "learns" how to create *better and more nuanced* "query", "key" and "value" numerical representations for the input tokens over time. We are no longer using the static original token embedding as earlier.

> How do you get the "query", "key", and "value" representations for $x^{2}$?

By multiplying $x^{2}$ with $W_q$, $W_k$, $W_v$ 🙂. That's it.

That's the major difference.

The three steps we described earlier:

1. Compute attention scores
2. Compute the attention weights. Here, we must scale the attention scores, by dividing them by the square root of the embedding dimension of the keys.
3. Compute context vectors

Stay the same.

![Attention with trainable weights](imgs/3/5.png)

#### Implementing self-attention as a Pytorch module

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        """
        d_in: Input embedding size
        d_out: output embedding size
        """
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        T, d_in = x.size()  # sequence length, embedding dimensionality

        # STEP 0: Get query, key, value representations
        queries = x @ self.W_query  # (T, d_out)
        keys    = x @ self.W_key    # (T, d_out)
        values  = x @ self.W_value  # (T, d_out)

        # STEP 1: Compute attention scores
        attn_scores = queries @ keys.T  # (T, T)

        # STEP 2: Compute attention weights
        d_key = keys.shape[1]
        attn_weights = F.softmax(attn_scores / d_key**0.5, dim=1)

        # STEP 3: Compute context vectors (T, T) x (T, d_out) -> (T, d_out)
        ctx_vectors = attn_weights @ values

        return ctx_vectors

In [6]:
torch.manual_seed(123)
self_attention = SelfAttention(d_in=3, d_out=2)

self_attention(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

**Note**: Instead of manually implementing `nn.Parameter(torch.rand(...))`, we could have used `torch.nn.Linear` and set the bias to `False`; in such case `torch.Linear` behaves like a simple matrix multiply. `nn.Linear` has a better weight initialization scheme, contributing to more stable and effective model training.

### Causal attention

The first self attention module we implemented, `SelfAttention` works well. The next step is to implement causal attention mask in code.

> Why are we doing this?

**Causal attention**, also known as *masked attention*, is a specialized form of self-attention. It restricts a model to only consider previous and current inputs in a sequence when processing any given token when computing attention scores. This is in contrast to the standard self-attention mechanism, which allows access to the entire input sequence at once.

When training a model to autoregressively predict the next word, we want the model to only use previous words to "guess" the next word. In other words, we allow words in the input sequence to "attend" to **previous words ONLY**, so only the last word in the sequence gets to "see" the entire sequence. Causal attention is crucial for developing a language model.

![Causal attention](imgs/3/6.png)

> How do you make input tokens NOT pay attention to future token in the input sequence? In roughly two steps:

1. Mask with `-inf` attention scores above the diagonal. The result? A masked (unnormalized) attention scores.
2. We apply softmax to the masked attention score. Softmax zeroes out the `-inf`. The result? Masked normalized attention weights. The gray squares on the above picture are places where attention weights have been zeroed out.

Let's see in this in action in an example:

In [7]:
attn_weights_example = torch.tensor(
    [[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
     [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
     [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
     [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
     [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
     [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]]
)
context_length = attn_weights_example.shape[0]

In [8]:
# STEP 1: Create a mask
mask = torch.triu(
    torch.ones(context_length, context_length),diagonal=1
)
masked = attn_weights_example.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.1921,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.2041, 0.1659,   -inf,   -inf,   -inf,   -inf],
        [0.2036, 0.1659, 0.1662,   -inf,   -inf,   -inf],
        [0.1869, 0.1667, 0.1668, 0.1571,   -inf,   -inf],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658,   -inf],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]])

In [9]:
# STEP 2: Apply softmax
masked_attn_weights = torch.softmax(
    masked, dim=1
)
masked_attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5095, 0.4905, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3417, 0.3291, 0.3292, 0.0000, 0.0000, 0.0000],
        [0.2544, 0.2493, 0.2493, 0.2469, 0.0000, 0.0000],
        [0.2030, 0.1997, 0.1997, 0.1981, 0.1995, 0.0000],
        [0.1712, 0.1666, 0.1666, 0.1646, 0.1666, 0.1644]])

The above is the **masked** attention weights matrix. Notice how positions with `-inf` were turned into zeros.

This matrix represents how much each word in the sequence contributes to the context representation of another word. For instance, the first row indicates how the first word attends to itself and previous words in the sequence. Since we only allow words to "look at" previous words (causality), the first word has no preceding words to attend to. As a result, it pays 100% attention to itself, meaning its context vector is entirely composed of its own value vector.

Now, consider the second row. This row represents how the second word forms its context vector. Here, the second word can attend to itself as well as the first word. From the values, we see that it splits its attention between the first word (about 50.95%) and itself (about 49.05%). This distribution reflects how the model combines information from the first word's value vector and its own value vector to form the context vector of the second word in context.

Another thing technique which makes to introduce now is: **Dropout**. Dropout in deep learning is a technique where randomly selected hidden layer units are ignored during training, effectively “dropping” them out. This method helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units.

In the context of the attention mechanism we are dealing with here, dropout involves randomly zeroing out random spots in the attention weights matrix. For instance, When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of $\frac{1}{1 - \text{dropout rate}}$. In our case, it's $\frac{1}{1 - 0.5} = 2$. All of the zeroing out and the scaling stuff is done internally by the `torch.nn.Dropout` layer.

#### Implementing causal attention as a Pytorch module

In [10]:
class SingleHeadCausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, dropout_rate, context_length, qkv_bias=False):
        """
        Implementation of a SINGLE head causal self-attention

        Params
        ------
        d_in: input embedding dimensionality 
        d_out: attention head output embedding dimensionality
        """
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) 
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)  
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.attn_dropout = nn.Dropout(dropout_rate)
        self.context_length = context_length

        # causal mask to ensure that attention is only applied to the left in the input sequence
        # NOTE: 'register_buffer' is used to register a tensor as a
        # buffer in the model. Buffers are NOT considered
        # parameters, thus won't be updated during training.
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    
    def forward(self, x):
        T, d_in = x.size() # sequence length, embedding dimension

        queries = self.W_query(x)   # (T, d_in) x (d_in, d_out) -> (T, d_out)
        keys = self.W_key(x)        # (T, d_in) x (d_in, d_out) -> (T, d_out)
        values = self.W_value(x)    # (T, d_in) x (d_in, d_out) -> (T, d_out)

        # STEP 1: Compute the masked attention scores
        attn_scores = (queries @ keys.T) * (keys.shape[1]**0.5)      # (T, d_out) x (d_out, T) -> (T, T)
        masked_attn_scores = attn_scores.masked_fill(
            self.mask.bool()[:T, :T], # adjusting mask size based on sequence length
            -torch.inf
        )

        # STEP 2: Compute masked attention weights + dropout
        masked_attn_weights = F.softmax(masked_attn_scores, dim=1)
        masked_attn_weights = self.attn_dropout(masked_attn_weights)

        # STEP 3: Compute context vectors
        ctx_vectors = masked_attn_weights @ values  # (T, T) x (T, d_out) -> (T, d_out)

        return ctx_vectors

In [11]:
single_head_causal_self_attention = SingleHeadCausalSelfAttention(
    d_in=3,
    d_out=2,
    context_length=10,
    dropout_rate=0.0
)

single_head_causal_self_attention(inputs)

tensor([[0.4772, 0.1063],
        [0.6007, 0.3485],
        [0.6293, 0.4037],
        [0.5499, 0.3662],
        [0.5331, 0.3455],
        [0.5087, 0.3554]], grad_fn=<MmBackward0>)

The above implementation is just an illustration purposes, it has several limitations:

1. It does not have support for "batched" inputs, meaning it can only process one sentence at a time. In real-world applications, models typically handle multiple sentences (or sequences) simultaneously for efficiency.
2. This implementation is a "single head" self-attention because it performs attention just once on the input sequence. In practice, it is common to perform attention multiple times in parallel on the same sequence, which is referred to as multi-head attention.

### What is "Multi-head self-attention"?

Earlier, I we briefly defined multi-head attention. Multi-head attention extends the single-head mechanism by performing attention computations multiple times in parallel. Each of these "heads" independently focuses on a different aspect of the input data, which significantly improves the model's capacity to understand complex relationships.

> Hum... okay, okay... but how does it work? How is it different from single head attention?

Let's recall that multi-head attention is an extension of single head attention. Here is how it works:

1. **Multiple Heads:**
   - Instead of a single set of projections for queries, keys, and values, multi-head attention uses multiple independent sets of parameters, one for each "head." These projections transform the input sequence into different subspaces, enabling each head to focus on different patterns or relationships.
   - For example:
     - One head might attend to long-range dependencies in the sequence.
     - Another might focus on local relationships (e.g., adjacent tokens).
     - Yet another could specialize in recognizing specific semantic roles (e.g., subject-object relationships in sentences).

2. **Independent Attention Computations:**
   - Each head computes its attention scores and context vectors independently, just like single-head attention does. The result is a set of context vectors, one for each head, where each captures a different "perspective" on the input data.

3. **Concatenation and Final Projection:**
   - The context vectors from all heads are concatenated along the feature dimension, producing a combined representation of the sequence.
   - This concatenated output is then passed through a final linear layer, which projects it back into the original embedding space.


> But why use multi-head attention?

Multi-head attention addresses the limitations of single-head attention by allowing the model to learn and combine diverse relationships simultaneously. Let’s consider an example to illustrate this:

> **Example Sentence:** *"The cat sat on the mat."*

In this sentence:
- One head might focus on the relationship between "cat" and "sat" (subject-action).
- Another head might look at "sat" and "mat" (action-location).
- A third head might capture the broader syntactic structure, such as identifying "The cat" as a noun phrase.

By combining these diverse perspectives, multi-head attention creates a richer representation of the sequence, enabling the model to understand it more deeply. Imagine single-head attention as a single spotlight scanning a scene. It can only focus on one thing at a time. In contrast, multi-head attention is like having multiple spotlights, each illuminating a different part of the scene. Together, they provide a more complete understanding.

Alright, now let's implement all this.

#### Multihead self-attention implementation

In [12]:
# Inspired by Andrej's self-attention implementation
# see: NanoGPT repository, model.py
import math

class MultiHeadCausalSelfAttention(nn.Module):
    def __init__(self, embed_dim, 
                 n_head, dropout_rate, context_length, qkv_bias=False):
        super().__init__()
        assert embed_dim % n_head == 0, "The embedding dimension must be divisible by the number of heads"

        self.embed_dim = embed_dim 
        self.n_head = n_head
        self.dropout_rate = dropout_rate
        self.context_length = context_length
        self.qkv_bias = qkv_bias
        self.attn_dropout = nn.Dropout(dropout_rate)

        # For "query", "key", and "value" projections
        self.W_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
        # For the output projection
        self.output_proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
        
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
            .unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, T, T)
        )

    def forward(self, x):
        B, T, C = x.size() # Batch size, Sequence length, embedding dimension

        # calculate tokens "query", "key", "values" projections in all heads 
        qkv_proj = self.W_qkv(x) # (B, T, C) x (C, C*3) -> (B, T, C*3)

        # Since the Q, K, V projections of each token is concatenated along
        # dim=2 in tensor qkv_proj, we want to separate the projections
        # in individual q, k, and v matrices. All of the same shape
        q, k, v = qkv_proj.split(self.embed_dim, dim=2) # (B, T, C)

        # We split the embedding dimension into multiple "heads"
        # In other words, each "head" gets a *slice* of the embedding dimension (C)
        q = q.view(B, T, self.n_head, C // self.n_head) # (B, T, nh, hs)
        k = k.view(B, T, self.n_head, C // self.n_head) # (B, T, nh, hs)
        v = v.view(B, T, self.n_head, C // self.n_head) # (B, T, nh, hs)

        # Why the transpose? Read the explanation below.
        q = q.transpose(1, 2) # (B, nh, T, hs)
        k = k.transpose(1, 2) # (B, nh, T, hs)
        v = v.transpose(1, 2) # (B, nh, T, hs)

        # Attention-calc-step 1: Compute masked attention scores in each head
        attn_scores = (q @ k.transpose(2, 3)) * ( 1 / math.sqrt(k.size(-1))) # (B, nh, T, T)
        masked_attn_scores = attn_scores.masked_fill(
            self.mask[:, :, :T, :T] == 1,
            -torch.inf
        )
        
        # Attention-calc-step 2: Compute masked attention weights along last dimension
        masked_attn_weights = F.softmax(masked_attn_scores, dim=-1)
        masked_attn_weights = self.attn_dropout(masked_attn_weights)
        
        # Attention-calc-step 3: Compute context vectors
        ctx_vectors = masked_attn_weights @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        # Re-assemble all head outputs side by side to form
        # a single context vector for each token in each sequence
        # in the batch
        y = ctx_vectors.transpose(1, 2).contiguous().view(B, T, C)

        # Project the context vectors back into the original space
        y = self.output_proj(y)
        return y

In [13]:
# Define input parameters
batch_size = 2   # Number of sequences in a batch
seq_length = 10  # Sequence length (context length)
embed_dim = 16   # Embedding dimension
n_heads = 4      # Number of attention heads
dropout_rate = 0.1

# Initialize the MultiHeadCausalSelfAttention module
attention_module = MultiHeadCausalSelfAttention(
    embed_dim=embed_dim,
    n_head=n_heads,
    dropout_rate=dropout_rate,
    context_length=seq_length,
    qkv_bias=True
)

# Generate random input tensor of shape (batch_size, seq_length, embed_dim)
x = torch.randn(batch_size, seq_length, embed_dim)

# Forward pass through the attention module
output = attention_module(x)

# Print the input and output shape
print("Input shape:", x.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([2, 10, 16])
Output shape: torch.Size([2, 10, 16])


Notice our the input and output have the same shape, but they do NOT contain the same thing. `output` contains *context vectors*, and we talked about what those are a little earlier in this notebook.

## Few explanations about the `MultiHeadCausalSelfAttention` module

#### What does `self.W_qkv` represent?

This linear layer produces a vector three times the dimension of its input. Another way of looking at this larger vector is as 3 concatenated vectors of the *same* size. Effectively getting the three projections for each of our tokens in one go, using one matrix, instead of three different ones.

```txt
Input: [____] (dimension C)
Output: [____|____|____] (dimension 3C = Q|K|V)
```

#### Why did we do `q, k, v = torch.split(qkv_proj, dim=2)`?

As mentioned before, we used a single linear layer to obtain the three projections of each token in each input sequence. And those projections are concatenated along `dim=2`, the 3rd dimension of our tensor. So `torch.split`, as its name suggests, splits (along the 3rd dimension) our `qkv_proj` matrix of shape `(B, T, C*3)` into three separate matrices `q`, `k`, and `v`... each of shape `(B, T, C)`.

#### Can you tell me about the reshape using `view(B, T, self.n_head, C // self.n_head)`. 

Remember, we are trying to perform attention in parallel, that is in different "heads". We know that C (embedding dimension) must be divisible by number of heads (that's why there's the assert statement in `__init__` method).

So if C = 768 and n_head = 12, each head will get 768/12 = 64 dimensions
The view operation splits the last dimension (C) into two dimensions: (n_head, C//n_head)
This preserves the total number of elements: B * T * C = B * T * n_head * (C//n_head)

The resulting `q`, `k`, and `v` matrices are of this shape: $(B \times T \times \text{nh} \times \text{ns})$ where:

- `B`   = Batch size
- `T`   = Sequence length
- `nh`  = Number of heads
- `hs`  = Head size (the embedding dimension inside each head)

The question you might be asking yourself now is how **to interpret this 4D tensor**?. Unfortunately, I do not have a nice picture draw to help you visualize, but would like to share my intuition. Personally, I think of 4D tensors as a *list of of 3D tensors*. In our context, the tensor of the `(B, T, nh, hs)` is a list of `B` elements where each element is a 3D tensor of shape `(T, nh, hs)`. So in the case of the `q` matrix, this 4D tensor tells me that each batch element (i.e, input sequence), has `T` elements (tokens) and that each token in the sequence has `nh` vectors of `hs` dimensions. It tells me that the `q` matrix hosts the "query" projection that, EACH token in the input sequence, has in EACH head. The same reasoning applies to `k` and `v` matrices, but they hold the "key" and "value" projections respectively.

#### Why did we do `transpose(1,2)`?

- First let talk about the new tensor shape `(B, nh, T, hs)` and what is means. It's still a 4D tensor, but this time the `nh` dimension comes first. This tensor shape tells us that are grouping the projections *by* "head", and not by "sequence" like before. Essentially, when we "open" a 3D tensor we will see that it has `nh` "heads" and in a given head, there are the `T`, `hs`-dimensional representation of the token in the sequence. So, we grouping the "query", "key", and "value" representation in the batch *per* head. 

- Now let's talk about **why** the transpose was necessary. I would like us to remember we are trying to compute attention in each "head" *independently* and at the *same time*. It just so happens that when the tensor dimension is like this `(B, T, nh, hs)`, it's less convenient for attention to be computed in parallel. **Without** the transpose, we have this shape `(B, T, nh, hs)`. To compute attention for each head independently, the framework would need to: (1) First iterate through the batch dimension (2) Then through the sequence dimension `T` (3) Only then reach the head dimension to parallelize. **With** the transpose, `nh` is positioned as the batch dimension in each 3D tensor, so we are ready start the computation earlier, than without the transpose. Consider the following snippet for illustration: 
  ```py
    # Without transpose: (B, T, n_head, head_size)
    for b in range(batch_size):
        for t in range(seq_length):
            for h in range(n_heads):  # <- Parallelization happens here (deep in the loops)
                compute_attention()

    # With transpose: (B, n_head, T, head_size)
    for b in range(batch_size):
        for h in range(n_heads):  # <- Parallelization happens earlier
            compute_attention()
  ```
  Also, torch expects batch dimensions to be first, when it's the case it performs operations more efficiently on inner dimensions.

#### Why did we need an "output" projection?

The output projection serves several important purposes:

1. **Dimension Reunification**: Remember that we split our embedding dimension (`C`) across multiple heads, where each head worked with a smaller dimension (`C//n_head`). After the attention computation, we need to combine these separate head outputs back into our original embedding dimension size. The output projection helps transform this concatenated multi-head output back into the expected embedding dimension.

2. **Learned Integration**: The output projection learns how to best combine the information from different attention heads. Think of each attention head as looking at different aspects of the relationships between tokens (like how some heads might focus on nearby words while others look at broader context). The output projection learns how to weigh and mix these different types of attention patterns.

3. **Maintaining Network Depth**: Without this projection, we would be directly using the raw attention outputs. The additional linear transformation adds another layer of learned parameters, allowing the network to perform more complex transformations on the attention outputs before passing them to the next layer.

Remember: The output projection gets better at doing the above ☝️ *during* training.

To visualize this:
- **Before** output projection: Each token has `n_head` different representations of size `head_size`.

- **After** output projection: Each token has a single, unified representation of size embed_dim

This is why the output projection takes input of shape `(B, T, C)` and produces output of the same shape - it's consolidating the multi-head attention information while preserving the original dimensionality of our token embeddings.