In [26]:
import torch

In [27]:
input_embeddings = torch.tensor([
    [0.43, 0.15, 0.89], # Your    -> x_0
    [0.55, 0.87, 0.66], # journey -> x_1
    [0.57, 0.85, 0.64], # starts  -> x_2
    [0.22, 0.58, 0.33], # with    -> x_3
    [0.77, 0.25, 0.10], # one     -> x_4
    [0.05, 0.80, 0.55], # step    -> x_5
])

# calculate context vectors naively

## calculate context vector for query 0

In [28]:
query_0 = input_embeddings[0]
query_0

tensor([0.4300, 0.1500, 0.8900])

In [29]:
attn_scores_0 = torch.empty(input_embeddings.shape[0])
for i, x_i in enumerate(input_embeddings):
    attn_scores_0[i] = torch.dot(x_i, query_0)
attn_scores_0

tensor([0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310])

In [30]:
attn_weights_0 = torch.softmax(attn_scores_0, dim=0)
attn_weights_0

tensor([0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452])

In [31]:
context_vec_0 = torch.zeros(query_0.shape)
for i, x_i in enumerate(input_embeddings):
    context_vec_0 += attn_weights_0[i] * x_i
context_vec_0

tensor([0.4421, 0.5931, 0.5790])

## calculate context vector for query 1

In [32]:
query_1 = input_embeddings[1]
attn_scores_1 = torch.empty(input_embeddings.shape[0])
for i, x_i in enumerate(input_embeddings):
    attn_scores_1[i] = torch.dot(x_i, query_1)
attn_weights_1 = torch.softmax(attn_scores_1, dim=0)
context_vec_1 = torch.zeros(query_1.shape)
for i, x_i in enumerate(input_embeddings):
    context_vec_1 += attn_weights_1[i] * x_i
context_vec_1

tensor([0.4419, 0.6515, 0.5683])

## calculate context vector for query 2

In [33]:
query_2 = input_embeddings[2]
attn_scores_2 = torch.empty(input_embeddings.shape[0])
for i, x_i in enumerate(input_embeddings):
    attn_scores_2[i] = torch.dot(x_i, query_2)
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
context_vec_2 = torch.zeros(query_2.shape)
for i, x_i in enumerate(input_embeddings):
    context_vec_2 += attn_weights_2[i] * x_i
context_vec_2

tensor([0.4431, 0.6496, 0.5671])

## calculate context vector for query 3

In [34]:
query_3 = input_embeddings[3]
attn_scores_3 = torch.zeros(input_embeddings.shape[0])
for i, x_i in enumerate(input_embeddings):
    attn_scores_3[i] = torch.dot(x_i, query_3)
attn_weights_3 = torch.softmax(attn_scores_3, dim=0)
context_vec_3 = torch.zeros(query_3.shape)
for i, x_i in enumerate(input_embeddings):
    context_vec_3 += attn_weights_3[i] * x_i
context_vec_3

tensor([0.4304, 0.6298, 0.5510])

## calculate context vector for query 4

In [35]:
query_4 = input_embeddings[4]
attn_scores_4 = torch.zeros(input_embeddings.shape[0])
for i, x_i in enumerate(input_embeddings):
    attn_scores_4[i] = torch.dot(x_i, query_4)
attn_weights_4 = torch.softmax(attn_scores_4, dim=0)
context_vec_4 = torch.zeros(query_4.shape)
for i, x_i in enumerate(input_embeddings):
    context_vec_4 += attn_weights_4[i] * x_i
context_vec_4

tensor([0.4671, 0.5910, 0.5266])

## calculate context vector for query 5

In [36]:
query_5 = input_embeddings[5]
attn_scores_5 = torch.zeros(input_embeddings.shape[0])
for i, x_i in enumerate(input_embeddings):
    attn_scores_5[i] = torch.dot(x_i, query_5)
attn_weights_5 = torch.softmax(attn_scores_5, dim=0)
context_vec_5 = torch.zeros(query_5.shape)
for i, x_i in enumerate(input_embeddings):
    context_vec_5 += attn_weights_5[i] * x_i
context_vec_5

tensor([0.4177, 0.6503, 0.5645])

## combine context vectors

In [48]:
context_vecs = torch.stack([
    context_vec_0,
    context_vec_1,
    context_vec_2,
    context_vec_3,
    context_vec_4,
    context_vec_5,
])
context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

# calculate context vectors in one go

In [45]:
attn_scores = input_embeddings @ input_embeddings.T
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [46]:
attn_weights = torch.softmax(attn_scores, dim=-1)
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [47]:
context_vecs = attn_weights @ input_embeddings
context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])