# Chapter 3: Coding Attention Mechanism

In [1]:
from importlib.metadata import version
print("torch version:", version("torch"))

torch version: 2.9.1


## 3.3 Attending to Different Parts of the Input with Self-Attention

### 3.3.1 A Simple Self-Attention Mechanism without Trainable Weights

Assume we have an input sequence, denoted as *x*, consisting of *T* elements represented as *x(1)* to *x(T)*. For example, in natural language processing, these elements could represent words or tokens in a sentence but have already been transformed into embeddings. 

Consider an input text like "Your journey starts with one step." Each element of the sequence, such as *x(1)*, corresponds to a *d*-dimensional embedding vector representing a specific token, such as "Your".

In self-attention, our goal is to calculate context vectors *z(i)* for each element *x(i)* in the input sequence(, where *z* and *x* have the same dimension). A **context vector** can be interepreted as an enriched embedding that captures not only the information from the token itself but also relevant information from other tokens in the sequence.

The concept of context vectors is essential in LLMs, which need to understand the relationships and relevance of words in a sentence to each other. A context vector *z(i)* is a weighted sum over the inputs *x(1)* to *x(T)*.

For example, suppose we focus on the embedding vector *x(2)*, which corresponds to the token "journey". This context vector *z(2)* is a weighted sum over all inputs *x(1)* to *x(T)* weighted with respect to the second input element *x(2)*. 

The attention weights are the weights that determine how much each of the input elements contributes to the weighted sum when computing *z(2)*.

By convention, the unnormalized attention weights are referred to as **attention scores**, whereas the normalized weights are called **attention weights**.

Suppose we have the following input sentence that is already embedded in 3-dimensional vectors:

In [1]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
     [0.55, 0.87, 0.66], # journey  (x^2)
     [0.57, 0.85, 0.64], # starts   (x^3)
     [0.22, 0.58, 0.33], # with     (x^4)
     [0.77, 0.25, 0.10], # one      (x^5)
     [0.05, 0.80, 0.55]] # step     (x^6)
)

**Step 1**: compute unnormalized attention scores *w*.

Suppose we use *x(2)* as the query *q(2)*, then we can compute the unnormalized attention scores via dot products:
- $w(2,1) = x(1)q(2)^T$
- $w(2,2) = x(2)q(2)^T$
- $w(2,3) = x(3)q(2)^T$
- ...
- $w(2,T) = x(T)q(2)^T$

where $w(2,1)$ tells us the input sequence element 2 was used as a query against input sequence element 1.

Now we can compute the unnormalized attention scores by computing the dot product between the query *x(2)* and all other input tokens:

In [3]:
query = inputs[1] # 2nd input token "journey"

attn_scores_2 = torch.empty(inputs.shape[0])
print("Shape of attn_scores_2:", attn_scores_2.shape)

Shape of attn_scores_2: torch.Size([6])


In [4]:
for i, x_i in enumerate(inputs):
    # dot product
    # (transpose not necessary here since x_i and query are 1D tensors)
    attn_scores_2[i] = torch.dot(x_i, query)

print("Attention scores for the 2nd input token 'journey':")
print(attn_scores_2)

Attention scores for the 2nd input token 'journey':
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [None]:
# This is totally equivalent:
res = 0.

for idx, element in enumerate(inputs[0]):
    res += inputs[0][idx] * query[idx]

print("Dot product computed manually:", res)
print("Dot product computed with torch.dot:", torch.dot(inputs[0], query))

Dot product computed manually: tensor(0.9544)
Dot product computed with torch.dot: tensor(0.9544)


**Step 2**: normalize the unnormalized attention scores so that they sum up to 1.

This normalization is a convention that is useful for interpretation and maintaining training stability in an LLM.

In [6]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

print("Attention weights for the 2nd input token 'journey':")
print(attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights for the 2nd input token 'journey':
tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


However in practice, we use the **softmax** function for normalization, which is better at handling extreme values and has more desirable gradient properties during training.

In [7]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)


attn_weights_2_naive = softmax_naive(attn_scores_2)

print("Attention weights for the 2nd input token 'journey' using naive softmax:")
print(attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights for the 2nd input token 'journey' using naive softmax:
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


Softmax ensures the attention weights are always positive and sum to 1, making the output interpretable as probabilities or relative importance.

To avoid numerical instability (overflow/underflow) when computing the exponential of large or small values, we prefer the PyTorch built-in softmax function:

In [8]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights for the 2nd input token 'journey' using PyTorch softmax:")
print(attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights for the 2nd input token 'journey' using PyTorch softmax:
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


**Step 3**: compute the context vector *z(2)* by multiplying the embedded input tokens *x(1)* to *x(T)* with the attention weights *a(2,1)* to *a(2,T)* and sum the resulting vectors:

In [9]:
query = inputs[1]  # 2nd input token "journey"

context_vector_2 = torch.zeros(query.shape) # initialize context vector

for i, x_i in enumerate(inputs):
    context_vector_2 += attn_weights_2[i] * x_i

print("Context vector for the 2nd input token 'journey':")
print(context_vector_2)

Context vector for the 2nd input token 'journey':
tensor([0.4419, 0.6515, 0.5683])


### 3.3.2 Computing Attention Weights for All Input Tokens

We have computed the attention weights and context vector for the 2nd input token "journey".

In [10]:
print("Attention weights for the 2nd input token 'journey':")
print(attn_weights_2)

Attention weights for the 2nd input token 'journey':
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


In [11]:
print("Context vector for the 2nd input token 'journey':")
print(context_vector_2)

Context vector for the 2nd input token 'journey':
tensor([0.4419, 0.6515, 0.5683])


Now we generalize this process to compute attention weights and context vectors for all input tokens in the sequence. This involves repeating the steps outlined above for each token in the input sequence, treating each token as a query in turn.

In self-attention, this process starts with the calculation of attention scores, which are subsequently normalized to derive attention weights that sum to one. Later, these attention weights are used to generate the context vectors through a weighted sum of the input embeddings.

**Step 1**: compute the unnormalized attention scores *w* for each input token as a query against all input tokens.

In [13]:
attn_scores = torch.empty((inputs.shape[0], inputs.shape[0]))

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print("Attention scores matrix:")
print(attn_scores)

Attention scores matrix:
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Each element in the tensor represents the attention score between each pair of input tokens. We can achieve the same result using matrix multiplication for efficiency.

In [14]:
attn_scores = inputs @ inputs.T

print("Attention scores matrix using matrix multiplication:")
print(attn_scores)

Attention scores matrix using matrix multiplication:
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


**Step 2**: normalize each row so that the values in each row sum to 1.

In [15]:
attn_weights = torch.softmax(attn_scores, dim=-1)

print("Attention weights matrix:")
print(attn_weights)

Attention weights matrix:
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


We can verify that the rows all sum to 1:

In [16]:
row_2_sum = sum(attn_weights[1])
print("Sum of attention weights for the 2nd input token 'journey':", row_2_sum)

all_rows_sum = attn_weights.sum(dim=-1)
print("Sum of attention weights for all input tokens:", all_rows_sum)

Sum of attention weights for the 2nd input token 'journey': tensor(1.)
Sum of attention weights for all input tokens: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


**Step 3**: compute all context vectors.

In [17]:
all_context_vecs = attn_weights @ inputs

print("All context vectors:")
print(all_context_vecs)

All context vectors:
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [18]:
print("Previous context vector for the 2nd input token 'journey':")
print(context_vector_2)

print("Context vector from all_context_vecs for the 2nd input token 'journey':")
print(all_context_vecs[1])

Previous context vector for the 2nd input token 'journey':
tensor([0.4419, 0.6515, 0.5683])
Context vector from all_context_vecs for the 2nd input token 'journey':
tensor([0.4419, 0.6515, 0.5683])


## 3.4 Implementing Self-Attention with Trainable Weights

### 3.4.1 Computing the Attention Weights Step-by-Step

We will implement the self-attention mechanism by introducing three trainable weight matrices: $W_{query}$, $W_{key}$, and $W_{value}$, which are used to project the input embeddings, $x^{(i)}$, into three different spaces: queries, keys, and values, respectively. These weight matrices are learned and can be used to capture different aspects of the input data:
- **Query vector**: $q^{(i)} = x^{(i)} W_{query}$
- **Key vector**: $k^{(i)} = x^{(i)} W_{key}$
- **Value vector**: $v^{(i)} = x^{(i)} W_{value}$

For example, if we still use the input text "Your journey starts with one step." and focus on the second token "journey" and assume the embedding vector $x^{(2)}$ corresponds to "journey", we can compute the query, key, and value vectors as follows:
- $q^{(2)} = x^{(2)} W_{query}$
- $k^{(2)} = x^{(2)} W_{key}$
- $v^{(2)} = x^{(2)} W_{value}$

But for other tokens, we can only compute their key and value vectors (since they are not used as queries in this step):
- For the first token "Your":
    - $k^{(1)} = x^{(1)} W_{key}$
    - $v^{(1)} = x^{(1)} W_{value}$
- For the last token "step.":
    - $k^{(T)} = x^{(T)} W_{key}$
    - $v^{(T)} = x^{(T)} W_{value}$

In [2]:
# second input token "journey"
x_2 = inputs[1]

# the input embedding size, d=3
d_in = inputs.shape[1]
# the output embedding size, d=2
d_out = 2

For demo purposes, we assume the output embedding size is 2, so the weight matrices have the following shapes:
- $W_{query}$: (3, 2)
- $W_{key}$: (3, 2)
- $W_{value}$: (3, 2)

But in GPT-like models, the output embedding size is usually the same as the input embedding size.

In [3]:
torch.manual_seed(0)
# Initialize weight matrices with random values
# (Set `requires_grad=False` for demo purposes)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# Compute query, key, and value vectors for x_2
query_2 = x_2 @ W_query  # shape: (d_out,)
key_2 = x_2 @ W_key      # shape: (d_out,)
value_2 = x_2 @ W_value  # shape: (d_out,)

print("Query vector for 'journey':", query_2)
print("Key vector for 'journey':", key_2)
print("Value vector for 'journey':", value_2)

Query vector for 'journey': tensor([0.5528, 0.9559])
Key vector for 'journey': tensor([0.8962, 1.3083])
Value vector for 'journey': tensor([0.7284, 1.0720])


NOTE: These weight matrices are not the attention weights; they are trainable parameters used to compute the query, key, and value vectors from the input embeddings. The attention weights are computed later using the dot products of the query and key vectors.

Next, we will compute the key and value vectors for all input tokens as they are needed to compute the attention weights with respect to the query vector.

In [4]:
keys = inputs @ W_key    # shape: (num_tokens, d_out)
values = inputs @ W_value  # shape: (num_tokens, d_out)

print("keys shape:", keys.shape)
print("values shape:", values.shape)

keys shape: torch.Size([6, 2])
values shape: torch.Size([6, 2])


We successfully projected the 6-dimensional input embeddings into 2-dimensional key and value vectors for all input tokens. The next step is to compute the attention scores by computing the dot products between the query and each key vector.

In [5]:
keys_2 = keys[1]  # key vector for "journey"
attn_scores_22 = query_2.dot(keys_2)

print("Attention score for 'journey' with itself:", attn_scores_22)

Attention score for 'journey' with itself: tensor(1.7460)


Then we can generalize this process to all attention scores for the query vector corresponding to the token "journey".

In [9]:
attn_scores_2 = query_2 @ keys.T # shape: (num_tokens,)
print("Attention scores for 'journey' with all tokens:\n", attn_scores_2)

Attention scores for 'journey' with all tokens:
 tensor([1.1268, 1.7460, 1.7399, 0.9351, 1.1402, 1.0587])


Then we need to normalize the attention scores by scaling them and using the softmax function. Here, we will scale the attention scores by dividing them by the square root of the embedding dimension of the key vectors.

In [8]:
d_k = keys.shape[1] # embedding dimension of the key vectors
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print("Attention weights for 'journey' with all tokens:\n", attn_weights_2)

Attention weights for 'journey' with all tokens:
 tensor([0.1443, 0.2236, 0.2227, 0.1261, 0.1457, 0.1376])


The reason for the normalization by the embedding dimension size is to **improve the training performance by avoiding small gradients.** When scaling up the embedding dimension, which is typically greater than 1,000 for GPT-like LLMs, large dot products can result in very small gradients during backpropagation due to the softmax function applied to them. As dot products increase, the softmax function behaves more like a *step function*, resulting in gradients **nearing zeros**.

The final step is to compute the context vectors as a weighted sum over the value vectors. The attention weights serve as a weigthting factor that weighs the respective importance of each value vector.

In [10]:
context_vec_2 = attn_weights_2 @ values
print("Context vector for 'journey':\n", context_vec_2)

Context vector for 'journey':
 tensor([0.5780, 0.8419])


The concepts of "key", "query", and "value" in the Self-Attention Mechanism are inspired by information retrieval systems, where similar terminology is used to "store", "search", and "retrieve" information.
- *Key* is like a database key used for indexing and searching. Each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match the query.
- *Query* is like a search query in a database. It represents the current item (e.g., a word or token in a sentence) the model focuses on or tries to understand. It is used to determine how much attention to pay to other items in the sequence.
- *Value* is like the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are the most relevant to the query (the current focus item), it retrieves the corresponding values to construct a meaningful representation (context vector) for that item.

### 3.4.2 Implementing a Compact Self-Attention Class

In [11]:
import torch.nn as nn

In [12]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()

        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key      # shape: (num_tokens, d_out)
        queries = x @ self.W_query  # shape: (num_tokens, d_out)
        values = x @ self.W_value   # shape: (num_tokens, d_out)

        attn_scores = queries @ keys.T # shape: (num_tokens, num_tokens)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,
            dim=-1
        )

        context_vec = attn_weights @ values  # shape: (num_tokens, d_out)
        return context_vec

In [13]:
torch.manual_seed(0)

sa_v1 = SelfAttention_v1(d_in=d_in, d_out=d_out)
print("Context vectors from SelfAttention_v1:\n", sa_v1(inputs))

Context vectors from SelfAttention_v1:
 tensor([[0.5762, 0.8392],
        [0.5780, 0.8419],
        [0.5780, 0.8420],
        [0.5628, 0.8183],
        [0.5704, 0.8302],
        [0.5635, 0.8196]], grad_fn=<MmBackward0>)


Self-attention involves the trainable weight matrices $W_{query}$, $W_{key}$, and $W_{value}$, which are used to project the input embeddings into query, key, and value vectors, respectively.

We can further improve the `SelfAttention_v1` class by using PyTorch's `nn.Linear` layers, which effectively perform matrix multiplications when the bias units are disabled. A significant advantage of using `nn.Linear` layers instead of manually implementing `nn.Parameter` is that `nn.Linear` has an optimized weight initialization scheme, contributing to more stable and efficient training.

In [14]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out,  bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out,  bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)     # shape: (num_tokens, d_out)
        queries = self.W_query(x)  # shape: (num_tokens, d_out)
        values = self.W_value(x)   # shape: (num_tokens, d_out)

        attn_scores = queries @ keys.T # shape: (num_tokens, num_tokens)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,
            dim=-1
        )

        context_vec = attn_weights @ values  # shape: (num_tokens, d_out)
        return context_vec

In [15]:
torch.manual_seed(0)

sa_v2 = SelfAttention_v2(d_in=d_in, d_out=d_out)
print("Context vectors from SelfAttention_v2:\n", sa_v2(inputs))

Context vectors from SelfAttention_v2:
 tensor([[-0.5844,  0.3235],
        [-0.5871,  0.3269],
        [-0.5871,  0.3269],
        [-0.5873,  0.3265],
        [-0.5876,  0.3276],
        [-0.5870,  0.3259]], grad_fn=<MmBackward0>)


## 3.5 Hiding Future Words with Causal Attention

### 3.5.1 Applying a Causal Attention Mask

Causal self-attention ensures that the model's prediction for a certain position in a sequence is only dependent on the known outputs at previous positions, not on future positions. To achieve this, for each given token, we should mask out the future tokens (the ones that come after the current token in the input text).

For example, in the first step, we use the attention scores and attention weights from the previous section:

In [16]:
# Reuse the query and key weight matrices from sa_v2
# to compute the attention scores matrix and attention weights matrix
queries = sa_v2.W_query(inputs)  # shape: (num_tokens, d_out)
keys = sa_v2.W_key(inputs)       # shape: (num_tokens, d_out

# Attention scores matrix
attn_scores = queries @ keys.T
# Attention weights matrix
attn_weights = torch.softmax(
    attn_scores / keys.shape[-1]**0.5,
    dim=-1
)

print("Attention weights matrix from SelfAttention_v2:\n", attn_weights)

Attention weights matrix from SelfAttention_v2:
 tensor([[0.1762, 0.1616, 0.1619, 0.1662, 0.1712, 0.1630],
        [0.1665, 0.1678, 0.1675, 0.1669, 0.1614, 0.1699],
        [0.1664, 0.1679, 0.1676, 0.1670, 0.1612, 0.1700],
        [0.1654, 0.1679, 0.1677, 0.1669, 0.1631, 0.1689],
        [0.1645, 0.1691, 0.1686, 0.1671, 0.1595, 0.1713],
        [0.1666, 0.1671, 0.1670, 0.1668, 0.1648, 0.1678]],
       grad_fn=<SoftmaxBackward0>)


The simplest way to mask out future attention weights is by creating a mask with PyTorch's `tril` function with elements below the main diagonal (including the diagonal itself) set to 1 and above the main diagonal set to 0.

In [17]:
context_length = attn_scores.shape[0]  # number of tokens in the input sequence

# `tril` creates a lower triangular matrix
mask_simple = torch.tril(
    torch.ones((context_length, context_length))
)
print("Simple causal attention mask:\n", mask_simple)

Simple causal attention mask:
 tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


Multiply this mask with the attention weights to zero-out the values above the main diagonal, effectively preventing the model from attending to future tokens.

In [18]:
masked_simple = attn_weights * mask_simple
print("Masked attention weights matrix:\n", masked_simple)

Masked attention weights matrix:
 tensor([[0.1762, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1665, 0.1678, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1664, 0.1679, 0.1676, 0.0000, 0.0000, 0.0000],
        [0.1654, 0.1679, 0.1677, 0.1669, 0.0000, 0.0000],
        [0.1645, 0.1691, 0.1686, 0.1671, 0.1595, 0.0000],
        [0.1666, 0.1671, 0.1670, 0.1668, 0.1648, 0.1678]],
       grad_fn=<MulBackward0>)


The next step is to **re-normalize** the attention weights to sum up to 1 again in each row after applying the causal mask. This step is crucial because the masking process alters the original attention weights, and re-normalization ensures that the remaining weights still represent a valid probability distribution over the allowed tokens.

In [19]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_normalized = masked_simple / row_sums
print("Re-normalized masked attention weights matrix:\n", masked_simple_normalized)

Re-normalized masked attention weights matrix:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4981, 0.5019, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3316, 0.3345, 0.3339, 0.0000, 0.0000, 0.0000],
        [0.2476, 0.2514, 0.2511, 0.2498, 0.0000, 0.0000],
        [0.1985, 0.2040, 0.2035, 0.2016, 0.1925, 0.0000],
        [0.1666, 0.1671, 0.1670, 0.1668, 0.1648, 0.1678]],
       grad_fn=<DivBackward0>)


We can still improve the implementation of the causal attention.

The softmax function converts its inputs into a probability distribution. When negative infinity values are present in a row, the softmax function treats them as zero probabilities: $e^{-\infty} = 0$. Thus, we can implement a more efficient masking "trick" by masking the unnormalized attention scores above the main diagonal with negative infinity before they enter the softmax function.

In [21]:
mask = torch.triu(
    torch.ones(context_length, context_length),
    diagonal=1
)
print(mask)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


In [22]:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print("Masked attention scores matrix with -inf:\n", masked)

Masked attention scores matrix with -inf:
 tensor([[-0.0021,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0191,  0.0297,    -inf,    -inf,    -inf,    -inf],
        [ 0.0196,  0.0319,  0.0293,    -inf,    -inf,    -inf],
        [ 0.0109,  0.0323,  0.0306,  0.0234,    -inf,    -inf],
        [ 0.0231,  0.0619,  0.0585,  0.0451, -0.0203,    -inf],
        [ 0.0068,  0.0113,  0.0104,  0.0086, -0.0085,  0.0175]],
       grad_fn=<MaskedFillBackward0>)


Now all we need to do is apply the softmax function to these masked results:

In [23]:
attn_weights = torch.softmax(
    masked / keys.shape[-1]**0.5,
    dim=-1
)
print("Attention weights matrix after applying -inf mask:\n", attn_weights)

Attention weights matrix after applying -inf mask:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4981, 0.5019, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3316, 0.3345, 0.3339, 0.0000, 0.0000, 0.0000],
        [0.2476, 0.2514, 0.2511, 0.2498, 0.0000, 0.0000],
        [0.1985, 0.2040, 0.2035, 0.2016, 0.1925, 0.0000],
        [0.1666, 0.1671, 0.1670, 0.1668, 0.1648, 0.1678]],
       grad_fn=<SoftmaxBackward0>)


As we can see based on the output, the values in each row sum to 1, and no further re-normalization is needed.

### 3.5.2 Masking Additional Attention Weights with Dropout

Dropout is used to prevent overfitting during training by randomly setting a fraction of the input units to zero at each update during training time, which helps to break up co-adaptations among neurons.

In the transformer architecture, dropout in the attention mechanism is applied at two specific times:
- after calculating the attention weights or 
- after applying the attention weights to the value vectors.

For demo purposes, we will apply the dropout after computing the attention weights with a dropout rate of 50%.

When applying dropout to an attention weight matrix with a dropout rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of `1 / (1 - dropout_rate)` which in this case is `1 / (1 - 0.5) = 2`.

In [24]:
torch.manual_seed(0)

dropout = torch.nn.Dropout(p=0.5) # Apply dropout with a rate of 50%
example = torch.ones(6, 6) # Example attention weights matrix
dropped_out = dropout(example)

print("Attention weights matrix after applying dropout:\n", dropped_out)

Attention weights matrix after applying dropout:
 tensor([[0., 0., 2., 0., 0., 0.],
        [2., 2., 0., 2., 2., 2.],
        [2., 0., 2., 2., 2., 2.],
        [0., 0., 2., 0., 2., 0.],
        [2., 0., 0., 0., 0., 2.],
        [0., 0., 2., 2., 2., 0.]])


Apply dropout to the attention weights matrix:

In [25]:
torch.manual_seed(0)

print("Apply dropout to the attention weights matrix:")
print(dropout(attn_weights))

Apply dropout to the attention weights matrix:
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9963, 1.0037, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6632, 0.0000, 0.6678, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5022, 0.0000, 0.0000, 0.0000],
        [0.3969, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.3340, 0.3335, 0.3295, 0.0000]],
       grad_fn=<MulBackward0>)


### 3.5.3 Implementing a Compact Causal Attention Class

We will now implement a compact `CausalAttention` class that incorporates the causal masking and dropout mechanisms into the self-attention process. We need to make sure this class can handle batches consisting of more than one input sequence so that it supports the batch outputs produced by the data loader during training.

In [26]:
# For demo purposes we duplicate the input text sequence
batch = torch.stack((inputs, inputs), dim=0)  # shape: (batch_size, num_tokens, d_in)
print("Input batch shape:", batch.shape)

Input batch shape: torch.Size([2, 6, 3])


This means that we have a 3-dimensional tensor consisting of *two* input sequences with *six* tokens each, where each token is represented by a *three*-dimensional embedding vector.

In [27]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()

        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # Dropout layer
        self.dropout = nn.Dropout(p=dropout)
        # Causal mask
        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones(context_length, context_length),
                diagonal=1
            )
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape # batch size, number of tokens, input embedding size

        # For inputs where `num_tokens` < `context_length`, this will result in errors
        # in the mask creation further below.
        # In practice, this is not a problem since the LLM ensures that
        # inputs do not exceed `context_length` before reaching this forward method.
        keys = self.W_key(x)       # shape: (b, num_tokens, d_out)
        queries = self.W_query(x)   # shape: (b, num_tokens, d_out)
        values = self.W_value(x)    # shape: (b, num_tokens, d_out)

        attn_scores = queries @ keys.transpose(1, 2) # shape: (b, num_tokens, num_tokens)
        # New operations for causal masking
        # `_` ops are in-place operations
        attn_scores.masked_fill_(
            # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,
            dim=-1
        )
        attn_weights = self.dropout(attn_weights) # Apply dropout to attention weights

        context_vec = attn_weights @ values # shape: (b, num_tokens, d_out)

        return context_vec

The use of `register_buffer` in PyTorch is to store tensors that are not parameters of the model but should still be part of the model's state. This is particularly useful for tensors that are used in computations but do not require gradients, such as masks or fixed constants. By registering a tensor as a buffer, it will be included in the model's state dictionary, allowing it to be saved and loaded along with the model's parameters.

For example, when we use the `CausalAttention` class in an LLM, buffers are automatically moved to the appropriate device (CPU or GPU) along with our model, which will be relevant when training our LLM, so that we don't need to manually ensure these tensors are on the same device as our model parameters, avoiding device mismatch errors.

In [29]:
torch.manual_seed(0)

context_length = batch.shape[1] # number of tokens in the input sequence
ca = CausalAttention(
    d_in=d_in,
    d_out=d_out,
    context_length=context_length,
    dropout=0.0
)
context_vecs = ca(batch)
print("Context vectors from CausalAttention:\n", context_vecs)
print("Context vectors shape:", context_vecs.shape)

Context vectors from CausalAttention:
 tensor([[[-0.5063,  0.3518],
         [-0.6503,  0.3955],
         [-0.6976,  0.4064],
         [-0.6289,  0.3677],
         [-0.6131,  0.3179],
         [-0.5870,  0.3259]],

        [[-0.5063,  0.3518],
         [-0.6503,  0.3955],
         [-0.6976,  0.4064],
         [-0.6289,  0.3677],
         [-0.6131,  0.3179],
         [-0.5870,  0.3259]]], grad_fn=<UnsafeViewBackward0>)
Context vectors shape: torch.Size([2, 6, 2])


## 3.6 Extending Single-Head Attention to Multi-Head Attention

### 3.6.1 Stacking Multiple Single-Head Attention Layers

The main idea behind multi-head attention is to run the attention mechanism multiple times (in parallel) with different, learned linear projections of the queries, keys, and values. Each of these parallel attention mechanisms is referred to as a "head."

We can achieve this by implementing a simple `MultiHeadAttentionWrapper` class that stacks multiple `CausalAttention` layers:

In [30]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(
                    d_in=d_in,
                    d_out=d_out,
                    context_length=context_length,
                    dropout=dropout,
                    qkv_bias=qkv_bias
                )
                for _ in range(num_heads)
            ]
        )

    def forward(self, x):
        # Concatenate the outputs from all heads
        head_outputs = torch.cat(
            [head(x) for head in self.heads],
            dim=-1
        )
        return head_outputs

In [31]:
torch.manual_seed(0)

context_length = batch.shape[1] # number of tokens in the input sequence
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in=d_in,
    d_out=d_out,
    context_length=context_length,
    dropout=0.0,
    num_heads=2
)

context_vecs = mha(batch)
print("Context vectors from MultiHeadAttentionWrapper:\n", context_vecs)
print("Context vectors shape:", context_vecs.shape)

Context vectors from MultiHeadAttentionWrapper:
 tensor([[[-0.5063,  0.3518, -0.3550, -0.6560],
         [-0.6503,  0.3955, -0.1536, -0.7514],
         [-0.6976,  0.4064, -0.0853, -0.7803],
         [-0.6289,  0.3677, -0.0297, -0.7015],
         [-0.6131,  0.3179, -0.0417, -0.6247],
         [-0.5870,  0.3259, -0.0040, -0.6322]],

        [[-0.5063,  0.3518, -0.3550, -0.6560],
         [-0.6503,  0.3955, -0.1536, -0.7514],
         [-0.6976,  0.4064, -0.0853, -0.7803],
         [-0.6289,  0.3677, -0.0297, -0.7015],
         [-0.6131,  0.3179, -0.0417, -0.6247],
         [-0.5870,  0.3259, -0.0040, -0.6322]]], grad_fn=<CatBackward0>)
Context vectors shape: torch.Size([2, 6, 4])


Here we set `num_heads=2` and `d_out=2`, meaning that each head will produce an output embedding of size 2. Since we have 2 heads, the combined output embedding size will be `2 heads * 2 dimensions/head = 4 dimensions`. This is why the final output shape is `(2, 6, 4)`, reflecting the batch size of 2, sequence length of 6, and combined output embedding size of 4.

### 3.6.2 Implementing Multi-Head Attention with Weight Splits

We will implement a single `MultiHeadAttention` class to make the multi-head attention mechanism more efficient.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        # Reduce the projection dim to match desired output dim
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        
        self.dropout = nn.Dropout(p=dropout)
        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones(context_length, context_length),
                diagonal=1
            )
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape # batch size, number of tokens, input embedding size
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x)      # shape: (b, num_tokens, d_out)
        queries = self.W_query(x)  # shape: (b, num_tokens, d_out)
        values = self.W_value(x)   # shape: (b, num_tokens, d_out)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dimension:
        # (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose:
        # (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(
            mask_bool,
            -torch.inf
        )

        attn_weights = torch.softmax(
            attn_scores / self.head_dim**0.5,
            dim=-1
        )
        attn_weights = self.dropout(attn_weights) # Apply dropout to attention weights

        # Compute context vectors for each head and
        # Reshape back to (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        # optional projection
        context_vec = self.out_proj(context_vec)

        return context_vec

There are lots of reshaping (`.view`) and transposing (`.transpose`) operations involved in this implementation. The idea is to combine the weight matrices for all heads into single large weight matrices and then split the results into multiple heads after the linear transformations. This approach reduces the number of separate matrix multiplications, which can be more efficient in terms of computation and memory usage. The `MultiHeadAttention` class implements the same concept as the `MultiHeadAttentionWrapper` but in a more efficient manner.

Compared to the `MultiHeadAttentionWrapper` where we stacked multiple `CausalAttention` layers, the `MultiHeadAttention` starts with a multi-head layer and then internally splits this layer into individual attention heads for processing.

The splitting of the query, key, and value tensors is done using tensor reshaping and transposing operations (`.view` and `.transpose`).

The key operation is to split the `d_out` dimension into `num_heads` and `head_dim`, where `head_dim = d_out / num_heads`. 

The tensors are then transposed to bring the `num_heads` dimension before the `num_tokens` dimension, resulting in a shape of `(batch_size, num_heads, num_tokens, head_dim)`. This arrangement allows each head to process its respective portion of the data independently, and correctly aligning the queries, keys, and values across the different heads and performing batched matrix multiplications for efficiency.

To demo the batched maxtrix multiplication, suppose we have the following tensors:

In [33]:
# (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4)
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])
print("Tensor a shape:", a.shape)

Tensor a shape: torch.Size([1, 2, 3, 4])


Now we perform a batched matrix multiplication between the tensor `a` itself and a view of the tensor where we transposed the last two dimensions, `num_tokens` and `head_dim`:

In [34]:
print("Batched matrix multiplication of a and a transposed:")
print(a @ a.transpose(2, 3))

Batched matrix multiplication of a and a transposed:
tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


In this case, the batched matrix multiplication in PyTorch handles the 4-dimensional input tensor so that the multiplication is carried out between the two last dimensions `(num_tokens, head_dim)` and then repeated for the individual heads.

For example, the preceding becomes a more compact way to compute the matrix multiplication for each head separately:

In [35]:
first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head matrix multiplication result:\n", first_res)

second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head matrix multiplication result:\n", second_res)

First head matrix multiplication result:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])

Second head matrix multiplication result:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


In [36]:
# This is exactly equivalent to the batched matrix multiplication above.
print(a @ a.transpose(2, 3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


Continuing with `MultiHeadAttention`, after computing the attention weights and applying them to the value vectors for each head, we need to combine the outputs from all heads back into a single tensor with shape `(batch_size, num_tokens, d_out)`. 

Additionally, we added an output projection layer `self.out_proj` to the `MultiHeadAttention` class after combining the heads. This output projection is not strictly necessary but it is commonly used in many LLM architectures to further transform the combined output of the multi-head attention mechanism. This transformation can help in capturing more complex relationships and interactions between the different heads' outputs.

Testing the `MultiHeadAttention` class:

In [37]:
torch.manual_seed(0)

batch_size, context_length, d_in = batch.shape # batch size, number of tokens, input embedding size
d_out = 2

mha = MultiHeadAttention(
    d_in=d_in,
    d_out=d_out,
    context_length=context_length,
    dropout=0.0,
    num_heads=2
)
context_vecs = mha(batch)
print("Context vectors from MultiHeadAttention:\n", context_vecs)
print("Context vectors shape:", context_vecs.shape)

Context vectors from MultiHeadAttention:
 tensor([[[-0.0111,  0.6056],
         [ 0.0435,  0.5950],
         [ 0.0630,  0.5891],
         [ 0.0426,  0.5836],
         [ 0.0496,  0.5591],
         [ 0.0353,  0.5700]],

        [[-0.0111,  0.6056],
         [ 0.0435,  0.5950],
         [ 0.0630,  0.5891],
         [ 0.0426,  0.5836],
         [ 0.0496,  0.5591],
         [ 0.0353,  0.5700]]], grad_fn=<ViewBackward0>)
Context vectors shape: torch.Size([2, 6, 2])


Now we can see the output dimension is directly controlled by the `d_out` parameter specified when creating the `MultiHeadAttention` instance.