# Coding Attention Mechanisms

- The reasons for using attention mechanisms in neural networks
- A basic self-attention framework, progressing to an enhanced self-attention mechanism 
- A causal attention module that allows LLMs to generate one token at a time
- Masking randomly selected attention weights with dropout to reduce overfitting
- Stacking multiple causal attention modules into a multi-head attention module

Before the advent of transformers, recurrent neural networks (RNNs) were the most popular encoder‚Äìdecoder architecture for language translation. An RNN is a type of neural network where outputs from previous steps are fed as inputs to the current step, making them well-suited for sequential data like text. 

In an encoder‚Äìdecoder RNN, the input text is fed into the encoder, which processes it sequentially. The encoder updates its hidden state (the internal values at the hidden layers) at each step, trying to capture the entire meaning of the input sentence in the final hidden state.

The decoder then takes this final hidden state to start generating the translated sentence, one word at a time. It also updates its hidden state at each step, which is supposed to carry the context necessary for the next-word prediction.

Before the advent of transformer models, encoder‚Äìdecoder RNNs were a popular choice for machine translation. The encoder takes a sequence of tokens from the source language as input, where a hidden state (an intermediate neural network layer) of the encoder encodes a compressed representation of the entire input sequence. Then, the decoder uses its current hidden state to begin the translation, token by token.

While we don‚Äôt need to know the inner workings of these encoder‚Äìdecoder RNNs, the key idea here is that the encoder part processes the entire input text into a hidden state (memory cell). The decoder then takes in this hidden state to produce the output. You can think of this hidden state as an embedding vector

Self-attention is a mechanism that allows each position in the input sequence to consider the relevancy of, or ‚Äúattend to,‚Äù all other positions in the same sequence when computing the representation of a sequence. Self-attention is a key component of contemporary LLMs based on the transformer architecture, such as the GPT series.

# The meaning of "self"

In self-attention, "self" refers to computing attention **within the same sequence**. Specifically:

- Each element in the sequence establishes relationships with **all other elements in that same sequence** (including itself)
- For example: when processing a sentence, each word attends to all other words in that sentence
- "Self" emphasizes **attending to itself**, meaning the relationships are computed among elements within the input sequence itself

**Example**: The sentence "I love eating apples"
- The word "apples" will attend to "I", "love", "eating", and "apples" itself
- All these relationships are established **within the same** input sentence

## What are "traditional attention mechanisms"?

Traditional attention mechanisms primarily refer to **attention used in sequence-to-sequence models**:

- Attention is computed **between two different sequences**
- Typical application: machine translation
  - **Encoder sequence** (source language): English sentence
  - **Decoder sequence** (target language): Chinese sentence
  - Each Chinese character in the decoder attends to all English words in the encoder

**Key difference**:
- **Traditional attention**: Establishes relationships between two different sequences (Sequence A ‚Üí Sequence B)
- **Self-attention**: Establishes relationships within a single sequence (among elements within Sequence A itself)

# A simple self-attention mechanism without trainable weights

In [2]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts   (x^3)
    [0.22, 0.58, 0.33], # with     (x^4)
    [0.77, 0.25, 0.10], # one      (x^5)
    [0.05, 0.80, 0.55]] # step     (x^6)
)

![attention-score](../images/attention-score.png)

In [3]:
query = inputs[1] # the second input token served as query (journey)
print("query vector:", query)

print("inputs shape:", inputs.shape[0])
attn_scores_2 = torch.empty(inputs.shape[0])
print("attn_scores_2:", attn_scores_2)

for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)

print("attn_scores_2:", attn_scores_2)

query vector: tensor([0.5500, 0.8700, 0.6600])
inputs shape: 6
attn_scores_2: tensor([0., 0., 0., 0., 0., 0.])
attn_scores_2: tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


the dot product is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot product indicates a greater degree of alignment or similarity between the vectors. In the context of self-attention mechanisms, the dot product determines the extent to which each element in a sequence focuses on, or ‚Äúattends to,‚Äù any other element: the higher the dot product, the higher the similarity and attention score between two elements.

Raku code to compute attention scores using dot products:

```raku
# ÂÆö‰πâËæìÂÖ•Âº†ÈáèÔºà6‰∏™tokenÔºåÊØè‰∏™3Áª¥ÂêëÈáèÔºâ
my @inputs = (
    [0.43, 0.15, 0.89],  # Your     (x^1)
    [0.55, 0.87, 0.66],  # journey  (x^2)
    [0.57, 0.85, 0.64],  # starts   (x^3)
    [0.22, 0.58, 0.33],  # with     (x^4)
    [0.77, 0.25, 0.10],  # one      (x^5)
    [0.05, 0.80, 0.55]   # step     (x^6)
);

# ÈÄâÊã©Á¨¨‰∫å‰∏™token‰Ωú‰∏∫Êü•ËØ¢ÂêëÈáèÔºàÁ¥¢Âºï‰∏∫1Ôºâ
my @query = @inputs[1].flat;  # ‰ΩøÁî® .flat Â±ïÂºÄÊï∞ÁªÑ
say "query vector: [{@query.join(', ')}]";
say "inputs shape: {@inputs.elems}";

# ÂàõÂª∫Á©∫Êï∞ÁªÑÂ≠òÂÇ®Ê≥®ÊÑèÂäõÂàÜÊï∞ÔºàRaku Êï∞ÁªÑ‰ºöËá™Âä®Êâ©Â±ïÔºâ
my @attn_scores_2;
say "attn_scores_2: []";

# ‰ΩøÁî®Ë∂ÖËøêÁÆóÁ¨¶ ¬ª*¬´ Âíå reduction operator [+] ËÆ°ÁÆóÊØè‰∏™ËæìÂÖ•ÂêëÈáè‰∏éÊü•ËØ¢ÂêëÈáèÁöÑÁÇπÁßØ
my @attn-scores = @inputs.map: -> @x { [+] @x ¬ª*¬´ @query };
say "attn_scores: [{@attn-scores.join(', ')}]";
```

## normalization

we normalize each of the attention scores we computed previously. The main goal behind the normalization is to obtain attention weights that sum up to 1. This normalization is a convention that is useful for interpretation and maintaining training stability in an LLM. Here‚Äôs a straightforward method for achieving this normalization step

In [31]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("attention weights:", attn_weights_2_tmp)
print("sum of attention weights:", attn_weights_2_tmp.sum())

attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
sum of attention weights: tensor(1.0000)


```raku
# ÂΩí‰∏ÄÂåñÊìç‰ΩúÔºöÂ∞ÜÊ≥®ÊÑèÂäõÂàÜÊï∞Èô§‰ª•ÊÄªÂíåÔºåÂæóÂà∞Ê≥®ÊÑèÂäõÊùÉÈáçÔºàÂíå‰∏∫1Ôºâ
my $sum = [+] @attn-scores;
my @attn_weights_2_tmp = @attn-scores ¬ª/¬ª $sum;
say "attention weights: {@attn_weights_2_tmp}";
say "sum of attention weights: {[+] @attn_weights_2_tmp}";
```

In practice, it‚Äôs more common and advisable to use the softmax function for normalization. This approach is better at managing extreme values and offers more favorable gradient properties during training. The following is a basic implementation of the softmax function for normalizing the attention scores:

In [32]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("attention weights (naive softmax):", attn_weights_2_naive)
print("sum of attention weights (naive softmax):", attn_weights_2_naive.sum())

attention weights (naive softmax): tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
sum of attention weights (naive softmax): tensor(1.)


the softmax function ensures that the attention weights are always positive. This makes the output interpretable as probabilities or relative importance, where higher weights indicate greater importance.

Note that this naive softmax implementation (softmax_naive) may encounter numerical instability problems, such as overflow and underflow, when dealing with large or small input values. Therefore, in practice, it‚Äôs advisable to use the PyTorch implementation of softmax, which has been extensively optimized for performance.

In [33]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("attention weights:", attn_weights_2)
print("sum of attention weights:", attn_weights_2.sum())

attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
sum of attention weights: tensor(1.)


calculating the context vector z(2) by multiplying the embedded input tokens, x(i), with the corresponding attention weights and then summing the resulting vectors. Thus, context vector z(2) is the weighted sum of all input vectors, obtained by multiplying each input vector by its corresponding attention weight:

In [42]:
query = inputs[1] # the second input token served as query (journey)
context_vec_2 = torch.zeros(inputs.shape[1])
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
print("context vector:", context_vec_2)

context vector: tensor([0.4419, 0.6515, 0.5683])


## Computing attention weights for all input tokens

In [None]:
# step1: compute attention scores for all queries
attn_scores = torch.empty(6, 6)

for x, x_i in enumerate(inputs):
    for i, x_j in enumerate(inputs):
        attn_scores[x, i] = torch.dot(x_i, x_j)

print("attn_scores:", attn_scores)

attn_scores: tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Each element in the tensor represents an attention score between each pair of inputs.

When computing the preceding attention score tensor, we used for loops in Python. However, for loops are generally slow, and we can achieve the same results using matrix multiplication

In [36]:
attn_scores = inputs @ inputs.T
print("attn_scores (matrix multiplication):", attn_scores)

attn_scores (matrix multiplication): tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [37]:
# step2: compute attention weights for all queries
attn_weights = torch.softmax(attn_scores, dim=-1)
print("attn_weights:", attn_weights)

attn_weights: tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In the context of using PyTorch, the dim parameter in functions like `torch.softmax` specifies the dimension of the input tensor along which the function will be computed. By setting `dim=-1`, we are instructing the `softmax` function to apply the normalization along the last dimension of the attn_scores tensor. If attn_scores is a two-dimensional tensor (for example, with a shape of [rows, columns]), it will normalize across the columns so that the values in each row (summing over the column dimension) sum up to 1.

In [39]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("row 2 sum:", row_2_sum)
print("all row sums:", attn_weights.sum(dim=-1))

row 2 sum: 1.0
all row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [40]:
# step3: compute context vectors for all queries
all_context_vecs = attn_weights @ inputs
print("all_context_vecs:", all_context_vecs)

all_context_vecs: tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [43]:
print("previous context vector for query 2:", context_vec_2)

previous context vector for query 2: tensor([0.4419, 0.6515, 0.5683])


# Implementing self-attention with trainable weights

In [7]:
x_2 = inputs[1]         # the second input element
d_in = inputs.shape[1]  # the input embedding size, d_in=3
d_out = 2               # the output embedding size, d_out=2

In [59]:
# initialize weight matrix Wq, Wk, and Wv
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

We set `requires_grad=False` to reduce clutter in the outputs, but if we were to use the weight matrices for model training, we would set `requires_grad=True` to update these matrices during model training.

Next, we compute the query, key, and value vectors:

In [60]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print("query_2:", query_2)
print("key_2:", key_2)
print("value_2:", value_2)

query_2: tensor([0.4306, 1.4551])
key_2: tensor([0.4433, 1.1419])
value_2: tensor([0.3951, 1.0037])


Weight parameters vs. attention weights 


In the weight matrices W, the term ‚Äúweight‚Äù is short for ‚Äúweight parameters,‚Äù the values of a neural network that are optimized during training. This is not to be confused with the attention weights. As we already saw, attention weights determine the extent to which a context vector depends on the different parts of the input (i.e., to what extent the network focuses on different parts of the input). 


In summary, weight parameters are the fundamental, learned coefficients that define the network‚Äôs connections, while attention weights are dynamic, context-specific values.

In [61]:
keys = inputs @ W_key
values = inputs @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [62]:
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
print("attn_scores_22:", attn_scores_22)

attn_scores_22: tensor(1.8524)


we can generalize this computation to all attention scores via matrix multiplication

In [63]:
attn_scores_2 = query_2 @ keys.T
print("attn_scores_2:", attn_scores_2)

attn_scores_2: tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


As we can see, as a quick check, the second element in the output matches the attn_score_22 we computed previously

Now, we want to go from the attention scores to the attention weights, as illustrated in figure 3.16. We compute the attention weights by scaling the attention scores and using the softmax function. However, now we scale the attention scores by dividing them by the square root of the embedding dimension of the keys (taking the square root is mathematically the same as exponentiating by 0.5):

After computing the attention scores œâ, 
the next step is to normalize these scores using the softmax function to obtain the attention weights ùõº.

![self-attention-weights](../images/self-attention-weights.png)

Finally, we compute the context vectors by multiplying the attention weights with the value vectors:

In [64]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 /d_k**0.5, dim=-1)
print("attn_weights_2:", attn_weights_2)

attn_weights_2: tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


The rationale behind scaled-dot product attention


The reason for the normalization by the embedding dimension size is to improve the training performance by avoiding small gradients. For instance, when scaling up the embedding dimension, which is typically greater than 1,000 for GPT-like LLMs, large dot products can result in very small gradients during backpropagation due to the softmax function applied to them. As dot products increase, the softmax function behaves more like a step function, resulting in gradients nearing zero. These small gradients can drastically slow down learning or cause training to stagnate.


The scaling by the square root of the embedding dimension is the reason why this self-attention mechanism is also called scaled-dot product attention.

Similar to when we computed the context vector as a weighted sum over the input vectors (see section 3.3), we now compute the context vector as a weighted sum over the value vectors. Here, the attention weights serve as a weighting factor that weighs the respective importance of each value vector.

In [65]:
context_vec_2 = attn_weights_2 @ values
print("context_vec_2:", context_vec_2)

context_vec_2: tensor([0.3061, 0.8210])


So far, we‚Äôve only computed a single context vector, z(2). Next, we will generalize the code to compute all context vectors in the input sequence, z(1) to z(T).

Why query, key, and value?


The terms ‚Äúkey,‚Äù ‚Äúquery,‚Äù and ‚Äúvalue‚Äù in the context of attention mechanisms are borrowed from the domain of information retrieval and databases, where similar concepts are used to store, search, and retrieve information.


A query is analogous to a search query in a database. It represents the current item (e.g., a word or token in a sentence) the model focuses on or tries to understand. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.


The key is like a database key used for indexing and searching. In the attention mechanism, each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match the query. 


The value in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.

## Implementing a compact self-attention Python class

In [10]:
# a compact self-attention class
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec  

In this PyTorch code, SelfAttention_v1 is a class derived from nn.Module, which is a fundamental building block of PyTorch models that provides necessary functionalities for model layer creation and management. 


The __init__ method initializes trainable weight matrices (W_query, W_key, and W_value) for queries, keys, and values, each transforming the input dimension d_in to an output dimension d_out. 


During the forward pass, using the forward method, we compute the attention scores (attn_scores) by multiplying queries and keys, normalizing these scores using softmax. Finally, we create a context vector by weighting the values with these normalized attention scores.

In [11]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


As a quick check, notice that the second row ([0.3061, 0.8210]) matches the contents of context_vec_2 in the previous section.

Self-attention involves the trainable weight matrices Wq, Wk, and Wv. These matrices transform input data into queries, keys, and values, respectively, which are crucial components of the attention mechanism. As the model is exposed to more data during training, it adjusts these trainable weights

![self-attention-class](../images/self-attention-v1.png)

We can improve the SelfAttention_v1 implementation further by utilizing PyTorch‚Äôs nn.Linear layers, which effectively perform matrix multiplication when the bias units are disabled. Additionally, a significant advantage of using nn.Linear instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training.

In [17]:
# a self-attention class using PyTorch's nn.Linear layers
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values

        return context_vec

In [20]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


Exercise 3.1 Comparing SelfAttention_v1 and SelfAttention_v2


Note that nn.Linear in SelfAttention_v2 uses a different weight initialization scheme as nn.Parameter(torch.rand(d_in, d_out)) used in SelfAttention_v1, which causes both mechanisms to produce different results. To check that both implementations, SelfAttention_v1 and SelfAttention_v2, are otherwise similar, we can transfer the weight matrices from a SelfAttention_v2 object to a SelfAttention_v1, such that both objects then produce the same results.


Your task is to correctly assign the weights from an instance of SelfAttention_v2 to an instance of SelfAttention_v1. To do this, you need to understand the relationship between the weights in both versions. (Hint: nn.Linear stores the weight matrix in a transposed form.) After the assignment, you should observe that both instances produce the same outputs.

In [15]:
import torch
import torch.nn as nn

# [Your class definitions here]

# Create instances with same dimensions
torch.manual_seed(123)
d_in, d_out = 3, 2

sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v2 = SelfAttention_v2(d_in, d_out)

# Create sample input
x = torch.rand(2, d_in)

# Before weight transfer - different outputs
print("Before weight transfer:")
print("v1 output:\n", sa_v1(x))
print("v2 output:\n", sa_v2(x))
print()

# Transfer weights from v2 to v1 (TRANSPOSE is the key!)
sa_v1.W_query = nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_key = nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value = nn.Parameter(sa_v2.W_value.weight.T)

# After weight transfer - identical outputs
print("After weight transfer:")
print("v1 output:\n", sa_v1(x))
print("v2 output:\n", sa_v2(x))
print()

# Verify they're equal
print("Outputs are equal:", torch.allclose(sa_v1(x), sa_v2(x)))

Before weight transfer:
v1 output:
 tensor([[0.1671, 0.3726],
        [0.1677, 0.3746]], grad_fn=<MmBackward0>)
v2 output:
 tensor([[0.3038, 0.2414],
        [0.3047, 0.2418]], grad_fn=<MmBackward0>)

After weight transfer:
v1 output:
 tensor([[0.3038, 0.2414],
        [0.3047, 0.2418]], grad_fn=<MmBackward0>)
v2 output:
 tensor([[0.3038, 0.2414],
        [0.3047, 0.2418]], grad_fn=<MmBackward0>)

Outputs are equal: True


please see https://claude.ai/chat/97eef876-1d16-4a7c-a8c3-2620c1de2c3a

Next, we will make enhancements to the self-attention mechanism, focusing specifically on incorporating causal and multi-head elements. The causal aspect involves modifying the attention mechanism to prevent the model from accessing future information in the sequence, which is crucial for tasks like language modeling, where each word prediction should only depend on previous words. 


The multi-head component involves splitting the attention mechanism into multiple ‚Äúheads.‚Äù Each head learns different aspects of the data, allowing the model to simultaneously attend to information from different representation subspaces at different positions. This improves the model‚Äôs performance in complex tasks.

## Hiding future words with causal attention

For many LLM tasks, you will want the self-attention mechanism to consider only the tokens that appear prior to the current position when predicting the next token in a sequence. Causal attention, also known as masked attention, is a specialized form of self-attention. It restricts a model to only consider previous and current inputs in a sequence when processing any given token when computing attention scores. This is in contrast to the standard self-attention mechanism, which allows access to the entire input sequence at once.


Now, we will modify the standard self-attention mechanism to create a causal attention mechanism, which is essential for developing an LLM in the subsequent chapters. To achieve this in GPT-like LLMs, for each token processed, we mask out the future tokens, which come after the current token in the input text. We mask out the attention weights above the diagonal, and we normalize the nonmasked attention weights such that the attention weights sum to 1 in each row. Later, we will implement this masking and normalization procedure in code.

In [21]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print("Attention weights from sa_v2:\n", attn_weights)

Attention weights from sa_v2:
 tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


We can implement the second step using PyTorch‚Äôs tril function to create a mask where the values above the diagonal are zero

In [22]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print("Simple lower-triangular mask:\n", mask_simple)

Simple lower-triangular mask:
 tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


Now, we can multiply this mask with the attention weights to zero-out the values above the diagonal

In [23]:
masked_simple = attn_weights * mask_simple
print("Masked attention weights (simple):\n", masked_simple)

Masked attention weights (simple):
 tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


The third step is to renormalize the attention weights to sum up to 1 again in each row. We can achieve this by dividing each element in each row by the sum in each row

In [24]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_normalized = masked_simple / row_sums
print("Normalized masked attention weights (simple):\n", masked_simple_normalized)

Normalized masked attention weights (simple):
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


The softmax function converts its inputs into a probability distribution. When negative infinity values (-‚àû) are present in a row, the softmax function treats them as zero probability. (Mathematically, this is because e‚Äâ‚Äâ‚Äì‚àû approaches 0.)


We can implement this more efficient masking ‚Äútrick‚Äù by creating a mask with 1s above the diagonal and then replacing these 1s with negative infinity (-inf) values:

In [25]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print("Masked attention scores (with -inf):\n", masked)

Masked attention scores (with -inf):
 tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [26]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print("Attention weights after softmax with masking:\n", attn_weights)

Attention weights after softmax with masking:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


We could now use the modified attention weights to compute the context vectors via `context_vec = attn_weights @ values`, as in section 3.4. However, we will first cover another minor tweak to the causal attention mechanism that is useful for reducing overfitting when training LLMs.

In [27]:
context_vec = attn_weights @ sa_v2.W_value(inputs)
print("Context vectors with masking:\n", context_vec)

Context vectors with masking:
 tensor([[-0.0872,  0.0286],
        [-0.0991,  0.0501],
        [-0.0999,  0.0633],
        [-0.0983,  0.0489],
        [-0.0514,  0.1098],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


## Masking additional attention weights with dropout

Dropout in deep learning is a technique where randomly selected hidden layer units are ignored during training, effectively ‚Äúdropping‚Äù them out. This method helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units. It‚Äôs important to emphasize that dropout is only used during training and is disabled afterward.


In the transformer architecture, including models like GPT, dropout in the attention mechanism is typically applied at two specific times: after calculating the attention weights or after applying the attention weights to the value vectors. Here we will apply the dropout mask after computing the attention weights, because it‚Äôs the more common variant in practice.

In [28]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # we choose a dropout rate of 50%

example = torch.ones(6, 6)      # we create a matrix of 1s
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [29]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


## Implementing a compact causal attention class



In [32]:
# Two inputs with six tokens each; each token has embedding dimension 3
batch = torch.stack((inputs, inputs), dim=0) 

print("batch shape:", batch.shape)

batch shape: torch.Size([2, 6, 3])


In [33]:
# a compact causal attention class
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.dropout = dropout
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # add a dropout layer
        
        # create lower-triangular mask
        self.register_buffer(
            "mask", 
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We transpose dimensions 1 and 2, keeping the batch dimension at the first position (0)
        attn_scores = queries @ keys.transpose(1, 2)   

        # In PyTorch, operations with a trailing underscore are performed in-place, 
        # avoiding unnecessary memory copies.
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

In [35]:
torch.manual_seed(123)
context_length = batch.shape[1]

ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs shape:", context_vecs.shape)
print("context_vecs:", context_vecs)

context_vecs shape: torch.Size([2, 6, 2])
context_vecs: tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


# Extending single-head attention to multi-head attention

The term ‚Äúmulti-head‚Äù refers to dividing the attention mechanism into multiple ‚Äúheads,‚Äù each operating independently. In this context, a single causal attention module can be considered single-head attention, where there is only one set of attention weights processing the input sequentially.