In [115]:
import torch

# encoder representations of four different words
word_1 = torch.tensor([1.0, 0.0, 0.0])
word_2 = torch.tensor([0.0, 1.0, 0.0])
word_3 = torch.tensor([1.0, 1.0, 0.0])
word_4 = torch.tensor([0.0, 0.0, 1.0])


In [116]:
# weight matrix
torch.manual_seed(42)
W_Q = torch.rand(3, 3)
W_K = torch.rand(3, 3)
W_V = torch.rand(3, 3)

In [117]:
# generating the queries, keys and values
Q_1 = torch.matmul(W_Q, word_1)
K_1 = torch.matmul(W_K, word_1)
V_1 = torch.matmul(W_V, word_1)

Q_2 = torch.matmul(W_Q, word_2)
K_2 = torch.matmul(W_K, word_2)
V_2 = torch.matmul(W_V, word_2)

Q_3 = torch.matmul(W_Q, word_3)
K_3 = torch.matmul(W_K, word_3)
V_3 = torch.matmul(W_V, word_3)

Q_4 = torch.matmul(W_Q, word_4)
K_4 = torch.matmul(W_K, word_4)
V_4 = torch.matmul(W_V, word_4)

In [118]:
# scoring the first query vector against all key vectors
scores_1 = torch.dot(Q_1, K_1)
scores_2 = torch.dot(Q_1, K_2)
scores_3 = torch.dot(Q_1, K_3)
scores_4 = torch.dot(Q_1, K_4)
scores = torch.stack([scores_1, scores_2, scores_3, scores_4])

In [119]:
# computing the weights by a softmax operation
weights = torch.nn.functional.softmax(scores / torch.sqrt(torch.tensor(3.0)), dim=0)

In [120]:
# computing the attention by a weighted sum of the value vectors
attention = torch.matmul(weights, torch.stack([V_1, V_2, V_3, V_4]))
print(attention)

tensor([0.6075, 0.6090, 0.3037])


In [121]:
# in one step
word_1 = torch.tensor([1.0, 0.0, 0.0])
word_2 = torch.tensor([0.0, 1.0, 0.0])
word_3 = torch.tensor([1.0, 1.0, 0.0])
word_4 = torch.tensor([0.0, 0.0, 1.0])
words = torch.stack([word_1, word_2, word_3, word_4])

torch.manual_seed(42)
W_Q = torch.rand(3, 3)
W_K = torch.rand(3, 3)
W_V = torch.rand(3, 3)

Q = torch.matmul(W_Q, words.T)
K = torch.matmul(W_K, words.T)
V = torch.matmul(W_V, words.T)

scores = torch.matmul(Q.T, K) / torch.sqrt(torch.tensor(3.0))
weights = torch.nn.functional.softmax(scores, dim=1)
attention = torch.matmul(weights, V.T)
print(attention)


tensor([[0.6075, 0.6090, 0.3037],
        [0.6144, 0.6002, 0.3056],
        [0.7058, 0.6380, 0.3293],
        [0.6064, 0.6041, 0.3020]])
