### Simple self-attention mechanism

In [1]:
import torch
text = "Your journey starts with one step"
# assume the embedding layer is initialized with the following values
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],
        [0.55, 0.87, 0.66],
        [0.57, 0.85, 0.64],
        [0.22, 0.58, 0.33],
        [0.77, 0.25, 0.10],
        [0.05, 0.80, 0.55]
    ]
)

In [2]:
query = inputs[1]

In [3]:
query

tensor([0.5500, 0.8700, 0.6600])

In [4]:
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(inputs[1], x_i)

In [5]:
attn_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [6]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

In [7]:
attn_weights_2

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [8]:
# %%timeit
# attn_weights_2_2 = torch.cat([attn_weights_2] * 1000)
# context_vector_2 = attn_weights_2_2 @ torch.cat([inputs] * 1000)

In [9]:
context_vector_2 = attn_weights_2 @ inputs

In [10]:
context_vector_2

tensor([0.4419, 0.6515, 0.5683])

In [11]:
attn_scores = inputs @ inputs.T

In [12]:
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [13]:
import numpy as np
assert np.allclose(attn_scores[1], attn_scores_2)

In [14]:
attn_weights = torch.softmax(attn_scores, dim=-1)

In [15]:
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [16]:
assert np.allclose(attn_weights_2, attn_weights[1])

In [17]:
assert np.allclose(torch.sum(attn_weights, dim=-1), 1)

In [18]:
context_vectors = attn_weights @ inputs

In [19]:
context_vectors

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

In [20]:
assert np.allclose(context_vectors[1], context_vector_2)

### Trainable Weights Self-Attention

In [33]:
d_in = inputs.shape[1]
d_out = 2
torch.manual_seed(123)
W_k = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_v = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_q = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

In [34]:
x_2 = inputs[1]

In [35]:
key_2 = x_2 @ W_k
value_2 = x_2 @ W_k
query_2 = x_2 @ W_q

In [36]:
query_2

tensor([0.4107, 0.6274])

In [37]:
keys = inputs @ W_k
values = inputs @ W_v
queries = inputs @ W_q


In [38]:
attention_scores = queries @ keys.T

In [39]:
attention_scores

tensor([[-0.2118, -0.1385, -0.1361, -0.0602, -0.0547, -0.0894],
        [-0.3533, -0.4847, -0.4709, -0.2879, -0.0888, -0.4388],
        [-0.3491, -0.4829, -0.4691, -0.2874, -0.0877, -0.4381],
        [-0.2002, -0.2877, -0.2794, -0.1728, -0.0502, -0.2635],
        [-0.1753, -0.3144, -0.3046, -0.1974, -0.0434, -0.3020],
        [-0.2533, -0.3215, -0.3126, -0.1871, -0.0639, -0.2848]])

In [40]:
attention_weights = torch.softmax(attention_scores / keys.shape[1]**0.5, dim=-1)
print(attention_weights.shape)

torch.Size([6, 6])


In [41]:
attention_weights

tensor([[0.1555, 0.1638, 0.1641, 0.1731, 0.1738, 0.1696],
        [0.1660, 0.1512, 0.1527, 0.1738, 0.2001, 0.1562],
        [0.1663, 0.1512, 0.1527, 0.1737, 0.2000, 0.1561],
        [0.1674, 0.1574, 0.1583, 0.1707, 0.1861, 0.1601],
        [0.1719, 0.1559, 0.1569, 0.1693, 0.1888, 0.1572],
        [0.1644, 0.1567, 0.1577, 0.1723, 0.1880, 0.1608]])

In [42]:
attention_weights.sum(dim=1)

tensor([1., 1., 1., 1., 1., 1.])

In [45]:
context_vector = attention_weights @ values

In [46]:
context_vector[1]

tensor([-0.1405, -0.5932])