In [1]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.tensor([
    [0.72, 0.45, 0.31],  # Dream
    [0.75, 0.20, 0.55],  # big
    [0.30, 0.80, 0.40],  # and
    [0.85, 0.35, 0.60],  # work
    [0.55, 0.15, 0.75],  # for
    [0.25, 0.20, 0.85]   # it
])

words = ['Dream', 'big', 'and', 'work', 'for', 'it']


In [3]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [4]:
torch.manual_seed(123)

W_query = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [5]:
query_2 = x_2 @ W_query
key_2   = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)

tensor([0.3131, 1.0017])


In [6]:
keys    = inputs @ W_key
values  = inputs @ W_value
queries = inputs @ W_query

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
print("queries.shape:", queries.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
queries.shape: torch.Size([6, 2])


In [7]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(0.6990)


In [8]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([0.7021, 0.6990, 0.9867, 0.8707, 0.7880, 0.8624])


In [9]:
attn_scores = queries @ keys.T
print(attn_scores)

tensor([[0.6807, 0.6795, 0.9526, 0.8454, 0.7654, 0.8359],
        [0.7021, 0.6990, 0.9867, 0.8707, 0.7880, 0.8624],
        [0.7350, 0.7315, 1.0337, 0.9113, 0.8248, 0.9029],
        [0.8436, 0.8402, 1.1848, 1.0464, 0.9471, 1.0361],
        [0.7080, 0.7025, 1.0003, 0.8764, 0.7929, 0.8699],
        [0.6680, 0.6606, 0.9486, 0.8254, 0.7465, 0.8210]])


In [10]:
d_k = keys.shape[-1]

attn_weights_2 = torch.softmax(
    attn_scores_2 / d_k**0.5,
    dim=-1
)

print(attn_weights_2)
print(d_k)

tensor([0.1531, 0.1528, 0.1873, 0.1725, 0.1627, 0.1715])
2


In [11]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.2274, 0.7362])


In [12]:
attn_weights = torch.softmax(
    attn_scores / d_k**0.5,
    dim=-1
)

context_vec = attn_weights @ values

print(context_vec)

tensor([[0.2273, 0.7361],
        [0.2274, 0.7362],
        [0.2276, 0.7363],
        [0.2280, 0.7368],
        [0.2275, 0.7362],
        [0.2275, 0.7360]])


In [13]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys    = x @ self.W_key
        queries = x @ self.W_query
        values  = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,
            dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec


In [14]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys    = self.W_key(x)
        queries = self.W_query(x)
        values  = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,
            dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec


In [15]:
sa = SelfAttention_v2(d_in=3, d_out=2)
output = sa(inputs)
print(output)

tensor([[0.5269, 0.2695],
        [0.5274, 0.2714],
        [0.5269, 0.2714],
        [0.5278, 0.2726],
        [0.5277, 0.2733],
        [0.5277, 0.2743]], grad_fn=<MmBackward0>)


In [17]:
queries = sa.W_query(inputs)
keys = sa.W_key(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)

print(attn_weights)

tensor([[0.1767, 0.1661, 0.1764, 0.1692, 0.1582, 0.1535],
        [0.1798, 0.1657, 0.1796, 0.1697, 0.1555, 0.1496],
        [0.1793, 0.1648, 0.1816, 0.1682, 0.1554, 0.1508],
        [0.1818, 0.1655, 0.1819, 0.1700, 0.1538, 0.1472],
        [0.1827, 0.1649, 0.1839, 0.1695, 0.1527, 0.1463],
        [0.1841, 0.1641, 0.1869, 0.1689, 0.1512, 0.1449]],
       grad_fn=<SoftmaxBackward0>)


In [18]:
context_length = attn_scores.shape[0]

mask_simple = torch.tril(
    torch.ones(context_length, context_length)
)

print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [19]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1767, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1798, 0.1657, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1793, 0.1648, 0.1816, 0.0000, 0.0000, 0.0000],
        [0.1818, 0.1655, 0.1819, 0.1700, 0.0000, 0.0000],
        [0.1827, 0.1649, 0.1839, 0.1695, 0.1527, 0.0000],
        [0.1841, 0.1641, 0.1869, 0.1689, 0.1512, 0.1449]],
       grad_fn=<MulBackward0>)


In [20]:
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums

print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5203, 0.4797, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3411, 0.3135, 0.3454, 0.0000, 0.0000, 0.0000],
        [0.2601, 0.2367, 0.2601, 0.2431, 0.0000, 0.0000],
        [0.2140, 0.1931, 0.2154, 0.1985, 0.1789, 0.0000],
        [0.1841, 0.1641, 0.1869, 0.1689, 0.1512, 0.1449]],
       grad_fn=<DivBackward0>)


In [21]:
print(attn_scores)

tensor([[ 0.0884,  0.0014,  0.0860,  0.0271, -0.0677, -0.1102],
        [ 0.1127, -0.0022,  0.1116,  0.0310, -0.0924, -0.1470],
        [ 0.0795, -0.0397,  0.0974, -0.0108, -0.1228, -0.1649],
        [ 0.1275, -0.0058,  0.1279,  0.0323, -0.1095, -0.1716],
        [ 0.1222, -0.0232,  0.1311,  0.0158, -0.1316, -0.1927],
        [ 0.1160, -0.0471,  0.1368, -0.0065, -0.1629, -0.2232]],
       grad_fn=<MmBackward0>)


In [22]:
mask = torch.triu(
    torch.ones(context_length, context_length),
    diagonal=1
)

print(mask)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


In [28]:
attn_scores = attn_scores/(2**(0.5))

In [29]:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[ 0.0442,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.0564, -0.0011,    -inf,    -inf,    -inf,    -inf],
        [ 0.0398, -0.0198,  0.0487,    -inf,    -inf,    -inf],
        [ 0.0637, -0.0029,  0.0639,  0.0161,    -inf,    -inf],
        [ 0.0611, -0.0116,  0.0656,  0.0079, -0.0658,    -inf],
        [ 0.0580, -0.0235,  0.0684, -0.0032, -0.0814, -0.1116]],
       grad_fn=<MaskedFillBackward0>)


In [32]:
attn_weights = torch.softmax(
    masked,
    dim=-1
)

In [33]:
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5144, 0.4856, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3389, 0.3193, 0.3419, 0.0000, 0.0000, 0.0000],
        [0.2571, 0.2405, 0.2572, 0.2452, 0.0000, 0.0000],
        [0.2099, 0.1952, 0.2109, 0.1991, 0.1849, 0.0000],
        [0.1790, 0.1650, 0.1809, 0.1684, 0.1557, 0.1511]],
       grad_fn=<SoftmaxBackward0>)

In [34]:
example = torch.ones(6, 6)
print(example)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])


In [35]:
torch.manual_seed(123)

dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)

print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [36]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6777, 0.6385, 0.6838, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4811, 0.5143, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3904, 0.0000, 0.3981, 0.0000, 0.0000],
        [0.0000, 0.3300, 0.3617, 0.3367, 0.3114, 0.0000]],
       grad_fn=<MulBackward0>)
