In [1]:
import torch

inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55],
    ]  # step     (x^6)
)

In [8]:
import torch.nn as nn


# using nn.Linear() as way to define the weight matrices
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        # initialize the weight matrices
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        # compute query, key and value vectors
        keys = self.W_key(x)  # multiply 6X3 with 3x2 matrices
        queries = self.W_query(x)
        values = self.W_value(x)

        # compute attention scores
        attention_scores = queries @ keys.T

        # compute scaled dot product attention
        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1] ** 0.5, dim=-1
        )

        # compute context vectors
        context_vectors = attention_weights @ values
        return context_vectors


# create an instance of the SelfAttention_v1 class
d_in = 3
d_out = 2
sa_v2 = SelfAttention_v2(d_in, d_out, qkv_bias=False)
context_vectors = sa_v2(inputs)
print(context_vectors)

tensor([[-0.4829, -0.0462],
        [-0.4823, -0.0540],
        [-0.4823, -0.0539],
        [-0.4827, -0.0537],
        [-0.4830, -0.0520],
        [-0.4825, -0.0543]], grad_fn=<MmBackward0>)


In [11]:
# causal attention  step 1
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=1)
print(attn_weights)

tensor([[0.1718, 0.1573, 0.1567, 0.1766, 0.1560, 0.1816],
        [0.1542, 0.1584, 0.1578, 0.1835, 0.1572, 0.1888],
        [0.1545, 0.1586, 0.1580, 0.1831, 0.1574, 0.1884],
        [0.1574, 0.1626, 0.1622, 0.1765, 0.1619, 0.1793],
        [0.1623, 0.1641, 0.1639, 0.1722, 0.1637, 0.1738],
        [0.1548, 0.1605, 0.1600, 0.1805, 0.1596, 0.1846]],
       grad_fn=<SoftmaxBackward0>)


In [21]:
# causal attention step 2, masking attention weights by multiplyig with a lower triangular matrix
block_size = attn_scores.shape[0]
mask_simple = torch.ones(block_size, block_size)
simple_mask = torch.tril(mask_simple)
print(simple_mask)

masked_attn_weights = attn_weights * simple_mask
print(masked_attn_weights)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
tensor([[0.1718, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1542, 0.1584, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1545, 0.1586, 0.1580, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.1626, 0.1622, 0.1765, 0.0000, 0.0000],
        [0.1623, 0.1641, 0.1639, 0.1722, 0.1637, 0.0000],
        [0.1548, 0.1605, 0.1600, 0.1805, 0.1596, 0.1846]],
       grad_fn=<MulBackward0>)


In [22]:
# causal attention step 3, re normalization is done again to make the rows sum to 1
row_sums = masked_attn_weights.sum(dim=1, keepdim=True)
masked_simple_norm = masked_attn_weights / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4933, 0.5067, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3279, 0.3367, 0.3354, 0.0000, 0.0000, 0.0000],
        [0.2390, 0.2468, 0.2463, 0.2680, 0.0000, 0.0000],
        [0.1964, 0.1986, 0.1984, 0.2084, 0.1982, 0.0000],
        [0.1548, 0.1605, 0.1600, 0.1805, 0.1596, 0.1846]],
       grad_fn=<DivBackward0>)


In [31]:
# more efficient implementation of causal attention
# make the values above  the diagonal to be negative infinity before applying softmax
mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)
print(mask)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print("apply inf to masks:", masked)
print("attn scores without mask:", attn_scores)
print("attn scores after masking:", masked)

# apply softmax function to the masked results
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
print("attn weights after normalization:", attn_weights)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
apply inf to masks: tensor([[-0.1928,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3860, -0.3479,    -inf,    -inf,    -inf,    -inf],
        [-0.3781, -0.3406, -0.3463,    -inf,    -inf,    -inf],
        [-0.2345, -0.1888, -0.1920, -0.0723,    -inf,    -inf],
        [-0.1290, -0.1131, -0.1149, -0.0451, -0.1165,    -inf],
        [-0.3241, -0.2726, -0.2771, -0.1067, -0.2809, -0.0749]],
       grad_fn=<MaskedFillBackward0>)
attn scores without mask: tensor([[-0.1928, -0.3181, -0.3234, -0.1538, -0.3291, -0.1143],
        [-0.3860, -0.3479, -0.3537, -0.1406, -0.3587, -0.0996],
        [-0.3781, -0.3406, -0.3463, -0.1376, -0.3511, -0.0975],
        [-0.2345, -0.1888, -0.1920, -0.0723, -0.1945, -0.0505],
        [-0.1290, -0.1131, -0.1149, -0.0451, -0.1165, -0.0319],
      

In [29]:
# one more change in causal attention with dropout to prevent overfitting
torch.manual_seed(123)
dropout = nn.Dropout(p=0.5)
example = torch.ones(6, 6)
print("before dropout:", example)
print("after dropout:", dropout(example))

print(dropout(attn_weights))

before dropout: tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
after dropout: tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6707, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4936, 0.0000, 0.5360, 0.0000, 0.0000],
        [0.0000, 0.3973, 0.3968, 0.4168, 0.3963, 0.0000],
        [0.3096, 0.3210, 0.0000, 0.0000, 0.3192, 0.3692]],
       grad_fn=<MulBackward0>)


In [30]:
# create a simple batch by duplicating the inputs tensor
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [32]:
# complete python class implementing causal attention with dropout


class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out)
        self.W_key = nn.Linear(d_in, d_out)
        self.W_value = nn.Linear(d_in, d_out)
        self.dropout = nn.Dropout(p=dropout)
        self.register_buffer(
            "mask", torch.triu(torch.ones(block_size, block_size), diagonal=1)
        )

    def forward(self, x):
        b, num_token, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        print("queries shape:", queries.shape)
        print("keys shape:", keys.shape)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask[:num_token, :num_token].bool(), -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=1)
        attn_weights = self.dropout(attn_weights)
        context_vectors = attn_weights @ values
        return context_vectors


torch.manual_seed(123)
block_size = batch.shape[1]
d_in = 3
d_out = 2
ca = CausalAttention(d_in, d_out, block_size, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
# context_vecs.shape: torch.Size([2, 6, 2])

queries shape: torch.Size([2, 6, 2])
keys shape: torch.Size([2, 6, 2])
context_vecs.shape: torch.Size([2, 6, 2])
