In [1]:
import torch

# input sequence / sentence:
#  "Can you help me to translate this sentence"

sentence = torch.tensor(
    [0, # can
     7, # you     
     1, # help
     2, # me
     5, # to
     6, # translate
     4, # this
     3] # sentence
)

sentence

tensor([0, 7, 1, 2, 5, 6, 4, 3])

In [2]:
torch.manual_seed(123)
embed = torch.nn.Embedding(10, 16)
embedded_sentence = embed(sentence).detach()
embedded_sentence.shape

torch.Size([8, 16])

In [3]:

#omega = torch.empty(8, 8)

#for i, x_i in enumerate(embedded_sentence):
#    for j, x_j in enumerate(embedded_sentence):
#        omega[i, j] = torch.dot(x_i, x_j)
#more effective way (instead of nested for loop) with matrix multiplication
omega_mat = embedded_sentence.matmul(embedded_sentence.T)

In [4]:
print(omega_mat)

tensor([[ 9.7601,  1.7326,  4.7543, -1.3587,  0.4752, -1.6717,  1.0227, -0.1286],
        [ 1.7326, 16.0787,  9.0642, -0.3370,  1.1368,  1.1972,  1.6485, -1.2789],
        [ 4.7543,  9.0642, 22.6615, -0.8519,  7.7799,  2.7483, -0.6832,  1.6236],
        [-1.3587, -0.3370, -0.8519, 13.9473, -1.4198, 10.9659, -0.5887,  2.3869],
        [ 0.4752,  1.1368,  7.7799, -1.4198, 13.7511, -6.8568, -2.5114, -3.3468],
        [-1.6717,  1.1972,  2.7483, 10.9659, -6.8568, 24.6738, -3.8294,  4.9581],
        [ 1.0227,  1.6485, -0.6832, -0.5887, -2.5114, -3.8294, 15.8691,  2.0269],
        [-0.1286, -1.2789,  1.6236,  2.3869, -3.3468,  4.9581,  2.0269, 18.7382]])


In [None]:
#ezzel lehetne osszehasonlitani
#torch.allclose(omega_mat, omega)

In [5]:
omega = torch.empty(8, 8)

In [6]:
import torch.nn.functional as F

attention_weights = F.softmax(omega, dim=1)
attention_weights.shape

torch.Size([8, 8])

In [7]:
attention_weights.sum(dim=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [8]:
#now that we have the attention weights, we can compute the context vector (eg for the element at index 1)
x_2 = embedded_sentence[1, :]
context_vec_2 = torch.zeros(x_2.shape)
for j in range(8):
    x_j = embedded_sentence[j, :]
    context_vec_2 += attention_weights[1, j] * x_j
    
context_vec_2

tensor([-0.3214, -0.0936, -0.5494,  0.3680,  0.0444,  0.4959, -0.0534, -0.3653,
         0.8559, -0.2540,  0.1836, -0.2252, -0.5542,  0.2524, -0.1345, -0.9024])

In [9]:
context_vectors = torch.matmul(
        attention_weights, embedded_sentence)


torch.allclose(context_vec_2, context_vectors[1])

True

In [None]:
#scaled dot-product attention

In [10]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]
U_query = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)

In [11]:
x_2 = embedded_sentence[1]
query_2 = U_query.matmul(x_2)
key_2 = U_key.matmul(x_2)
value_2 = U_value.matmul(x_2)
keys = U_key.matmul(embedded_sentence.T).T
torch.allclose(key_2, keys[1])

True

In [12]:
values = U_value.matmul(embedded_sentence.T).T
torch.allclose(value_2, values[1])

True

In [13]:
omega_23 = query_2.dot(keys[2])
omega_23        

tensor(14.3667)

In [14]:
omega_2 = query_2.matmul(keys.T)
omega_2

tensor([-25.1623,   9.3602,  14.3667,  32.1482,  53.8976,  46.6626,  -1.2131,
        -32.9392])

In [15]:
attention_weights_2 = F.softmax(omega_2 / d**0.5, dim=0)
attention_weights_2

tensor([2.2317e-09, 1.2499e-05, 4.3696e-05, 3.7242e-03, 8.5596e-01, 1.4026e-01,
        8.8897e-07, 3.1935e-10])

In [None]:
#context_vector_2nd = torch.zeros(values[1, :].shape)
#for j in range(8):
#    context_vector_2nd += attention_weights_2[j] * values[j, :]
    
#context_vector_2nd

In [16]:
context_vector_2 = attention_weights_2.matmul(values)
context_vector_2

tensor([-1.2226, -3.4387, -4.3928, -5.2125, -1.1249, -3.3041, -1.4316, -3.2765,
        -2.5114, -2.6105, -1.5793, -2.8433, -2.4142, -0.3998, -1.9917, -3.3499])

In [None]:
#the original transformer architecture

In [17]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]
one_U_query = torch.rand(d, d)

In [18]:
h = 8
multihead_U_query = torch.rand(h, d, d)
multihead_U_key = torch.rand(h, d, d)
multihead_U_value = torch.rand(h, d, d)

In [19]:
multihead_query_2 = multihead_U_query.matmul(x_2)
multihead_query_2.shape

torch.Size([8, 16])

In [20]:
multihead_key_2 = multihead_U_key.matmul(x_2)
multihead_value_2 = multihead_U_value.matmul(x_2)

In [21]:
multihead_key_2[2]

tensor([-1.9619, -0.7701, -0.7280, -1.6840, -1.0801, -1.6778,  0.6763,  0.6547,
         1.4445, -2.7016, -1.1364, -1.1204, -2.4430, -0.5982, -0.8292, -1.4401])

In [22]:
stacked_inputs = embedded_sentence.T.repeat(8, 1, 1)
stacked_inputs.shape

torch.Size([8, 16, 8])

In [23]:
multihead_keys = torch.bmm(multihead_U_key, stacked_inputs)
multihead_keys.shape

torch.Size([8, 16, 8])

In [24]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_keys.shape

torch.Size([8, 8, 16])

In [25]:
multihead_keys[2, 1] # index: [2nd attention head, 2nd key]

tensor([-1.9619, -0.7701, -0.7280, -1.6840, -1.0801, -1.6778,  0.6763,  0.6547,
         1.4445, -2.7016, -1.1364, -1.1204, -2.4430, -0.5982, -0.8292, -1.4401])

In [26]:
multihead_values = torch.matmul(multihead_U_value, stacked_inputs)
multihead_values = multihead_values.permute(0, 2, 1)

In [27]:
multihead_z_2 = torch.rand(8, 16)

In [28]:
linear = torch.nn.Linear(8*16, 16)
context_vector_2 = linear(multihead_z_2.flatten())
context_vector_2.shape

torch.Size([16])

In [None]:
#Learning a language model: decoder and masked multi-head attention