# Self-attention PyTorch implementation
Python implementation of the self (scaled-dot product) attention mechanism originally proposed in "Attention Is All You Need".

# Setup

In [1]:
import torch
import torch.nn.functional as F

In [2]:
torch.manual_seed(42)
dim_word_embedding = 16

Let's map each **unique** word present in the sentence to an index.
In real-world scenarios, there would be a bigger vocubalary, e.g. hundred of thousands of words. 

In [3]:
sentence = 'Life is short, eat dessert first is life'
sentence = sentence.lower().replace(',', '').split()
n_sentence = len(sentence)
dict_words = {word: idx for idx, word in enumerate(sorted(sentence))}
sentence_ints = torch.tensor([dict_words[char] for char in sentence])

In [4]:
print('Word dictionary =', dict_words)
print('Numeric representation of the input sentence =', sentence_ints)

Word dictionary = {'dessert': 0, 'eat': 1, 'first': 2, 'is': 4, 'life': 6, 'short': 7}
Numeric representation of the input sentence = tensor([6, 4, 7, 1, 0, 2, 4, 6])


# Single head self-attention
![single-attention-head](./images/single-head.png)

Now create a $n_{words} \times dim_{word_{embedding}} = 8 \times 16$ embedding tensor.

These initial values will be randomly generated and $dim_{vec} = 16$ is an hyper-parameter.

In [5]:
embedding_layer = torch.nn.Embedding(n_sentence, dim_word_embedding)
embedded_sentence = embedding_layer(sentence_ints).detach()
print('Embedded sentence shape =', embedded_sentence.shape)

Embedded sentence shape = torch.Size([8, 16])


Let's choose values for Queries, Keys, and Values.
Since Queries and Keys vector are multiplied afterward, their dimensions MUST be equal.

In [6]:
dim_q, dim_v = 24, 28
dim_k = dim_q
W_q = torch.rand(dim_q, dim_word_embedding)
W_k = torch.rand(dim_k, dim_word_embedding)
W_v = torch.rand(dim_v, dim_word_embedding)

## Compute queries, keys, values for the input words
First, try to understand this on single words.

### Example #1
Compute query, key, and value vector related to the first and second words in the sentence ($idx = 0, 1$).

In [7]:
word_idx = 0
x_1 = embedded_sentence[word_idx]
query_1 = W_q @ x_1
key_1 = W_k @ x_1
value_1 = W_v @ x_1
print('Word idx =', word_idx)
print('\tQuery size =', query_1.shape)
print('\tKey size =', key_1.shape)
print('\tValue size =', value_1.shape)

Word idx = 0
	Query size = torch.Size([24])
	Key size = torch.Size([24])
	Value size = torch.Size([28])


In [8]:
word_idx = 1
x_2 = embedded_sentence[word_idx]
query_2 = W_q @ x_2
key_2 = W_k @ x_2
value_2 = W_v @ x_2
print('Word idx =', word_idx)
print('\tQuery size =', query_2.shape)
print('\tKey size =', key_2.shape)
print('\tValue size =', value_2.shape)

Word idx = 1
	Query size = torch.Size([24])
	Key size = torch.Size([24])
	Value size = torch.Size([28])


### Compute queries, keys, values for ALL input words
Now that I have understood the process, let's extend this to all the input words.
Transpose the matrices to have words on rows and vector components on columns.

In [9]:
queries = (W_q @ embedded_sentence.T).T
keys = (W_k @ embedded_sentence.T).T
values = (W_v @ embedded_sentence.T).T
print('All input words -> queries.shape =', queries.shape) 
print('All input words -> keys.shape =', keys.shape) 
print('All input words -> values.shape =', values.shape)

All input words -> queries.shape = torch.Size([8, 24])
All input words -> keys.shape = torch.Size([8, 24])
All input words -> values.shape = torch.Size([8, 28])


## Unnormalized attention scores
As before, I am trying to understand this concept on single words.

### Example #2
Let's compute the unnormalized attention score $\omega$ (omega) for the first word w.r.t. the 5th word

In [10]:
omega_1_5 = queries[0] @ keys[4]
print('Unnormalized attention score of first word w.r.t. 5th word =', omega_1_5.item())

Unnormalized attention score of first word w.r.t. 5th word = -46.02456283569336


### Compute unnormalized attention scores w.r.t. ALL input words
Let's compute the unnormalized attention scores for the first word w.r.t. to all other words.

In [11]:
omega_1_all = queries[0] @ keys.T
print(f'Unnormalized attention scores of first word w.r.t. ALL other words:\n{omega_1_all}')

Unnormalized attention scores of first word w.r.t. ALL other words:
tensor([ 47.9667,  58.9805,  42.1271, 141.0643, -46.0246, -72.1767,  58.9805,
         47.9667])


### Compute unnormalized attention scores for ALL input words
Let's compute the unnormalized attention scores for all the words w.r.t. to all other words.

In [12]:
omega_all = queries @ keys.T
print(f'Unnormalized attention scores of ALL words w.r.t. ALL other words:\n{omega_all}')
print("All input words scores -> omega_all.shape =", omega_all.shape) 

Unnormalized attention scores of ALL words w.r.t. ALL other words:
tensor([[  47.9667,   58.9805,   42.1272,  141.0642,  -46.0246,  -72.1767,
           58.9805,   47.9667],
        [  58.7503,   93.6661,   65.4516,  229.0244,  -51.6797, -109.5712,
           93.6661,   58.7503],
        [  47.7602,   53.6036,   42.4971,  132.3753,  -45.4809,  -63.6152,
           53.6036,   47.7602],
        [ 145.1907,  182.7591,  157.4097,  479.7147, -139.6413, -238.2749,
          182.7591,  145.1907],
        [ -26.0050,  -25.1779,  -21.8697,  -61.3257,   20.7895,   35.2284,
          -25.1779,  -26.0050],
        [ -71.2604,  -94.5636,  -83.3776, -257.1654,   66.7258,  130.5431,
          -94.5636,  -71.2604],
        [  58.7503,   93.6661,   65.4516,  229.0244,  -51.6797, -109.5712,
           93.6661,   58.7503],
        [  47.9667,   58.9805,   42.1272,  141.0642,  -46.0246,  -72.1767,
           58.9805,   47.9667]])
All input words scores -> omega_all.shape = torch.Size([8, 8])


## Normalized attention scores
Why? To reach more numeric stability and thus, reduce errors.
Note that the sum of the values is 1 (thanks to the softmax function).

### Example #3
Let's compute the normalized attention score $\alpha$ (alpha) for the first word w.r.t. the ALL other words.

In [13]:
normalized_attention_scores_1 = F.softmax(omega_1_all / dim_k ** 0.5, dim=0)
print(f'Normalized attention scores of first word w.r.t. ALL other words:\n{normalized_attention_scores_1}')
print('Sum of this vector =', normalized_attention_scores_1.sum().item())

Normalized attention scores of first word w.r.t. ALL other words:
tensor([5.5834e-09, 5.2879e-08, 1.6952e-09, 1.0000e+00, 2.5976e-17, 1.2479e-19,
        5.2879e-08, 5.5834e-09])
Sum of this vector = 1.0


### Compute normalized attention scores for ALL input words
Let's compute the normalized attention score $\alpha$ (alpha) for ALL words w.r.t. ALL other words.

In [14]:
normalized_attention_scores = F.softmax(omega_all / dim_k ** 0.5, dim=0)
print(f'Normalized attention scores of ALL words w.r.t. ALL other words:\n{normalized_attention_scores}')
print('Sum of this vector =', normalized_attention_scores.sum().item())

Normalized attention scores of ALL words w.r.t. ALL other words:
tensor([[2.4049e-09, 1.0642e-11, 6.0284e-11, 9.5202e-31, 1.0108e-10, 1.0688e-18,
         1.0642e-11, 2.4049e-09],
        [2.1730e-08, 1.2645e-08, 7.0456e-09, 5.9746e-23, 3.1866e-11, 5.1745e-22,
         1.2645e-08, 2.1730e-08],
        [2.3056e-09, 3.5511e-12, 6.5013e-11, 1.6157e-31, 1.1294e-10, 6.1358e-18,
         3.5511e-12, 2.3056e-09],
        [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 5.0760e-19, 2.0151e-33,
         1.0000e+00, 1.0000e+00],
        [6.6606e-16, 3.6846e-19, 1.2790e-16, 0.0000e+00, 8.4667e-05, 3.5510e-09,
         3.6846e-19, 6.6606e-16],
        [6.4806e-20, 2.6022e-25, 4.5103e-22, 0.0000e+00, 9.9992e-01, 1.0000e+00,
         2.6022e-25, 6.4806e-20],
        [2.1730e-08, 1.2645e-08, 7.0456e-09, 5.9747e-23, 3.1866e-11, 5.1745e-22,
         1.2645e-08, 2.1730e-08],
        [2.4049e-09, 1.0642e-11, 6.0284e-11, 9.5202e-31, 1.0108e-10, 1.0688e-18,
         1.0642e-11, 2.4049e-09]])
Sum of this ve

## Compute $z$ vector
The $z$ vector is an enhanced version of the input word. In other words, it *"embeds"* the information about ALL the other inputs words.

![context-vector](./images/context-vector.png)

Note that the dimension of $z = dim_v$, and in this case $dim_v > dim_{word_{embedding}}$, but it turns out that can be arbitrarly chosen.

### Example #4
Let's compute $z$ for the first word w.r.t. ALL other words.

In [15]:
context_vector_z_1 = normalized_attention_scores_1 @ values
print('Context vector related to the first word w.r.t. ALL other words')
print('Shape =', context_vector_z_1.shape)
print(context_vector_z_1)

Context vector related to the first word w.r.t. ALL other words
Shape = torch.Size([28])
tensor([5.2083, 4.5906, 3.1900, 4.1853, 4.7579, 4.2178, 3.4120, 5.0137, 4.0621,
        4.1455, 6.3549, 3.0836, 6.4934, 4.0425, 4.8676, 3.0821, 6.8482, 5.5784,
        4.6929, 5.4580, 5.6707, 4.9629, 4.2686, 5.6802, 5.3528, 4.5219, 4.6112,
        4.7807])


### Compute $z$ for ALL input words
Let's compute $z$ for ALL words w.r.t. ALL other words.

In [16]:
context_z = normalized_attention_scores @ values
print('Context vector related to ALL input words w.r.t. ALL other words')
print('Shape =', context_z.shape)
print(context_z)

Context vector related to ALL input words w.r.t. ALL other words
Shape = torch.Size([8, 28])
tensor([[ 1.0208e-08,  4.5245e-09,  8.9584e-09,  8.8031e-09,  1.0888e-08,
          1.1263e-08,  6.6825e-09,  9.1426e-09,  3.3903e-09,  9.7982e-09,
          1.1035e-08,  4.7999e-09,  1.4235e-08,  1.4572e-08,  6.5525e-09,
          7.8531e-09,  8.0995e-09,  1.5258e-08,  1.1228e-09,  1.2500e-08,
          1.6545e-09,  8.0533e-09,  1.2058e-08,  1.8843e-09,  1.4327e-08,
          1.2405e-08,  1.1857e-09,  1.0192e-08],
        [ 1.3333e-07,  1.4292e-07,  1.4621e-07,  5.8652e-08,  9.4583e-08,
          2.0076e-07,  1.0602e-07,  1.5732e-07,  9.1718e-08,  1.1024e-07,
          1.1862e-07,  5.8737e-08,  1.7114e-07,  1.5896e-07,  1.0163e-07,
          9.6130e-08,  1.6256e-07,  2.0144e-07,  6.1971e-08,  1.4980e-07,
          5.8659e-08,  1.8701e-07,  1.8030e-07,  9.1482e-08,  1.2541e-07,
          1.3888e-07,  7.5303e-08,  1.3610e-07],
        [ 9.7775e-09,  4.3164e-09,  8.5718e-09,  8.4555e-09,  1.0414e

## Ok but...are you getting confused by all these dimension?
Try to check out this picture.
![single-attention-head](./images/single-attention-head.png)

# Multi-head self-attention
Each query has now shape $n_{heads} \times dim_{word_{embedding}}$. 
![multi-attention-head](./images/multi-head.png)

In [17]:
n_heads = 3
multihead_W_query = torch.rand(n_heads, dim_q, dim_word_embedding) 
multihead_W_key = torch.rand(n_heads, dim_k, dim_word_embedding)
multihead_W_value = torch.rand(n_heads, dim_v, dim_word_embedding)

Follow this example with the first word.

In [18]:
word_idx = 0
x_1 = embedded_sentence[word_idx]
multihead_query_1 = multihead_W_query @ x_1
multihead_key_1 = multihead_W_key @ x_1
multihead_value_1 = multihead_W_value @ x_1
print('Word idx =', word_idx)
print('\tMultihead-query size =', multihead_query_1.shape)
print('\tMultihead-key size =', multihead_key_1.shape)
print('\tMultihead-value size =', multihead_value_1.shape)

Word idx = 0
	Multihead-query size = torch.Size([3, 24])
	Multihead-key size = torch.Size([3, 24])
	Multihead-value size = torch.Size([3, 28])


But for extend this example to all the words, we have to use the `bmm()` (Batch Matrix Mutliplication) method from **PyTorch**. It's useful when there's the need to deal with tensor. The effective `matmul`is performed along the last two dimensions, while the first one representing heads is preserved.

Finally, to make the input feasible (i.e., choose the right dimensions) for the multi-head attention layer, I need to replicate it for the number of attention heads (in this case, three times).

In [19]:
repeated_inputs = embedded_sentence.T.repeat(n_heads, 1, 1)
print('Repeated input size =', repeated_inputs.shape)

Repeated input size = torch.Size([3, 16, 8])


In [20]:
multihead_queries = torch.bmm(multihead_W_query, repeated_inputs)
multihead_keys = torch.bmm(multihead_W_key, repeated_inputs)
multihead_values = torch.bmm(multihead_W_value, repeated_inputs)
print('All input words -> multihead_queries.shape =', multihead_queries.shape) 
print('All input words -> multihead_keys.shape =', multihead_keys.shape) 
print('All input words -> multihead_values.shape =', multihead_values.shape)

All input words -> multihead_queries.shape = torch.Size([3, 24, 8])
All input words -> multihead_keys.shape = torch.Size([3, 24, 8])
All input words -> multihead_values.shape = torch.Size([3, 28, 8])


But - as can be seen from above - the last two dimensions are swapped, because I would like to have words as the second dimension and vector components as third one. Hence, let's permute (i.e., swap) the last two dimensions.

Here, the `permute()` method is asking the desidered ordering of dimensions.

In [21]:
multihead_queries = multihead_queries.permute(0, 2, 1)
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print('All input words -> multihead_queries.shape =', multihead_queries.shape) 
print('All input words -> multihead_keys.shape =', multihead_keys.shape) 
print('All input words -> multihead_values.shape =', multihead_values.shape)

All input words -> multihead_queries.shape = torch.Size([3, 8, 24])
All input words -> multihead_keys.shape = torch.Size([3, 8, 24])
All input words -> multihead_values.shape = torch.Size([3, 8, 28])


## Compute unnormalized attention scores
As did before, compute the unnormalzied attention scores for each head.

In [22]:
multihead_omega_all = torch.bmm(multihead_queries, multihead_keys.permute(0,2,1))
print(f'Unnormalized attention scores of ALL words w.r.t. ALL other words:\n{multihead_omega_all}')
print("All input words scores -> omega_all.shape =", multihead_omega_all.shape) 

Unnormalized attention scores of ALL words w.r.t. ALL other words:
tensor([[[  63.3466,   69.7982,   60.7684,  168.7238,  -35.8708,  -90.1632,
            69.7982,   63.3466],
         [  64.3794,   50.0433,   56.6445,  156.4643,  -33.4781,  -78.7985,
            50.0433,   64.3794],
         [  54.1203,   63.1504,   53.7396,  154.9160,  -35.4911,  -82.0983,
            63.1504,   54.1203],
         [ 178.0788,  180.1088,  162.0738,  483.6229,  -87.1213, -246.8582,
           180.1088,  178.0788],
         [ -27.7026,  -43.6290,  -35.4398, -103.1395,   30.3216,   52.3912,
           -43.6290,  -27.7026],
         [ -76.8325,  -87.6214,  -66.1114, -224.5666,   31.7686,  111.7955,
           -87.6214,  -76.8325],
         [  64.3794,   50.0433,   56.6445,  156.4643,  -33.4781,  -78.7985,
            50.0433,   64.3794],
         [  63.3466,   69.7982,   60.7684,  168.7238,  -35.8708,  -90.1632,
            69.7982,   63.3466]],

        [[  40.2509,   51.2320,   37.3632,  140.1730,  -24.

### Compute normalized attention scores for ALL input words
Let's compute the normalized attention score $\alpha$ (alpha) for ALL words w.r.t. ALL other words.

In [23]:
multihead_normalized_attention_scores = F.softmax(multihead_omega_all / dim_k ** 0.5, dim=0)
print(f'Multihead normalized attention scores of ALL words w.r.t. ALL other words:\n{multihead_normalized_attention_scores}')
print('Sum of this vector =', multihead_normalized_attention_scores.sum().item())

Multihead normalized attention scores of ALL words w.r.t. ALL other words:
tensor([[[5.5202e-01, 9.7718e-01, 8.2888e-01, 8.5065e-01, 8.6035e-02,
          1.8976e-03, 9.7718e-01, 5.5202e-01],
         [1.5916e-01, 5.0917e-02, 4.2787e-01, 8.0445e-05, 6.5592e-01,
          3.3284e-02, 5.0917e-02, 1.5916e-01],
         [5.1472e-01, 7.0178e-01, 7.4659e-01, 2.9650e-02, 2.5820e-01,
          1.2600e-02, 7.0178e-01, 5.1472e-01],
         [9.9754e-01, 1.4868e-03, 9.8379e-01, 1.7937e-07, 9.2739e-01,
          3.2141e-05, 1.4868e-03, 9.9754e-01],
         [4.5928e-01, 2.2055e-02, 1.5059e-01, 1.2142e-02, 3.6933e-01,
          7.8774e-01, 2.2055e-02, 4.5928e-01],
         [1.9087e-01, 1.2551e-02, 5.9268e-01, 9.5008e-01, 9.0600e-04,
          6.9592e-02, 1.2551e-02, 1.9087e-01],
         [1.5916e-01, 5.0917e-02, 4.2787e-01, 8.0445e-05, 6.5592e-01,
          3.3284e-02, 5.0917e-02, 1.5916e-01],
         [5.5202e-01, 9.7718e-01, 8.2888e-01, 8.5065e-01, 8.6035e-02,
          1.8976e-03, 9.7718e-01, 5.

### Compute $z$ for ALL input words
Let's compute $z$ for ALL words w.r.t. ALL other words.

In [24]:
multihead_context_z = multihead_normalized_attention_scores @ multihead_values
print('Multihead context vector related to ALL input words w.r.t. ALL other words')
print('Shape =', multihead_context_z.shape)
print(multihead_context_z)

Multihead context vector related to ALL input words w.r.t. ALL other words
Shape = torch.Size([3, 8, 28])
tensor([[[ 1.2550e+01,  8.2323e+00,  9.1246e+00,  2.9574e+00,  1.3149e+01,
           1.1474e+01,  5.4783e+00,  8.9908e+00,  8.1721e+00,  9.5334e+00,
           1.2736e+01,  1.0823e+01,  1.4178e+01,  1.3128e+01,  1.0271e+01,
           7.1355e+00,  9.9221e+00,  1.1401e+01,  9.2407e+00,  8.2230e+00,
           1.4161e+01,  1.2087e+01,  5.3284e+00,  1.3517e+01,  1.0876e+01,
           1.0542e+01,  1.3312e+01,  9.9359e+00],
         [ 8.4613e-01,  5.4901e-01, -9.3701e-01, -5.6394e-02,  1.7809e-01,
           6.3644e-01,  3.3348e-01, -6.3719e-02,  7.8778e-01,  4.3655e-01,
           5.3653e-01,  6.2437e-01,  7.7565e-01, -5.2384e-01, -1.3249e-01,
           5.0235e-01,  1.7263e-01,  9.4807e-01, -1.7731e-01,  8.8894e-02,
           9.8492e-02,  6.2525e-01, -1.3491e+00,  1.0078e+00,  5.1598e-01,
          -1.0351e+00,  7.7803e-01,  1.1364e+00],
         [ 6.7393e+00,  4.8771e+00,  3.2697e

# Cross-attention [TODO]
This mechanism is used in Transformer to perform an attention calculus among different inputs.