In [1]:
import torch
import torch.nn as nn


In [10]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        context_vector = attn_weights @ values
        return context_vector

In [11]:
vocab = {"Your": 0, "journey": 1, "starts": 2, "with": 3, "one": 4, "step": 5}
embedding_dim = 3
tokenized_input = ["Your", "journey", "starts", "with", "one", "step"]

In [12]:
# Random embeddings
torch.manual_seed(789)
embeddings = torch.randn(len(vocab), embedding_dim)

In [13]:
# Convert tokens to embeddings
inputs = torch.stack([embeddings[vocab[word]] for word in tokenized_input])

In [14]:
# Initialize and apply self-attention
sa_v2 = SelfAttention_v2(d_in=3,d_out=2)
print(sa_v2(inputs))

tensor([[-0.0427,  0.0316],
        [-0.0348,  0.0876],
        [-0.0490, -0.0365],
        [-0.0355,  0.1067],
        [-0.0369,  0.1218],
        [-0.0342,  0.0994]], grad_fn=<MmBackward0>)


In [16]:
#Applying causal attention
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1670, 0.1690, 0.1660, 0.1656, 0.1644, 0.1679],
        [0.1379, 0.1451, 0.1693, 0.2051, 0.1718, 0.1708],
        [0.2051, 0.1875, 0.1616, 0.1281, 0.1620, 0.1557],
        [0.1309, 0.1469, 0.1672, 0.2135, 0.1641, 0.1773],
        [0.1259, 0.1504, 0.1648, 0.2189, 0.1561, 0.1840],
        [0.1330, 0.1433, 0.1688, 0.2121, 0.1699, 0.1729]],
       grad_fn=<SoftmaxBackward0>)


In [18]:
#Create a mask where values above diagonals are zero
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [20]:
#Zero out attention weights above diagonals
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1670, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1379, 0.1451, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2051, 0.1875, 0.1616, 0.0000, 0.0000, 0.0000],
        [0.1309, 0.1469, 0.1672, 0.2135, 0.0000, 0.0000],
        [0.1259, 0.1504, 0.1648, 0.2189, 0.1561, 0.0000],
        [0.1330, 0.1433, 0.1688, 0.2121, 0.1699, 0.1729]],
       grad_fn=<MulBackward0>)


In [21]:
#renormalize attn weight
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4873, 0.5127, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3701, 0.3384, 0.2915, 0.0000, 0.0000, 0.0000],
        [0.1988, 0.2231, 0.2539, 0.3242, 0.0000, 0.0000],
        [0.1543, 0.1843, 0.2020, 0.2682, 0.1913, 0.0000],
        [0.1330, 0.1433, 0.1688, 0.2121, 0.1699, 0.1729]],
       grad_fn=<DivBackward0>)


Using infinity trick to implement more efficient masking

In [23]:
#applying softmax
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[ 0.0025,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.2798, -0.2078,    -inf,    -inf,    -inf,    -inf],
        [ 0.3394,  0.2129,  0.0021,    -inf,    -inf,    -inf],
        [-0.3581, -0.1953, -0.0126,  0.3334,    -inf,    -inf],
        [-0.4171, -0.1660, -0.0364,  0.3649, -0.1134,    -inf],
        [-0.3330, -0.2270,  0.0043,  0.3274,  0.0140,  0.0385]],
       grad_fn=<MaskedFillBackward0>)


In [24]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4873, 0.5127, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3701, 0.3384, 0.2915, 0.0000, 0.0000, 0.0000],
        [0.1988, 0.2231, 0.2539, 0.3242, 0.0000, 0.0000],
        [0.1543, 0.1843, 0.2020, 0.2682, 0.1913, 0.0000],
        [0.1330, 0.1433, 0.1688, 0.2121, 0.1699, 0.1729]],
       grad_fn=<SoftmaxBackward0>)
