## What is attention mechanism

*   Words represented as vectors don't consider context (bat (baseball), bat (animal) are the same vector
*  Attention mechanisms changes the original vector with maybe multiple meanings by considering it's context. SUPA interesting.
* Even more, when the weights are applied, a vector for 'bat'  (the animal) can be shifted into a direction of maybe caves, blindness, etc..



### Omega -> Context Vector





$$
Q_i = x_i W^Q, \quad K_j = x_j W^K, \quad V_j = x_j W^V
$$

$$
\space
$$

$$
 \omega_{ij} = Q_i \cdot K_j^T
$$

$$
\space
$$

$$
\alpha_{ij} = \text{softmax}(\omega_{ij})
$$

$$
\space
$$

$$
z_i = \sum_{j=1}^{n} \alpha_{ij} V_j
$$



In [None]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

$$
\text{"Your journey starts with one step"}
\;\longrightarrow\;
\begin{aligned}
\mathbf{x}^{(1)} &= \begin{bmatrix} 0.43 & 0.15 & 0.89 \end{bmatrix}
  \;\bigl(\text{“Your”}\bigr)\\
\mathbf{x}^{(2)} &= \begin{bmatrix} 0.55 & 0.87 & 0.66 \end{bmatrix}
  \;\bigl(\text{“journey”}\bigr)\\
\mathbf{x}^{(3)} &= \begin{bmatrix} 0.57 & 0.85 & 0.64 \end{bmatrix}
  \;\bigl(\text{“starts”}\bigr)\\
\mathbf{x}^{(4)} &= \begin{bmatrix} 0.22 & 0.58 & 0.33 \end{bmatrix}
  \;\bigl(\text{“with”}\bigr)\\
\mathbf{x}^{(5)} &= \begin{bmatrix} 0.77 & 0.25 & 0.10 \end{bmatrix}
  \;\bigl(\text{“one”}\bigr)\\
\mathbf{x}^{(6)} &= \begin{bmatrix} 0.05 & 0.80 & 0.55 \end{bmatrix}
  \;\bigl(\text{“step”}\bigr)
\end{aligned}
$$


### Simple Calculation of Attention Weights

Magnitude of  $
\vec{a} \cdot \vec{b} = \|\vec{a}\| \|\vec{b}\| \cos(\theta)
$ determines how aligned two vectors are, thus a mathematical representation of context when words are represented in space.  After we calculate the dot product of the query vector (the current vector we are at) with the vectors of every other word in the context, we create this score

$$
\vec{\alpha_{2j}} =
\begin{bmatrix}
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing
\end{bmatrix}
\rightarrow
\begin{bmatrix}
0.9544 \\
1.4950 \\
1.4754 \\
0.8434 \\
0.7070 \\
1.0865
\end{bmatrix} = \sum_{j=1}^{n} \alpha_{2j}
$$




In [None]:
query = inputs[1]


print(inputs.shape[0])
attention_score_x2 = torch.empty(inputs.shape[0])
for (index, x_i) in enumerate(inputs):
  attention_score_x2[index] = torch.dot(x_i, query)




6


### Normalizing

These sum to one through a normailization process that, for lack of better words, yields 'meaningful' results. Mathematically, for later implementation. But in general just use the pytorch one -- it's robust.

$$$$

$$
\vec{z_2} =
\begin{bmatrix}
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing
\end{bmatrix}
\rightarrow
\begin{bmatrix}
0.9544 \\
1.4950 \\
1.4754 \\
0.8434 \\
0.7070 \\
1.0865
\end{bmatrix}
\rightarrow
\text{softmax}(x_i)
\rightarrow
\begin{bmatrix}
0.1385 \\
0.2379 \\
0.2333 \\
0.1240 \\
0.1082 \\
0.1581
\end{bmatrix}
$$


$$$$

where


$$
\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}
$$


In [None]:


def softmax_naive(x):
 return torch.exp(x) / torch.exp(x).sum(dim=0) #dim=0 because we're summing a column.


attn_weights_2_solid = torch.softmax(attention_score_x2, dim=0)
print("Refined Attention Weights:", attn_weights_2_solid)


Refined Attention Weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


### Finalizing the Calculation of $z$
$$
\vec{\alpha_{2j}} =
\begin{bmatrix}
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing
\end{bmatrix}
\rightarrow
\begin{bmatrix}
0.9544 \\
1.4950 \\
1.4754 \\
0.8434 \\
0.7070 \\
1.0865
\end{bmatrix}
\rightarrow
\text{softmax}(x_i)
\rightarrow
\begin{bmatrix}
0.1385 \\
0.2379 \\
0.2333 \\
0.1240 \\
0.1082 \\
0.1581
\end{bmatrix}
$$

$$
\space
$$

$$
\vec{z_{2}} =
\begin{bmatrix}
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing \\
\varnothing
\end{bmatrix}
\rightarrow
\alpha_{21} V_1 + \alpha_{22} V_2 + \cdots
\rightarrow
\begin{bmatrix}
0.4371 \\
0.4371 \\
0.4371 \\
0.4371 \\
0.4371 \\
0.4371
\end{bmatrix} = z_2 =\sum_j \alpha_{2j}V_j
$$



In [None]:
context_vector = torch.empty(inputs.shape[0])

for (index, x_j) in enumerate(query):
  context_vector += attn_weights_2_solid[index] * x_j

print(context_vector)


tensor([0.4371, 0.4371, 0.4371, 0.4371, 0.4371, 0.4371])


## Computing Attention Weights $\rightarrow$ context vectors for all Input Tokens

The latter example is a very quick way of calculating the normalized attentino score for the group of tensors.



In [None]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

attn_scores = torch.empty(6,6)
for i, x_i in enumerate(inputs):
  for j, x_j in enumerate(inputs):
    attn_scores[i, j] = torch.dot(x_i, x_j)


attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)
context_vector = attn_weights @ inputs
print(context_vector)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### A quick verification softmax was applied properly

Especially in pertains to `dim=-1` which may change from code to code.

In [None]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


## Weights Q, K, V


We left these alone a while ago, now we are going back to compute them and find out their significance. Mathematically, this is what we are doing
$$
Q_i = x_i W^Q, \quad K_j = x_j W^K, \quad V_j = x_j W^V
$$

In [None]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)


tensor([0.4306, 1.4551])


Notice mathematically that $Q_i$ only considers $i$, because it is the query vector and only considers itself. that means that we only need to perform the `query_2 = x_2 @ W_query` operation. However, because $(K,V)$ consider $j$ we need to perform these operations on all vectors, below:

also, @. is just shorthand for a double for loop. whenever you feel like you need to do that, you know, use @.

In [None]:
keys = inputs @ W_key
values = inputs @ W_value
print(keys, values)

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]]) tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])


### Recomputing Context Vectors with Calculated Weights

starting with $$
 \omega_{ij} = Q_i \cdot K_j^T
$$

In [None]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])



$$
\alpha_{ij} = \text{softmax}(\omega_{ij})
$$

In [None]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2, sum(attn_weights_2))

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820]) tensor(1.)


$$z_2 = \sum_j \alpha_{2j} V_j$$

In [None]:
context_vector_2 = attn_weights_2 @ values
print(context_vector_2)

tensor([0.3061, 0.8210])


## Generalizing the Process to Compute All Context Vectors

So we've implemented weights, and we just computed the attention weights (softmax dot products) for the second word. This process can be optimized now. It's all there, in my head, some parts floating around not yet optimized yet. This will be an important chapter to code myself. Right now, the one new thing is '@' notation which replaces embedded for loops.

d_in, d_out are confusing though

$$
\begin{aligned}
K &= X \;@\; W_{\mathrm{key}},\\
Q &= X \;@\; W_{\mathrm{query}},\\
V &= X \;@\; W_{\mathrm{value}},
\end{aligned}
$$

$$
\begin{aligned}
\mathrm{scores} &= Q \;@\; K^{T},\\
\mathrm{weights} &= \mathrm{softmax}\!\Bigl(\tfrac{\mathrm{scores}}{\sqrt{d_{\mathrm{key}}}}\Bigr),\\
\mathrm{context} &= \mathrm{weights} \;@\; V.
\end{aligned}
$$

In [None]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x): #x is an entire input tensor
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value #access the first by values[0], same for above

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec



In [None]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

## Casual Attention

Realistic text geneartion relies mostly on words that come prior. As in, we won't have access to future words in a sentence. Let's first calculate the tensor for the attention weights of this given sentence.

In [None]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

sa_v2 = SelfAttention_v2(d_in, d_out)
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)


tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


Now applying masks, making sure softmax is still applied and sums to one.

In [None]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
masked_simple = mask_simple * attn_weights

row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums #same

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)

attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1) #same
print(attn_weights)


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)


## Dropouts

A technique used in deep learning. I'm noticing that these matrices have the same shape as the input tensor, and that theindividual context vecs have the same height and are stored. ==ok so dimensions might be a little more obvious==.

In [None]:
import torch

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
print(dropout(attn_weights))


tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0335, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6804, 0.0000, 0.0000, 0.0000],
        [0.4889, 0.5090, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3988, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3418, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


## Developing new SAC with Dropouts

full class with functionality.

In [None]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


## Mathematical Explanation for Attention with Batches

$$
X \;=\;
\begin{bmatrix}
X^{(1)} \\[4pt]
X^{(2)}
\end{bmatrix}
\;\in\;\mathbb{R}^{2\times6\times3}
$$
$$
\space
$$


$$
Q^{(b)} = X^{(b)}\,W_Q,\quad
K^{(b)} = X^{(b)}\,W_K,\quad
V^{(b)} = X^{(b)}\,W_V
\;\in\;\mathbb{R}^{6\times3}.
$$
$$
\space
$$
$$
\omega^{(b)} \;=\; Q^{(b)}\,\bigl(K^{(b)}\bigr)^{T}
\;\in\;\mathbb{R}^{6\times6}
$$
$$
\space
$$
$$
\widetilde \alpha^{(b)}
= \text{masked}(\alpha^{b}) \rightarrow
\text{softmax}\!\Bigl(\tfrac{1}{\sqrt{D}}\,\widetilde \alpha^{(b)}\Bigr)
\;\in\;\mathbb{R}^{6\times6}.
$$
$$
\space
$$
$$
z^{(b)}
= A^{(b)}\,V^{(b)}
\;\in\;\mathbb{R}^{6\times3},
\quad
C^{(b)}_{i,:}
= \sum_{j=1}^{6} A^{(b)}_{ij}\,V^{(b)}_{j,:}.
$$

$$
\space
$$

$$
z =
\begin{bmatrix}
C^{(1)} \\[4pt]
C^{(2)}
\end{bmatrix}
\;\in\;\mathbb{R}^{2\times6\times3}.
$$


### Quick Blurb about Masking

We define a mask \(M\in\mathbb{R}^{T\times T}\) that blocks future tokens:
$$
M_{ij} =
\begin{cases}
0, & j \le i,\\
-\infty, & j > i.
\end{cases}
$$

% Applying the mask
\noindent
Add \(M\) to the raw scores before softmax so that any “future” positions get zero weight:
$$
\widetilde S^{(b)} = S^{(b)} + M,
\quad
A^{(b)} = \text{softmax}\!\bigl(\tfrac{1}{\sqrt{D}}\,\widetilde S^{(b)}\bigr).
$$

In [None]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # This is the mask

    def forward(self, x):
        b, num_tokens, d_in = x.shape # New batch dimension b
        # For inputs where `num_tokens` exceeds `context_length`, this will result in errors
        # in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forward method.
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


## Multi-Head Attention

In [None]:
import torch.nn as nn
import torch

class MultiHeadAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, context_length,
      dropout, num_heads, qkv_bias=False):

      super().__init__()
      self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)]
                             )
  def forward(self, x):
    return torch.cat([head(x) for head in self.heads], dim=-1)

In [None]:
torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens, this is standard
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


## Refined MultiHead Class

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`,
        # this will result in errors in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forwar

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec
