# Coding the Attention Mechanisms

## Attending to different parts of the input with self-attention


### A simple self-attention mechanism without trainable weights

In [193]:
import torch
inputs = torch.tensor(
  [
    [0.43, 0.15, 0.89], # Your   (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts  (x^3)
    [0.22, 0.58, 0.33], # with    (x^4)
    [0.77, 0.25, 0.10], # one     (x^5)
    [0.05, 0.80, 0.55]  # step    (x^6)
  ]
)
inputs.shape

torch.Size([6, 3])

In [196]:
input_query = inputs[1]
input_query

tensor([0.5500, 0.8700, 0.6600])

In [197]:
input_1 = inputs[0]
input_1

tensor([0.4300, 0.1500, 0.8900])

In [198]:
0.55 * 0.43 + 0.87 * 0.15 + 0.66 * 0.89

0.9544

In [199]:
torch.dot(input_query, input_1)

tensor(0.9544)

In [211]:
res = 0.
i = 3 # index of the input vector

for idx, element in enumerate(inputs[i]):
  res += element * input_query[idx]

res

tensor(0.8434)

In [210]:
# the same as above but using dot product

i = 3 # index of the input vector

res = torch.dot(inputs[i], input_query)
res


tensor(0.8434)

In [214]:
inputs.shape[0]

6

In [212]:
# now do them all at once in a loop

query = inputs[1]

attn_scores_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
  attn_scores_2[i] = torch.dot(x_i, query) # dot product of x_i and query

attn_scores_2


tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [None]:
# normalize the scores but with a simplified formula (no softmax)
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
attn_weights_2_tmp




tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])

In [216]:
attn_weights_2_tmp.sum()

tensor(1.0000)

In [None]:
# now do it with a softmax

def softmax_naive(x):
  return torch.exp(x) / torch.exp(x).sum()

softmax_naive(attn_scores_2)

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [None]:
# and now with a softmax from pytorch (it is recommended to use this over the naive implementation)

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print(attn_weights_2)

attn_weights_2.sum()


tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


tensor(1.)

#### now calculate the context vector

In [223]:
query = inputs[1] # the second input vector is the query

context_vec_2 = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
  print(f"{attn_weights_2[i]} ---> {x_i}")
  context_vec_2 += attn_weights_2[i] * x_i

context_vec_2


0.13854756951332092 ---> tensor([0.4300, 0.1500, 0.8900])
0.2378913015127182 ---> tensor([0.5500, 0.8700, 0.6600])
0.23327402770519257 ---> tensor([0.5700, 0.8500, 0.6400])
0.12399158626794815 ---> tensor([0.2200, 0.5800, 0.3300])
0.10818186402320862 ---> tensor([0.7700, 0.2500, 0.1000])
0.15811361372470856 ---> tensor([0.0500, 0.8000, 0.5500])


tensor([0.4419, 0.6515, 0.5683])

## A simple self-attention mechanism without trailable weights

### now calculating attention weights for all inputs 

In [225]:
inputs = torch.tensor(
  [
    [0.43, 0.15, 0.89], # Your   (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts  (x^3)
    [0.22, 0.58, 0.33], # with    (x^4)
    [0.77, 0.25, 0.10], # one     (x^5)
    [0.05, 0.80, 0.55]  # step    (x^6)
  ]
)


In [228]:
inputs.shape[0]

6

In [234]:
attn_scores = torch.empty(inputs.shape[0], inputs.shape[0])

for i, x_i in enumerate(inputs):
  # print(f"i: {i}")
  # print(f"x_i: {x_i}")
  # print("@")
  for j, x_j in enumerate(inputs):
    # print(f"x_j: {x_j}")
    attn_scores[i, j] = torch.dot(x_i, x_j) 

attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [235]:
# faster way to do the same thing with matrix multiplication
attn_scores = inputs @ inputs.T
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [236]:
# now add the softmax to the attention scores
attn_weights = torch.softmax(attn_scores, dim=1)
attn_weights



tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [243]:
# this is the transposed matrix of inputs
print("inputs:")
print(inputs)
print("inputs.T:")
print(inputs.T)


inputs:
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])
inputs.T:
tensor([[0.4300, 0.5500, 0.5700, 0.2200, 0.7700, 0.0500],
        [0.1500, 0.8700, 0.8500, 0.5800, 0.2500, 0.8000],
        [0.8900, 0.6600, 0.6400, 0.3300, 0.1000, 0.5500]])


In [237]:
# so two lines of code to do the same thing
attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(inputs @ inputs.T, dim=1)

attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [238]:
# now calculate the context vector
context_vec = attn_weights @ inputs

print(context_vec.shape)
context_vec


torch.Size([6, 3])


tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

---

## Implementing self-attention with **trainable weights**

### Computing the attention weights step by step



![Self Attention Weights Diagram](./images/self-attention-weights.png.png)

In Chapter 3, the transition from "simple self-attention" to "scaled dot-product attention" involves the introduction of **trainable weight matrices** ($W_q, W_k, W_v$). These matrices are used to transform your input tokens ($x$) into three distinct representational spaces: **Query ($q$)**, **Key ($k$)**, and **Value ($v$)** 1\.  
Here is an explanation of what these values represent and how they work:

#### 1\. The Trainable Weight Matrices

Instead of using the raw input vector ($x$) for every calculation, the model uses three separate weight matrices:

* **$W_q$ (Query weights):** Multiplied by the input to create the Query vector ($q$) 2\.  
* **$W_k$ (Key weights):** Multiplied by the input to create the Key vector ($k$) 2\.  
* **$W_v$ (Value weights):** Multiplied by the input to create the Value vector ($v$) 2\.

These matrices are **parameters** that start as random numbers and are optimized during training 3, 4\. This allows the model to learn how to "re-interpret" the same input word differently depending on whether it is acting as a query, a key, or a piece of content 5\.

#### 2\. The Three Vectors ($q, k, v$)

The sources explain these terms using an **information retrieval (database)** analogy 6:

* **Query ($q$):** This represents the **current token** the model is trying to understand 6\. It acts like a **search query** you type into a database to find relevant information from other parts of the sentence 6\.  
* **Key ($k$):** This acts like a **database index** or a label 7\. Every token in the sequence provides a Key that is compared against the current Query. The more a Key matches a Query (calculated via dot product), the more "attention" the model pays to that token 7, 8\.  
* **Value ($v$):** This represents the **actual content** or information of the token 7\. Once the model uses the Query and Key to decide which tokens are important, it extracts and sums up the "Values" of those important tokens to create the final context vector 7, 9\.

#### 3\. How they interact in the math

The process follows these specific steps:

1. **Transformation:** You multiply your input $x$ by $W_q, W_k, W_v$ to get your $q, k, v$ vectors 2\.  
2. **Scoring:** You take the **dot product of your Query ($q$) and all Keys ($k$)** to see how well they match 8, 10\.  
3. **Scaling and Normalization:** You scale these scores (dividing by the square root of the dimension) and apply **softmax** to turn them into percentages (attention weights) 11\.  
4. **Weighted Sum:** You multiply those percentages by the **Value ($v$)** vectors and sum them up 9\. This ensures the final "context vector" is made up of the most relevant information 12, 13\.

**Analogy:** Imagine you are a **researcher (the Query)** looking for information about "climate change" in a **library**.

* The **library's digital catalog (the Keys)** contains titles and keywords for every book. You compare your query to these titles to see which books are relevant.  
* The **actual information inside the books (the Values)** is what you actually read and take notes from.  
* Your **final research paper (the Context Vector)** is a "weighted sum" of all the information you gathered, where you took 80% of your notes from the most relevant books and only 2% from the less relevant ones.

In [257]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [245]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [299]:

# torch.nn.Parameter is a wrapper around a tensor that allows it to be trainable 
W_query = torch.nn.Parameter(torch.randn(d_in, d_out))
W_query

Parameter containing:
tensor([[-0.9724, -0.7550],
        [ 0.3239, -0.1085],
        [ 0.2103, -0.3908]], requires_grad=True)

In [295]:
W_key = torch.nn.Parameter(torch.randn(d_in, d_out))
W_key

Parameter containing:
tensor([[-0.4796, -0.5166],
        [-0.3107,  0.2057],
        [ 0.9657,  0.7057]], requires_grad=True)

In [294]:
W_value = torch.nn.Parameter(torch.randn(d_in, d_out))
W_value

Parameter containing:
tensor([[-0.9601, -0.4087],
        [ 1.0764, -0.4015],
        [-0.7291, -0.1218]], requires_grad=True)

In [293]:
query_2 = x_2 @ W_query
query_2

tensor([-1.1729, -0.0048], grad_fn=<SqueezeBackward4>)

In [255]:
keys = inputs @ W_key
keys

tensor([[-0.1823, -0.6888],
        [-0.1142, -0.7676],
        [-0.1443, -0.7728],
        [ 0.0434, -0.3580],
        [-0.6467, -0.6476],
        [ 0.3262, -0.3395]], grad_fn=<MmBackward0>)

In [266]:
values = inputs @ W_value
values

tensor([[-0.5526, -0.7627],
        [-0.5401,  0.4147],
        [-0.5412,  0.3956],
        [-0.2483,  0.4370],
        [-0.4085, -0.0622],
        [-0.2599,  0.6266]], grad_fn=<MmBackward0>)

In [258]:
# Attention scores
keys_2 = keys[1]
keys_2

tensor([-0.1142, -0.7676], grad_fn=<SelectBackward0>)

In [260]:
attention_score_22 = torch.dot(query_2, keys_2)
attention_score_22

tensor(0.1376, grad_fn=<DotBackward0>)

In [261]:
attention_scores_2 = query_2 @ keys.T
attention_scores_2


tensor([ 0.2172,  0.1376,  0.1730, -0.0491,  0.7616, -0.3809],
       grad_fn=<SqueezeBackward4>)

In [279]:
d_k = keys.shape[1]
print(d_k)


attn_weights_2 = torch.softmax(attention_scores_2 / d_k ** 0.5, dim=-1) # ** 0.5 is the square root of the dimension of the keys
attn_weights_2


2


tensor([0.1704, 0.1611, 0.1652, 0.1412, 0.2505, 0.1117],
       grad_fn=<SoftmaxBackward0>)

In [280]:
attn_weights_2.sum()

tensor(1.0000, grad_fn=<SumBackward0>)

In [301]:
# print(values.shape)
# print(values)
# print(attn_weights_2.shape)
#print(attn_weights_2)

context_vec_2 = attn_weights_2 @ values
context_vec_2

tensor([-0.4370,  0.1182], grad_fn=<SqueezeBackward4>)

## Implementing a compact Self-Attention class

In [304]:
# generalize it and get the context vector for all the inputs at once

import torch.nn as nn

class SelfAttention_v1(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.W_query = torch.nn.Parameter(torch.randn(d_in, d_out)) # trainable weights
    self.W_key = torch.nn.Parameter(torch.randn(d_in, d_out)) # trainable weights
    self.W_value = torch.nn.Parameter(torch.randn(d_in, d_out)) # trainable weights

  def forward(self, x):
    queries = x @ self.W_query
    keys = x @ self.W_key
    values = x @ self.W_value

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1) # ** 0.5 is the square root of the dimension of the keys
    context_vec = attn_weights @ values

    return context_vec


torch.manual_seed(123)

self_attn_v1 = SelfAttention_v1(d_in, d_out)

self_attn_v1(inputs)

tensor([[0.2845, 0.4071],
        [0.2854, 0.4081],
        [0.2854, 0.4075],
        [0.2864, 0.3974],
        [0.2863, 0.3910],
        [0.2860, 0.4039]], grad_fn=<MmBackward0>)

In [316]:
m = torch.nn.Linear(2, 3)
m.weight

Parameter containing:
tensor([[-0.5980, -0.2029],
        [-0.4980,  0.0467],
        [-0.1320, -0.3793]], requires_grad=True)

In [None]:
# Optimize the self-attention class above

class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias=False):
    super().__init__()
    self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)

  def forward(self, x):
    queries = self.W_query(x)
    keys = self.W_key(x)
    values = self.W_value(x)

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1) # ** 0.5 is the square root of the dimension of the keys
    context_vec = attn_weights @ values

    return context_vec


torch.manual_seed(123)
self_attn_v2 = SelfAttention_v2(d_in, d_out)
self_attn_v2(inputs)

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

## Hiding future words with causal attention

### Applying causal attention mask

Basically hiding the unknown words "Your -> journey" in the first step, hide the other 4 and so on.

In [321]:
queries = self_attn_v2.W_query(inputs)
keys = self_attn_v2.W_key(inputs)
values = self_attn_v2.W_value(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / d_k ** 0.5, dim=-1) # ** 0.5 is the square root of the dimension of the keys
attn_weights

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

In [322]:
context_length = attn_weights.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple



tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [324]:
attn_weights_masked_simple = attn_weights * mask_simple
attn_weights_masked_simple


tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)

In [325]:
# now normalize the weights again (the don't sum to 1 anymore)
rows_sum = attn_weights_masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = attn_weights_masked_simple / rows_sum
masked_simple_norm





tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)

In [None]:
# Now let's do the same in less steps

mask = torch.tril(torch.ones(inputs.shape[0], inputs.shape[0]))
masked = attn_scores.masked_fill(mask == 0, float("-inf")) # -inf squared is 0
masked


tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)

In [None]:
# now normalize the weights
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=-1) # keys.shape[-1] is the dimension of the keys and the same as d_k
attn_weights



tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

In [330]:
# now add a dropout layer - though with newer LLMs it is not used anymore
torch.manual_seed(123)

layer = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
layer(example)



tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

In [None]:
layer(attn_weights) # scales the remaining values to compensate for the dropped out values

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.6816, 0.6804, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5085, 0.4936, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.3906, 0.0000],
        [0.3249, 0.3418, 0.0000, 0.3308, 0.3249, 0.3363]],
       grad_fn=<MulBackward0>)

## Implementing a compact causal self-attention class

In [333]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [335]:
batch = torch.stack([inputs, inputs], dim=0)
batch.shape

torch.Size([2, 6, 3])

In [339]:

class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = torch.nn.Dropout(dropout)
    self.register_buffer("mask", torch.tril(torch.ones(context_length, context_length))) # register buffer because otherwise it would not be moved to the GPU


  def forward(self, x):
    b, num_tokens, d_in = x.shape
    # x = batch, 2 x 6 x 3
    queries = self.W_query(x)
    keys = self.W_key(x)
    values = self.W_value(x)

    attn_scores = queries @ keys.transpose(1, 2)
    attn_scores.masked_fill_( # the underscore means in place operation = no copy of the tensor is made
      self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # :num_tokens to account for cases where the context length is not the same for all the inputs
    attn_weights = torch.softmax(
      attn_scores / keys.shape[-1] ** 0.5, dim=-1) # ** 0.5 is the square root of the dimension of the keys
    attn_weights = self.dropout(attn_weights)
    context_vec = attn_weights @ values

    return context_vec

torch.manual_seed(123)
dropout = 0.5
context_length = batch.shape[1] # batch.shape[1] = torch.Size([2, 6, 3]) -> 6
ca = CausalAttention(d_in, d_out, context_length, dropout)
ca(batch)


tensor([[[-0.7976, -0.2622],
         [-0.8226, -0.2565],
         [-0.3206, -0.1053],
         [-0.9066, -0.3096],
         [ 0.0000,  0.0000],
         [    nan,     nan]],

        [[-0.9584, -0.2871],
         [-0.5867, -0.1790],
         [ 0.0000,  0.0000],
         [-0.4799, -0.1576],
         [-0.8427, -0.3002],
         [    nan,     nan]]], grad_fn=<UnsafeViewBackward0>)

## Extending single-head attention to multi-head attention 

### Stacking multiple single-head attention layers

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, dropout, num_heads=2, qkv_bias=False):
    super().__init__()
    self.heads = nn.ModuleList([
      CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads) # not optimal because for loop is slow and not parallelized in GPU
    ])
    
  def forward(self, x):
    out = torch.cat([head(x) for head in self.heads], dim=-1)
    return out

torch.manual_seed(123)
dropout = 0.5
context_length = batch.shape[1]
d_in, d_out = 3, 2
num_heads = 2

multi_head_attn = MultiHeadAttention(d_in, d_out, dropout, num_heads)
multi_head_attn(batch)

## Implementing a compact multi-head self-attention class





tensor([[[-0.7675, -0.2244,  0.2713,  0.1995],
         [-0.3961, -0.1508,  0.6994,  0.5711],
         [-0.5729, -0.2084,  0.2476,  0.2284],
         [ 0.0000,  0.0000,  0.8587,  0.6483],
         [-0.8427, -0.3002,  0.7671,  0.7078],
         [    nan,     nan,     nan,     nan]],

        [[-0.6067, -0.1995,  0.1965,  0.1226],
         [-0.4463, -0.1524,  0.5849,  0.4014],
         [-0.8586, -0.3102,  0.2476,  0.2284],
         [-0.4799, -0.1576,  0.4816,  0.3004],
         [ 0.0000,  0.0000,  0.7671,  0.7078],
         [    nan,     nan,     nan,     nan]]], grad_fn=<CatBackward0>)

### Implementing multi-head attention with wieght splits 

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out // num_heads
    
    self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.out_proj = torch.nn.Linear(d_out, d_out) # Linear layer to combine the head outputs
    self.dropout = torch.nn.Dropout(dropout)
    self.register_buffer("mask", torch.tril(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    b, num_tokens, d_in = x.shape
    q = self.W_query(x)
    k = self.W_key(x)
    v = self.W_value(x)
    