In [1]:

import torch.nn as nn
import torch


inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts   (x^3)
    [0.22, 0.58, 0.33], # with     (x^4)
    [0.77, 0.25, 0.10], # one      (x^5)
    [0.05, 0.80, 0.55]] # step     (x^6)
)


d_in =3
d_out = 2


class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias = qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/d_out**0.5, dim=-1)

        context_vec = attn_weights @ values

        return context_vec
    
torch.manual_seed(567)
sa_v2 = SelfAttention_v2(d_in,d_out)
print(sa_v2(inputs))



"""
- In this section, we are converting the previous self-attention mechanism into a causal self-attention mechanism
- Causal self-attention ensures that the model's prediction for a certain position in a sequence is only dependent on the known outputs at previous positions, not on future positions
- In simpler words, this ensures that each next word prediction should only depend on the preceding words
- To achieve this, for each given token, we mask out the future tokens (the ones that come after the current token in the input text):
"""


queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores/d_out**0.5, dim =-1)
print("attn_weights:\n",attn_weights)

context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

masked_simple = attn_weights * mask_simple
print(masked_simple)

row_sums = masked_simple.sum(dim = -1, keepdim=True)
masked_simple_norm = masked_simple/ row_sums
print(masked_simple_norm)

mask = torch.triu(torch.ones(context_length, context_length), diagonal = 1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)


attn_weights = torch.softmax(masked/2**0.5, dim = -1)
print(attn_weights )

tensor([[-0.0063, -0.2999],
        [-0.0118, -0.2994],
        [-0.0117, -0.2993],
        [-0.0066, -0.2995],
        [-0.0067, -0.2978],
        [-0.0079, -0.3003]], grad_fn=<MmBackward0>)
attn_weights:
 tensor([[0.1751, 0.1620, 0.1636, 0.1585, 0.1967, 0.1442],
        [0.1668, 0.1576, 0.1603, 0.1576, 0.2246, 0.1331],
        [0.1658, 0.1577, 0.1604, 0.1582, 0.2241, 0.1338],
        [0.1691, 0.1619, 0.1636, 0.1611, 0.1980, 0.1464],
        [0.1486, 0.1622, 0.1636, 0.1714, 0.1972, 0.1570],
        [0.1793, 0.1605, 0.1625, 0.1552, 0.2048, 0.1377]],
       grad_fn=<SoftmaxBackward0>)
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
tensor([[0.1751, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1668, 0.1576, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1658, 0.1577, 0.1604, 0.0000, 0.0000, 0.0000],
        [0.1691, 0.1619, 0.163