In [1]:
import sys
import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sentence = torch.tensor(
    [0, # can
     7, # you     
     1, # help
     2, # me
     5, # to
     6, # translate
     4, # this
     3] # sentence
)

sentence

tensor([0, 7, 1, 2, 5, 6, 4, 3])

In [6]:
torch.manual_seed(123)
embed = torch.nn.Embedding(10, 16)
embedded_sentence = embed(sentence).detach()
embedded_sentence.shape
#embedded_sentence

torch.Size([8, 16])

In [4]:
omega = torch.empty(8, 8)

for i, x_i in enumerate(embedded_sentence):
    for j, x_j in enumerate(embedded_sentence):
        omega[i, j] = torch.dot(x_i, x_j)

In [9]:
omega_mat = embedded_sentence.matmul(embedded_sentence.T)
omega_mat.shape

torch.Size([8, 8])

In [8]:
torch.allclose(omega_mat, omega)

True

In [10]:
attention_weights = F.softmax(omega, dim=1)
attention_weights.shape

torch.Size([8, 8])

In [11]:
attention_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [12]:
x_2 = embedded_sentence[1, :]
context_vec_2 = torch.zeros(x_2.shape)
for j in range(8):
    x_j = embedded_sentence[j, :]
    context_vec_2 += attention_weights[1, j] * x_j
print(context_vec_2)

tensor([-9.3975e-01, -4.6856e-01,  1.0311e+00, -2.8192e-01,  4.9373e-01,
        -1.2896e-02, -2.7327e-01, -7.6358e-01,  1.3958e+00, -9.9543e-01,
        -7.1287e-04,  1.2449e+00, -7.8077e-02,  1.2765e+00, -1.4589e+00,
        -2.1601e+00])


In [13]:
context_vectors = torch.matmul(attention_weights, embedded_sentence)

In [14]:
torch.allclose(context_vec_2, context_vectors[1])

True

In [16]:
d = embedded_sentence.shape[1]
U_query = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)

x_2 = embedded_sentence[1]
query_2 = U_query.matmul(x_2)

key_2 = U_key.matmul(x_2)
value_2 = U_value.matmul(x_2)


In [17]:
keys = U_key.matmul(embedded_sentence.T).T
torch.allclose(key_2, keys[1])


values = U_value.matmul(embedded_sentence.T).T
torch.allclose(value_2, values[1])


True

In [18]:
omega_2 = query_2.matmul(keys.T)
omega_2

tensor([-10.6576,  17.5616,  11.4914,  14.7860,  48.0565,  22.1839,   6.8517,
        -57.4444])

In [19]:
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
attention_weights_2

tensor([4.2086e-07, 4.8754e-04, 1.0689e-04, 2.4359e-04, 9.9758e-01, 1.5483e-03,
        3.3512e-05, 3.5022e-12])

In [20]:
context_vector_2 = attention_weights_2.matmul(values)
context_vector_2


tensor([-3.5123, -1.2417, -4.7549, -1.6927, -1.9569, -1.6585, -4.0222, -2.3917,
        -1.5843, -1.1964, -1.6968, -4.3375, -2.1805, -3.6918, -4.2858, -3.5367])

torch.Size([8, 16])
