# 3. Coding attention mechanisms

<img src="./images/chapter-3/Fig3.1.png" width="720">  
  
<img src="./images/chapter-3/Fig3.2.png" width="720">

### 3.3 Attending to different parts of the input with self-attention

#### 3.3.1 A simple self-attention mechanism without trainable weights

In [201]:
import torch

inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55],  # step     (x^6)
    ]
)

In [202]:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [203]:
# we normalize each of the attention scores we computed previously. The main goal behind the normalization is to obtain attention weights that sum up to 1. This normalization is a convention that is useful for interpretation and maintaining training stability in an LLM
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [204]:
# In practice, it’s more common and advisable to use the softmax function for normal- ization. This approach is better at managing extreme values and offers more favorable gradient properties during training.
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)


attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [205]:
# Note that this naive softmax implementation (softmax_naive) may encounter numerical instability problems, such as overflow and underflow, when dealing with large or small input values. Therefore, in practice, it’s advisable to use the PyTorch implementation of softmax, which has been extensively optimized for performance:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [206]:
# Thus, context vector z(2) is the weighted sum of all input vectors, obtained by multiplying each input vector by its corresponding attention weight:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


#### 3.3.2 Computing attention weights for all input tokens

1) Compute attention scores  
2) Compute attention weights  
3) Compute context vectors  
Compute the attention scores as dot products between the inputs.
The attention weights are a normalized version of the attention scores.
The context vectors are computed as a weighted sum over the inputs.

In [207]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [208]:
# pytorch way of doing the same
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [209]:
attn_scores.shape

torch.Size([6, 6])

In [210]:
# we normalize each row so that the values in each row sum to 1
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)
# By setting dim=-1, we are instructing the softmax function to apply the normalization along the last dimension of the attn_scores tensor
# If attn_scores is a two-dimensional tensor (for example, with a shape of [rows, columns]), it will normalize across the columns so that the values in each row (summing over the column dimension) sum up to 1.

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [211]:
# all context vectors
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### 3.4 Implementing self-attention with trainable weights

<img src="./images/chapter-3/Fig3.13.png" width="720">

#### 3.4.1 Computing the attention weights step by step

In [212]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [213]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [214]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


In [215]:
# Even though our temporary goal is only to compute the one context vector, z(2), we still require the key and value vectors for all input elements as they are involved in com- puting the attention weights with respect to the query q(2).
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [216]:
# First, let’s compute the attention score ω22:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


In [217]:
# all attention scores
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [218]:
# all attention weights
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [219]:
# compute context vector
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


#### 3.4.2 Implementing a compact self-attention Python class

In [220]:
import torch.nn as nn


class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        torch.manual_seed(123)
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        torch.manual_seed(123)
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        torch.manual_seed(123)
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T  # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [221]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.3306, 1.1527],
        [0.3371, 1.1767],
        [0.3368, 1.1755],
        [0.3262, 1.1349],
        [0.3245, 1.1272],
        [0.3301, 1.1505]], grad_fn=<MmBackward0>)


In [222]:
# We can improve the SelfAttention_v1 implementation further by utilizing PyTorch’s nn.Linear layers, which effectively perform matrix multiplication when the bias units are disabled. Additionally, a significant advantage of using nn.Linear instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training.
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)
        # to reproduce SelfAttention_v1 results
        torch.manual_seed(123)
        self.W_query.weight.data = nn.Parameter(torch.rand(d_in, d_out)).T
        torch.manual_seed(123)
        self.W_key.weight.data = nn.Parameter(torch.rand(d_in, d_out)).T
        torch.manual_seed(123)
        self.W_value.weight.data = nn.Parameter(torch.rand(d_in, d_out)).T

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T  # omega
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [223]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[0.3306, 1.1527],
        [0.3371, 1.1767],
        [0.3368, 1.1755],
        [0.3262, 1.1349],
        [0.3245, 1.1272],
        [0.3301, 1.1505]], grad_fn=<MmBackward0>)


In [224]:
# torch.manual_seed(789)
# sa_v2 = SelfAttention_v2(d_in, d_out)
# print(sa_v2(inputs))

### 3.5 Hiding future words with causal attention

In [225]:
# Causal attention, also known as masked attention, It restricts a model to only consider previous and current inputs in a sequence when processing any given token when computing attention scores. This is in contrast to the standard self-attention mechanism, which allows access to the entire input sequence at once.



#### 3.5.1 Applying a causal attention mask

In [226]:
# step-1, we compute the attention weights using the softmax function

queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
print(attn_weights)

tensor([[0.1615, 0.2204, 0.2168, 0.1283, 0.1161, 0.1569],
        [0.1565, 0.2404, 0.2353, 0.1154, 0.1016, 0.1508],
        [0.1567, 0.2395, 0.2344, 0.1160, 0.1023, 0.1511],
        [0.1631, 0.2065, 0.2040, 0.1380, 0.1286, 0.1598],
        [0.1630, 0.2008, 0.1988, 0.1421, 0.1348, 0.1606],
        [0.1615, 0.2187, 0.2153, 0.1295, 0.1178, 0.1571]],
       grad_fn=<SoftmaxBackward0>)


In [227]:
# step-2 use PyTorch’s tril function to create a mask where the values above the diagonal are zero:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [228]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1615, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1565, 0.2404, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1567, 0.2395, 0.2344, 0.0000, 0.0000, 0.0000],
        [0.1631, 0.2065, 0.2040, 0.1380, 0.0000, 0.0000],
        [0.1630, 0.2008, 0.1988, 0.1421, 0.1348, 0.0000],
        [0.1615, 0.2187, 0.2153, 0.1295, 0.1178, 0.1571]],
       grad_fn=<MulBackward0>)


In [229]:
# step-3 is to renormalize the attention weights to sum up to 1 again in each row.
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3942, 0.6058, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2485, 0.3798, 0.3718, 0.0000, 0.0000, 0.0000],
        [0.2292, 0.2902, 0.2867, 0.1939, 0.0000, 0.0000],
        [0.1942, 0.2392, 0.2369, 0.1693, 0.1605, 0.0000],
        [0.1615, 0.2187, 0.2153, 0.1295, 0.1178, 0.1571]],
       grad_fn=<DivBackward0>)


In [230]:
# better implementation with stable math computing
# We can implement this more efficient masking “trick” by creating a mask with 1s above the diagonal and then replacing these 1s with negative infinity (-inf) values:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[1.2559,   -inf,   -inf,   -inf,   -inf,   -inf],
        [1.6951, 2.3026,   -inf,   -inf,   -inf,   -inf],
        [1.6722, 2.2722, 2.2421,   -inf,   -inf,   -inf],
        [0.9305, 1.2640, 1.2472, 0.6938,   -inf,   -inf],
        [0.7889, 1.0838, 1.0700, 0.5948, 0.5200,   -inf],
        [1.2143, 1.6432, 1.6211, 0.9020, 0.7681, 1.1753]],
       grad_fn=<MaskedFillBackward0>)


In [231]:
# Now all we need to do is apply the softmax function to these masked results, and we are done:
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3942, 0.6058, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2485, 0.3798, 0.3718, 0.0000, 0.0000, 0.0000],
        [0.2292, 0.2902, 0.2867, 0.1939, 0.0000, 0.0000],
        [0.1942, 0.2392, 0.2369, 0.1693, 0.1605, 0.0000],
        [0.1615, 0.2187, 0.2153, 0.1295, 0.1178, 0.1571]],
       grad_fn=<SoftmaxBackward0>)


#### 3.5.2 Masking additional attention weights with dropout

In [232]:
# We will apply the dropout mask after computing the attention weights,
# Half of the elements in the matrix are randomly set to zero, To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 = 2.
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [233]:
# Now let’s apply dropout to the attention weight matrix itself:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4969, 0.7595, 0.7436, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5803, 0.5735, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4784, 0.0000, 0.3385, 0.0000, 0.0000],
        [0.0000, 0.4374, 0.4307, 0.2590, 0.2356, 0.0000]],
       grad_fn=<MulBackward0>)


#### 3.5.3 Implementing a compact causal attention class

In [234]:
# A compact causal attention class
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec


# self.register_buffer
# The use of register_buffer in PyTorch is not strictly necessary for all use cases but offers several advantages here. For instance, when we use the CausalAttention class in our LLM, buffers are automatically moved to the appropriate device (CPU or GPU) along with our model, which will be relevant when training our LLM. This means we don’t need to manually ensure these tensors are on the same device as your model parameters, avoiding device mismatch errors.
# If you have parameters in your model, which should be saved and restored in the state_dict, but not trained by the optimizer, you should register them as buffers.
# Buffers won’t be returned in model.parameters(), so that the optimizer won’t have a chance to update them.
# 1. one reason to register the tensor as a buffer is to be able to serialize the model and restore all internal states.
# 2. another reason is to make sure that the tensor is moved to the right device when you call model.to(device) or model.cuda() or model.cpu(). Buffers are automatically moved to the right device when you call these functions.

In [235]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [236]:
# We can use the CausalAttention class as follows, similar to SelfAttention previously:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


### 3.6 Extending single-head attention to multi-head attention

1. First, we will intuitively build a multi-head attention module by stacking multiple `CausalAttention` modules  
2. Then we will implement the same multi-head attention module in a more complicated but more computationally efficient way.  

#### 3.6.1 Stacking multiple single-head attention layers

In [237]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
                for _ in range(num_heads)
            ]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [238]:
# example
torch.manual_seed(123)
context_length = batch.shape[1]  # This is the number of tokens
d_in, d_out = 3, 2

mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


In [239]:
# MultiHeadAttentionWrapper that combined multiple single-head attention modules. However, these are processed sequentially via [head(x) for head in self.heads] in the forward method. We can improve this implementation by processing the heads in parallel.

#### 3.6.2 Implementing multi-head attention with weight splits

In [240]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        # reshaping the keys, queries, and values tensors to have the shape (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        # transposing the keys, queries, and values tensors to have the shape (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        # calculating the attention scores
        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        # calculate the attention weights
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        # calculating the context vectors
        context_vec = (attn_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        # projecting the context vectors to the output dimension
        context_vec = self.out_proj(context_vec)
        return context_vec

In [241]:
# example
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


## Summary

Learned New stuff in pytorch:  
1. lower triangular matrix  
`mask = torch.tril(torch.ones(context_length, context_length))`  

2. upper triangular matrix  
`mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)`  

3. masked_fill:  
    ```python  
    mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)  
    masked = attn_scores.masked_fill(mask.bool(), -torch.inf)  
    ```


4. register_buffer in nn.Module:
    - The use of register_buffer in PyTorch is not strictly necessary for all use cases but offers several advantages here. For instance, when we use the CausalAttention class in our LLM, buffers are automati- cally moved to the appropriate device (CPU or GPU) along with our model, which will be relevant when training our LLM. This means we don’t need to manually ensure these tensors are on the same device as your model parameters, avoiding device mis- match errors.  
    - If you have parameters in your model, which should be saved and restored in the state_dict, but not trained by the optimizer, you should register them as buffers.  
    - Buffers won’t be returned in model.parameters(), so that the optimizer won’t have a change to update them.  
        - one reason to register the tensor as a buffer is to be able to serialize the model and restore all internal states.  
        - another reason is to make sure that the tensor is moved to the right device when you call model.to(device) or model.cuda() or model.cpu(). Buffers are automatically moved to the right device when you call these functions.  