In [1]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
import torch.nn as nn

In [3]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in: int, d_out: int, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        print("x.shape is ", x.shape)
        keys = self.W_key(x)
        print("keys.shape is ", keys.shape)
        queries = self.W_query(x)
        print("queries.shape is ", queries.shape)
        values = self.W_value(x)
        print("values.shape is ", values.shape)

        attn_scores = queries @ keys.T
        print("attn_scores.shape is ", attn_scores.shape)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        print(attn_weights)
        self.attn_scores = attn_scores
        context_vec = attn_weights @ values
        return context_vec

In [4]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in=3, d_out=2)
print(sa_v2(inputs))

x.shape is  torch.Size([6, 3])
keys.shape is  torch.Size([6, 2])
queries.shape is  torch.Size([6, 2])
values.shape is  torch.Size([6, 2])
attn_scores.shape is  torch.Size([6, 6])
tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)
tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


In [5]:
pre_causual_attn_weights = torch.Tensor(
    [
        [0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529],
    ]
)

In [6]:
pre_causual_attn_weights.shape

torch.Size([6, 6])

In [7]:
context_length = pre_causual_attn_weights.shape[0]  #6
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [8]:
masked_simple = pre_causual_attn_weights * mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]])


In [9]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
print(row_sums)

tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0001]])


In [10]:
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5516, 0.4484, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3801, 0.3097, 0.3102, 0.0000, 0.0000, 0.0000],
        [0.2759, 0.2461, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1985, 0.1887, 0.1970, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]])


In [11]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
print(mask)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


In [12]:
print(mask.bool())

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])


In [13]:
masked = pre_causual_attn_weights.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.1921,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.2041, 0.1659,   -inf,   -inf,   -inf,   -inf],
        [0.2036, 0.1659, 0.1662,   -inf,   -inf,   -inf],
        [0.1869, 0.1667, 0.1668, 0.1571,   -inf,   -inf],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658,   -inf],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]])


In [14]:
pre_causual_attn_weights = torch.softmax(
    masked / 2**0.5,
    dim=1
)
print(pre_causual_attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5068, 0.4932, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3393, 0.3303, 0.3304, 0.0000, 0.0000, 0.0000],
        [0.2531, 0.2495, 0.2495, 0.2478, 0.0000, 0.0000],
        [0.2021, 0.1998, 0.1998, 0.1987, 0.1996, 0.0000],
        [0.1698, 0.1666, 0.1666, 0.1652, 0.1666, 0.1650]])


In [15]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)    # We choose a dropout rate of 50%
example = torch.ones(6, 6)      # we create a matrix of 1s.
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [16]:
torch.manual_seed(123)
print(dropout(pre_causual_attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.9865, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6608, 0.0000, 0.0000, 0.0000],
        [0.5062, 0.4990, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3332, 0.0000, 0.0000, 0.0000, 0.0000]])


In [17]:
print(inputs.shape)
batch = torch.stack((inputs, inputs), dim=0)
# two input - each has 6 tokens, and each token has 3 dimensions
print(batch.shape)

torch.Size([6, 3])
torch.Size([2, 6, 3])


In [18]:
class CasualAttention(nn.Module):
    def __init__(self, d_in: int, d_out: int, context_length: int, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(p=dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

        
