<a href="https://colab.research.google.com/github/linhoangce/llm_from_scratch/blob/main/chapter3_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 3.3 Attending to different parts of the input with self-attention

**THE "SELF" IN SELF-ATTENTION**

In self-attention, the "self" refers to the mechanism's ability to compute attention weights by relating different positions within a single input sequence. It asseses and learns the relationships and dependencies between various parts of the input itself, such as words in a sentence or pixels in an image.

This is in constrast to traditional attention mechanisms, where the focus is on the relationships between elements of two different sequences, such as in sequence-to-sequence models where the attention might be between an input sequence and an output sequence.

### 3.3.1 A simple self-attention mechanism without trainable weights

The goal of self-attention is to compute a context vector for each input element that combines information from all other input elements. A *context vector* can be interpreted as an enriched embedding vector/representation of each element in an input sequence (like a sentence) by incorporating information from all other elements in the sequence.



In [81]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
     [0.55, 0.87, 0.66], # journey  (x^2)
     [0.57, 0.85, 0.64], # starts   (x^3)
     [0.22, 0.58, 0.33], # with     (x^4)
     [0.77, 0.25, 0.10], # one      (x^5)
     [0.05, 0.80, 0.55]  # step     (x^6)
     ]
)


# calculate intermediate attention scores between
# query token and each input token with dot product
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
  attn_scores_2[i] = torch.dot(x_i, query)

attn_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

**Dot Product**: Beyond an operation as a mathematical tool that combines two vectors to yielda scaler value, the dot product is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot product indicates a greater degree of alignment of similarity between the vectors. In the context of self-attention mechanisms, the dot product determines the extend to which each element in a sequence focuses on, or "attend to", any other elements: the higher the dot product, the higher the similarity and attention score between the two elements.

**Mental model to remember**

* Vectors = meanings

* Dot product = alignment of meanings

* High alignment = strong attention

* Self-attention = each token asking “who is most relevant to me?”

In [82]:
# normalize attention scores for  interpretation and stability
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print(f'Attention weights: {attn_weights_2_tmp}')
print(f'Sum: {attn_weights_2_tmp.sum()}')

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: 1.0000001192092896


In [83]:
# normalize using softmax
def softmax_naive(x):
  return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print(f'Attention weights: {attn_weights_2_naive}')
print(f'Sum: {attn_weights_2_naive.sum()}')

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.0


In [84]:
# calculate context vector by multiplying embedded input tokens
# with corresponding attention weights, then sum results
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
  context_vec_2 += attn_weights_2_tmp[i] * x_i

context_vec_2

tensor([0.4355, 0.6451, 0.5680])

In [85]:
context_vec_2 = torch.matmul(attn_weights_2_tmp, inputs)
context_vec_2

tensor([0.4355, 0.6451, 0.5680])

### 3.3.2 Computing attention weights for all input tokens

In [86]:
attn_scores = torch.empty(6, 6)
attn_scores = torch.matmul(inputs, inputs.T) # same as using @
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [87]:
inputs @ inputs.T

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [88]:
# setting dim=-1 apply normalization along last dim,
# e.g: (rows, columns) -> normalized across columns
attn_weights = torch.softmax(attn_scores, dim=-1)
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [89]:
# compute context vectors for all inputs
all_context_vecs = attn_weights @ inputs
all_context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

## 3.4 Implementing self-attention with trainable weights

### 3.4.1 Computing the attention weights step by step

In [90]:
x_2 = inputs[1]
d_in = inputs.shape[1] # input embedding size
d_out = 2 # output embedding size

In [91]:
from torch.nn import Parameter

torch.manual_seed(123)

# initialize weight matrices
W_query = Parameter(torch.rand(d_in, d_out),
                     requires_grad=False)
W_key = Parameter(torch.rand(d_in, d_out),
                   requires_grad=False)
W_value = Parameter(torch.rand(d_in, d_out),
                     requires_grad=False)

In [92]:
# compute query, key, value vectors
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

query_2

tensor([0.4306, 1.4551])

In [93]:
# get all keys and values
keys = inputs @ W_key
values = inputs @ W_value
print(f'keys.shape: {keys.shape}')
print(f'values.shape: {values.shape}')

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


**Attention score**: dot-product computation between input elements using *query* and *key* by transforming inputs via respective weight matrices.

In [94]:
# compute attention score omega_22
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
attn_score_22

tensor(1.8524)

In [95]:
# attention scores for query 2 with all inputs
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [96]:
# scale attention scores and apply softmax
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5,
                               dim=-1)
attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

**THE RATIONALE BEHIND SCALED-DOT PRODUCT ATTENTION**

The reason for the normalization by the embedding dimension size is to improve the training performance by avoiding small gradients. For instance, when scaling up the embedding dimension, which is typically greater than 1,000 for GPT-like LLMs, large dot products can result in very small gradients during backpropagation due to the softmax function applied to them. As dot products increase, the softmax function behaves more like a step function resulting in gradient nearing zero. These small gradients can drastically slow down learning or cause training to stagnate.

The scaling by the square root of the embedding dimension is the reason why this self-attention mechanism is also called scaled-dot product attention.

In [97]:
# compute context vectors
context_vec_2 = attn_weights_2 @ values
context_vec_2

tensor([0.3061, 0.8210])

### 3.4.2 Implementing a compact self-attention Python class

In [98]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_query = nn.Parameter(torch.rand(d_in, d_out))
    self.W_key = nn.Parameter(torch.rand(d_in, d_out))
    self.W_value = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
    keys = x @ self.W_key
    queries = x @ self.W_query
    values = x @ self.W_value
    attn_scores = queries @ keys.T # omega
    attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
    )
    context_vec = attn_weights @ values
    return context_vec


In [99]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v1(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

In [100]:
class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias=False):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)
    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
    )
    context_vec = attn_weights @ values
    return context_vec


In [101]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
sa_v2(inputs) # callable-object for efficiency

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

In [102]:
sa_v2.W_query, sa_v1.W_query

(Linear(in_features=3, out_features=2, bias=False),
 Parameter containing:
 tensor([[0.2961, 0.5166],
         [0.2517, 0.6886],
         [0.0740, 0.8665]], requires_grad=True))

In [103]:
# Assign weights from nn.Linear instance to nn.Parameters obj
sa_v1.W_query.data = sa_v2.W_query.weight.data.T
sa_v1.W_key.data = sa_v2.W_key.weight.data.T
sa_v1.W_value.data = sa_v2.W_value.weight.data.T

sa_v1(inputs), sa_v2(inputs)

(tensor([[-0.0739,  0.0713],
         [-0.0748,  0.0703],
         [-0.0749,  0.0702],
         [-0.0760,  0.0685],
         [-0.0763,  0.0679],
         [-0.0754,  0.0693]], grad_fn=<MmBackward0>),
 tensor([[-0.0739,  0.0713],
         [-0.0748,  0.0703],
         [-0.0749,  0.0702],
         [-0.0760,  0.0685],
         [-0.0763,  0.0679],
         [-0.0754,  0.0693]], grad_fn=<MmBackward0>))

## 3.5 Hiding future words with casual attention

**Casual attention**: also known as *masked attention, restricts a model to only consider previous and current inputs in a sequence when processing any given token when computing attention scores. This is in contrast to the standard self-attention mechanism, which allows access to the enture input sequence at once.

### 3.5.1 Applying a casual attention mask

In [104]:
keys.shape[-1]

2

In [105]:
keys.shape[-1]**0.5

1.4142135623730951

In [106]:
# compute attention weights using softmax
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(
    attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

In [107]:
attn_scores

tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)

In [108]:
# create mask of value 0 above diagonal
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length,
                                    context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [109]:
# zero out values above diagonal
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

In [110]:
# renormalize attention weights to sum to 1 per row
row_sums = masked_simple.sum(dim=-1, keepdims=True)
masked_simple_norm = masked_simple / row_sums
masked_simple_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

**INFORMATION LEAKAGE**

When we apply a mask and then renormalize the attention weights, it might initially appear that information from future tokens (which we intend to mask) could still influence the current token because their values are part of the softmax calculation. However, the key insight is that when we renormalize the attention weights after masking, what we're essentially doing is recalculating the softmax over a smaller subset (since maksed positions don't contribute to the softmax value).

The mathematical elegance of softmax is that despite initially including all positions in the denominator, after masking and renormalizing, the effect of the masked positions is nullified - they don't contribute to the softmax score in any meaningful way.

In simpler terms, after masking and renormalization, the distribution of attention weights is as if it was calculated only among the unmasked positions to begin with. This ensures there's no information leakage from future (or otherwise masked) tokens as we intended.

The softmax converts its inputs into a probability distribution. When negative infinity values are present in a row, the softmax function treats them as zero probability (Mathematically, this is because e^-inf approaches 0).

We can implement this more efficient masking "trick" by creating a mask with 1s above the diagonal and then replacing these 1s with negative infinity values:


In [111]:
mask = torch.triu(torch.ones(context_length, context_length),
                  diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)

In [112]:
# apply softmax to these results
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5,
                             dim=-1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

### 3.5.2 Masking additional attention weights with dropout

In the transformer architecture, including models like GPT, dropout in the attention mechanism is typically applied at two specific times: after calculating the attention weights or after applying the attention weights to the value vectors. Here we will apply the dropout mask after computing the attention weights because it's more common in practice.

In [113]:
# apply dropout rate of 50%
torch.manual_seed(123)

dropout = nn.Dropout(0.5)
example = torch.ones(6, 6)
dropout(example)

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

When applying aropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of $1/0.5 = 2$. This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.

In [114]:
torch.manual_seed(123)
dropout(attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)

### 3.5.3 Implementing a compact casual attention class

In [115]:
# simulate batch inputs
batch = torch.stack((inputs, inputs), dim=0)
batch.shape

torch.Size([2, 6, 3])

In [116]:
class CausalAttention(nn.Module):
  """
  1. Adds dropout layer
  2. not strictly necessary but offer advantages: buffer automatically moved to appropriate device along with model
  3. transposes dimensions 1 and 2, keep same batch dimension
  4. in place operation for memory efficiency
  """
  def __init__(self, d_in, d_out, context_length,
               dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = nn.Dropout(dropout) # 1
    self.register_buffer( # 2
        'mask',
        torch.triu(torch.ones(context_length, context_length),
                   diagonal=1),
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = queries @ keys.transpose(1, 2) # 3
    attn_scores.masked_fill_( # 4
        self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
    )
    attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
    )
    attn_weights = self.dropout(attn_weights)

    context_vec = attn_weights @ values
    return context_vec

In [117]:
torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
context_vecs.shape

torch.Size([2, 6, 2])

## 3.6 Extending single-head attention to multi-head attention

The term "multi-head" refers to dividing the attention mechanism into multiple "heads", each operating independently. In this context, a single casual attention module can be considered single-head attention, where there is only one set of attention weights processing the input sequentially.

### 3.6.1 Stacking multiple single-head attention layers

The main idea behind multi-head attention is to run the attention mechanism multiple times (in parallel) with different, learned linear projections - the results of multipling the input data (like query, key, and value vectors in attention mechanisms) by a weight matrix.

In [118]:
class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, context_length,
               dropout, num_heads, qkv_bias=False):
    super().__init__()
    self.heads = nn.ModuleList(
        [CausalAttention(
            d_in, d_out, context_length, dropout, qkv_bias
        )
        for _ in range(num_heads)]
    )

  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)


In [119]:
torch.manual_seed(123)

context_length = batch.shape[1] # number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

print(context_vecs)
print(f'context_vecs.shape: {context_vecs.shape}')

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


In [120]:
mha = MultiHeadAttentionWrapper(
    d_in, 1, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
context_vecs, context_vecs.shape

(tensor([[[0.0189, 0.2729],
          [0.2181, 0.3037],
          [0.2804, 0.3125],
          [0.2830, 0.2793],
          [0.2476, 0.2541],
          [0.2748, 0.2513]],
 
         [[0.0189, 0.2729],
          [0.2181, 0.3037],
          [0.2804, 0.3125],
          [0.2830, 0.2793],
          [0.2476, 0.2541],
          [0.2748, 0.2513]]], grad_fn=<CatBackward0>),
 torch.Size([2, 6, 2]))

### 3.6.2 Implementing multi-head attention with weight splits

In [133]:
class MultiHeadAttention(nn.Module):
  """
  Integrates the multi-head functionality within a single
  class. It splits the input into multiple heads by reshaping the projected query, key, value tensor and then combines the
  results from these heads after computing attention.

  1. Reduces the projection dim to match the desired output dim
  2. Uses a Linear layer to combine head outputs
  3. Tensor shape: (b, num_tokens, d_out)
  4. Implicitly split the matrix by adding a num_heads
  dimension. Then unrolls the last dim: (n, num_tokens, d_out)
  -> (b, num_tokens, num_heads, head_dim)
  5. Transposes from shape (b, num_tokens, num_heads, head_dim)
  to (b, num_heads, num_tokens, head_dim)
  6. Computes dot product for each head
  7. Masks truncated to the number of tokens
  8. Uses the mask to fill attention scores
  9. Tensor shape: (n, num_tokens, n_heads, head_dim)
  10. Combines heads, where self.d_out = self.num_heads * self.head_dim
  11. Adds an optional linear projection
  """
  def __init__(self, d_in, d_out, context_length,
               dropout, num_heads, qkv_bias=False):
    super().__init__()
    assert (d_out % num_heads == 0), \
        "d_out must be divisible by num_heads"

    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads # 1
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.out_proj = nn.Linear(d_out, d_out) # 2
    self.dropout = nn.Dropout(dropout)
    self.register_buffer(
        'mask',
        torch.triu(torch.ones(context_length, context_length),
                   diagonal=1)
    )

  def forward(self, x):
    b, num_tokens, d_in = x.shape
    keys = self.W_key(x) # 3
    queries = self.W_query(x) # 3
    values = self.W_value(x) # 3

    keys = keys.view(b, num_tokens, self.num_heads,
                      self.head_dim) # 4
    values = values.view(b, num_tokens,
                         self.num_heads, self.head_dim)
    queries = queries.view(b, num_tokens,
                           self.num_heads, self.head_dim)

    keys = keys.transpose(1, 2) # 5
    queries = queries.transpose(1, 2)
    values = values.transpose(1, 2)

    attn_scores = queries @ keys.transpose(2, 3) # 6
    mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # 7

    attn_scores.masked_fill_(mask_bool, -torch.inf) # 8

    attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
    )
    attn_weights = self.dropout(attn_weights)

    context_vec = (attn_weights @ values).transpose(1, 2) # 9
    context_vec = context_vec.contiguous().view(
        b, num_tokens, self.d_out
    ) # 10
    context_vec = self.out_proj(context_vec) # 11
    return context_vec

In [134]:
### Illustrate batched matrix multiplication
# shape (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4)
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

a.transpose(2, 3), a.transpose(2, 3).shape

(tensor([[[[0.2745, 0.8993, 0.7179],
           [0.6584, 0.0390, 0.7058],
           [0.2775, 0.9268, 0.9156],
           [0.8573, 0.7388, 0.4340]],
 
          [[0.0772, 0.4066, 0.4606],
           [0.3565, 0.2318, 0.5159],
           [0.1479, 0.4545, 0.4220],
           [0.5331, 0.9737, 0.5786]]]]),
 torch.Size([1, 2, 4, 3]))

In [135]:
a @ a.transpose(2, 3)

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])

In [136]:
first_head = a[0, 0, :, :]
second_head = a[0, 1, :, :]
first_head, second_head

(tensor([[0.2745, 0.6584, 0.2775, 0.8573],
         [0.8993, 0.0390, 0.9268, 0.7388],
         [0.7179, 0.7058, 0.9156, 0.4340]]),
 tensor([[0.0772, 0.3565, 0.1479, 0.5331],
         [0.4066, 0.2318, 0.4545, 0.9737],
         [0.4606, 0.5159, 0.4220, 0.5786]]))

In [137]:
first_res = first_head @ first_head.T
second_res = second_head @ second_head.T
first_res, second_res


(tensor([[1.3208, 1.1631, 1.2879],
         [1.1631, 2.2150, 1.8424],
         [1.2879, 1.8424, 2.0402]]),
 tensor([[0.4391, 0.7003, 0.5903],
         [0.7003, 1.3737, 1.0620],
         [0.5903, 1.0620, 0.9912]]))

In [138]:
torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0,
                         num_heads=2)
context_vecs = mha(batch)
context_vecs, context_vecs.shape

(tensor([[[0.3190, 0.4858],
          [0.2943, 0.3897],
          [0.2856, 0.3593],
          [0.2693, 0.3873],
          [0.2639, 0.3928],
          [0.2575, 0.4028]],
 
         [[0.3190, 0.4858],
          [0.2943, 0.3897],
          [0.2856, 0.3593],
          [0.2693, 0.3873],
          [0.2639, 0.3928],
          [0.2575, 0.4028]]], grad_fn=<ViewBackward0>),
 torch.Size([2, 6, 2]))

In [139]:
gpt_mha = MultiHeadAttention(d_in=768, d_out=768,
                             context_length=1024,
                             dropout=0.0,
                             num_heads=12)
gpt_mha

MultiHeadAttention(
  (W_query): Linear(in_features=768, out_features=768, bias=False)
  (W_key): Linear(in_features=768, out_features=768, bias=False)
  (W_value): Linear(in_features=768, out_features=768, bias=False)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

In [148]:
batch_gpt = torch.rand(size=(2, 768, 768))
gpt_mha(batch_gpt)

tensor([[[-0.0929, -0.0228, -0.0636,  ..., -0.0095, -0.1502, -0.0183],
         [-0.0523, -0.0455,  0.0185,  ...,  0.0080, -0.0402, -0.1689],
         [-0.0413, -0.0100,  0.0419,  ..., -0.0295, -0.0688, -0.1732],
         ...,
         [ 0.0696,  0.0380, -0.0127,  ...,  0.0112, -0.0339, -0.0346],
         [ 0.0695,  0.0381, -0.0128,  ...,  0.0114, -0.0340, -0.0345],
         [ 0.0695,  0.0380, -0.0126,  ...,  0.0112, -0.0336, -0.0344]],

        [[ 0.1281, -0.0059, -0.0568,  ...,  0.0641, -0.2194, -0.1864],
         [ 0.0844,  0.0588, -0.0159,  ...,  0.0928, -0.0769, -0.0961],
         [ 0.1025,  0.1080, -0.0090,  ...,  0.1047, -0.0564, -0.0802],
         ...,
         [ 0.0753,  0.0463, -0.0138,  ...,  0.0127, -0.0429, -0.0418],
         [ 0.0752,  0.0462, -0.0139,  ...,  0.0127, -0.0428, -0.0417],
         [ 0.0755,  0.0465, -0.0138,  ...,  0.0123, -0.0428, -0.0422]]],
       grad_fn=<ViewBackward0>)

# Summary

* Attention mechanisms transform input elements into enhanced context vector representations that incorporate information about all inputs.

* A self-attention mechanism computes the context vector representation as a weighted sum over the inputs.

* In a simplified attention mechanism, the attention weights are computed via dot products.

* A dot product is a concise way of multiplying two vectors element-wise and then summing the products.

* Matrix multiplications, while not strictly required, help us implement computations more efficiently and compactly by replacing nested *for* loops.

* In self-attention mechanisms used in LLMs, also called scaled-dot product attention, we include trainable weight matrices to compute intermediate transformations of the inputs: queries, values, and keys.

* When working with LLMs that read and generate text from left to right, we add a causal attention mask to prevent the LLM from acccessing futture tokens.

* In addition to causual attention masks to zero-out attention weights, we can add a dropout mask to reduce overfitting in LLMs.

* The attention modules in transformer-based LLMs involve multiple instances of causal attention, which is called multi-head attention.

* We can create a multi-head attention module by stacking multiple instances of causal attention modules.

* A more efficient way of creating multi-head attention modules involves batches matrix multiplications.