In [1]:
import torch


# input sequence / sentence:
#  "Can you help me to translate this sentence"

sentence = torch.tensor(
    [0, # can
     7, # you     
     1, # help
     2, # me
     5, # to
     6, # translate
     4, # this
     3] # sentence
)

sentence

tensor([0, 7, 1, 2, 5, 6, 4, 3])

In [2]:
torch.manual_seed(123)
embed = torch.nn.Embedding(10, 16)
embedded_sentence = embed(sentence).detach()
embedded_sentence.shape

torch.Size([8, 16])

In [3]:
omega = torch.empty(8, 8)

for i, x_i in enumerate(embedded_sentence):
    for j, x_j in enumerate(embedded_sentence):
        omega[i, j] = torch.dot(x_i, x_j)

In [4]:
omega_mat = embedded_sentence.matmul(embedded_sentence.T)

In [5]:
torch.allclose(omega_mat, omega)

True

In [6]:
omega

tensor([[ 9.7601,  1.7326,  4.7543, -1.3587,  0.4752, -1.6717,  1.0227, -0.1286],
        [ 1.7326, 16.0787,  9.0642, -0.3370,  1.1368,  1.1972,  1.6485, -1.2789],
        [ 4.7543,  9.0642, 22.6615, -0.8519,  7.7799,  2.7483, -0.6832,  1.6236],
        [-1.3587, -0.3370, -0.8519, 13.9473, -1.4198, 10.9659, -0.5887,  2.3869],
        [ 0.4752,  1.1368,  7.7799, -1.4198, 13.7511, -6.8568, -2.5114, -3.3468],
        [-1.6717,  1.1972,  2.7483, 10.9659, -6.8568, 24.6738, -3.8294,  4.9581],
        [ 1.0227,  1.6485, -0.6832, -0.5887, -2.5114, -3.8294, 15.8691,  2.0269],
        [-0.1286, -1.2789,  1.6236,  2.3869, -3.3468,  4.9581,  2.0269, 18.7382]])

In [7]:
import torch.nn.functional as F

attention_weights = F.softmax(omega, dim=1)
attention_weights.shape

torch.Size([8, 8])

In [8]:
attention_weights

tensor([[9.9270e-01, 3.2398e-04, 6.6502e-03, 1.4723e-05, 9.2135e-05, 1.0766e-05,
         1.5929e-04, 5.0374e-05],
        [5.8773e-07, 9.9910e-01, 8.9788e-04, 7.4187e-08, 3.2391e-07, 3.4407e-07,
         5.4033e-07, 2.8926e-08],
        [1.6712e-08, 1.2438e-06, 1.0000e+00, 6.1412e-11, 3.4437e-07, 2.2482e-09,
         7.2703e-11, 7.3008e-10],
        [2.1438e-07, 5.9550e-07, 3.5585e-07, 9.5172e-01, 2.0167e-07, 4.8272e-02,
         4.6299e-07, 9.0760e-06],
        [1.7110e-06, 3.3158e-06, 2.5448e-03, 2.5719e-07, 9.9745e-01, 1.1195e-09,
         8.6338e-08, 3.7443e-08],
        [3.6165e-12, 6.3713e-11, 3.0052e-10, 1.1136e-06, 2.0250e-14, 1.0000e+00,
         4.1804e-13, 2.7390e-09],
        [3.5667e-07, 6.6694e-07, 6.4779e-08, 7.1194e-08, 1.0410e-08, 2.7865e-09,
         1.0000e+00, 9.7366e-07],
        [6.4013e-09, 2.0263e-09, 3.6918e-08, 7.9205e-08, 2.5622e-10, 1.0361e-06,
         5.5258e-08, 1.0000e+00]])

In [9]:
attention_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [10]:
attention_weights.sum(dim=0)

tensor([0.9927, 0.9994, 1.0101, 0.9517, 0.9975, 1.0483, 1.0002, 1.0001])

In [11]:
x_2 = embedded_sentence[1, :]
context_vec_2 = torch.zeros(x_2.shape)
for j in range(8):
    x_j = embedded_sentence[j, :]
    context_vec_2 += attention_weights[1, j] * x_j
    
context_vec_2

tensor([-9.3975e-01, -4.6856e-01,  1.0311e+00, -2.8192e-01,  4.9373e-01,
        -1.2896e-02, -2.7327e-01, -7.6358e-01,  1.3958e+00, -9.9543e-01,
        -7.1287e-04,  1.2449e+00, -7.8077e-02,  1.2765e+00, -1.4589e+00,
        -2.1601e+00])

In [12]:
context_vectors = torch.matmul(
        attention_weights, embedded_sentence)


torch.allclose(context_vec_2, context_vectors[1])

True

### scaled dot-product attention

In [13]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]
U_query = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)

In [14]:
x_2 = embedded_sentence[1]
query_2 = U_query.matmul(x_2)

In [15]:
key_2 = U_key.matmul(x_2)
value_2 = U_value.matmul(x_2)

In [16]:
keys = U_key.matmul(embedded_sentence.T).T
torch.allclose(key_2, keys[1])

True

In [17]:
values = U_value.matmul(embedded_sentence.T).T
torch.allclose(value_2, values[1])

True

In [18]:
omega_23 = query_2.dot(keys[2])
omega_23

tensor(14.3667)

In [19]:
omega_2 = query_2.matmul(keys.T)
omega_2

tensor([-25.1623,   9.3602,  14.3667,  32.1482,  53.8976,  46.6626,  -1.2131,
        -32.9392])

In [20]:
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
attention_weights_2

tensor([2.2317e-09, 1.2499e-05, 4.3696e-05, 3.7242e-03, 8.5596e-01, 1.4026e-01,
        8.8897e-07, 3.1935e-10])

In [21]:
context_vector_2 = attention_weights_2.matmul(values)
context_vector_2

tensor([-1.2226, -3.4387, -4.3928, -5.2125, -1.1249, -3.3041, -1.4316, -3.2765,
        -2.5114, -2.6105, -1.5793, -2.8433, -2.4142, -0.3998, -1.9917, -3.3499])