# Self Attention - Basics

In [1]:
import numpy as np
import math
import torch.nn.functional as F
import torch
import torch.nn as nn

This is how attetion is described in the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. <br>

"An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key."

-- Attention Is All You Need, Vaswani et. al.

$Attention(Q,K,V) = softmax( \frac{QK^{T}} {\sqrt{d_{k}}} )V$

Let's suppose that we have a sentence of three words. For example 'sky is blue'. Let's  also assume that these three words are represented using three numbers, 1, 2 and 3. These are called tokens.  Therefore, our sequence of tokens is given by the following tensor. Tokenization is the process of breaking a sequence of characters into different parts. There are many tokenization methods. What I have done here is a very simple tokenization method for the sake of easliy explaining the concepts of the attention mechanism.

In [28]:
sequence =  torch.tensor([[1,2,3]])

In [29]:
sequence

tensor([[1, 2, 3]])

In this case, our `sequence length` is 3.

Tokenization alone is not enough to create a information-rich representation of a character sequence. We will have to come up with a vector representation for each token. These are called embeddings.

let's assume that the embedding vectors of 1,2 and 3 are as follows.

In [4]:
embeddings = {1: [-1.0720, -0.5001], 2:  [-0.0020, -0.4311], 3: [-0.0020, -0.4311]}
embeddings

{1: [-1.072, -0.5001], 2: [-0.002, -0.4311], 3: [-0.002, -0.4311]}

Therefore, our embedded tokens are,

In [5]:
embedded_tokens = torch.tensor([[-1.0720, -0.5001], [-0.0020, -0.4311], [-0.0020, -0.4311] ])

In [6]:
embedded_tokens

tensor([[-1.0720, -0.5001],
        [-0.0020, -0.4311],
        [-0.0020, -0.4311]])

As the next step, let's create the Q,K,V matrices. This is done by transforming our embedding matrix using three weight vectors. WQ, WK and WV.

let's cosinder the following WQ, WK and WV matrices

In [7]:
WQ = torch.tensor([[-0.0271, -0.3840],
        [-0.3940, -0.6610]]).requires_grad_(True)

WK = torch.tensor([[-0.4109,  0.5777],
        [-0.1162, -0.1661]]).requires_grad_(True)

WV = torch.tensor([[-0.2045,  0.1210],
        [-0.1712, -0.4462]]).requires_grad_(True)

Now, let's transform our embedded tokens using the transformation matrices WQ, WK and WV.

In [8]:
Q = embedded_tokens @ WQ
K = embedded_tokens @ WK
V = embedded_tokens @ WV

In [9]:
Q

tensor([[0.2261, 0.7422],
        [0.1699, 0.2857],
        [0.1699, 0.2857]], grad_fn=<MmBackward0>)

In [10]:
K

tensor([[ 0.4986, -0.5362],
        [ 0.0509,  0.0705],
        [ 0.0509,  0.0705]], grad_fn=<MmBackward0>)

In [11]:
V

tensor([[0.3048, 0.0934],
        [0.0742, 0.1921],
        [0.0742, 0.1921]], grad_fn=<MmBackward0>)

In [12]:
d_k = Q.size()[-1]

$$ Q = \begin{bmatrix} Q_{1,1} & Q_{1,2} \\ Q_{2,1} & Q_{2,2} \\ Q_{3,1} & Q_{3,2} \end{bmatrix}$$

$$ Q = \begin{bmatrix} \vec{Q_{1}} \\ \vec{Q_{2}} \end{bmatrix}$$

$$ K = \begin{bmatrix} K_{1,1} & K_{1,2} \\ K_{2,1} & K_{2,2} \\ K_{3,1} & K_{3,2} \end{bmatrix}$$

$$ K^{T} = \begin{bmatrix} \vec{K_{1}} \\ \vec{K_{2}} \end{bmatrix}$$

$$ K^{T} = \begin{bmatrix} K_{1,1} & K_{2,1} \\ K_{1,2} & K_{2,2} \\ K_{1,3} & K_{2,3} \end{bmatrix}$$

<!-- $$ K = \begin{bmatrix} \vec{K_{1}} \\ \vec{K_{2}} \\ \vec{K_{3}} \end{bmatrix}$$ -->

$$ V = \begin{bmatrix} V_{1,1} & V_{1,2} \\ V_{2,1} & V_{2,2}  \\ V_{3,1} & V_{3,2} \end{bmatrix}$$

Attention logits are given by $QK^{T}$

$$ QK^{T} = \begin{bmatrix} Q_{1,1}\times K_{1,1} + Q_{1,2}\times K_{1,2} &
Q_{1,1}\times K_{2,1} + Q_{1,2}\times K_{2,2} & 
Q_{1,1}\times K_{3,1} + Q_{1,2}\times K_{3,2} \\ 
Q_{2,1}\times K_{1,1} + Q_{2,2}\times K_{1,2} & 
Q_{2,1}\times K_{2,1} + Q_{2,2}\times K_{2,2} & 
Q_{2,1}\times K_{3,1} + Q_{2,2}\times K_{3,2} \\
Q_{3,1}\times K_{1,1} + Q_{3,2}\times K_{1,2} & 
Q_{3,1}\times K_{2,1} + Q_{3,2}\times K_{2,2} & 
Q_{3,1}\times K_{3,1} + Q_{3,2}\times K_{3,2} \end{bmatrix}$$

Equivalently,

$$ QK^{T} = \begin{bmatrix} \vec{Q_{1}}\cdot \vec{K_{1}} & 
\vec{Q_{1}}\cdot \vec{K_{2}} & 
\vec{Q_{1}}\cdot \vec{K_{3}} \\ 
\vec{Q_{2}}\cdot \vec{K_{1}} & 
\vec{Q_{2}}\cdot \vec{K_{2}} & 
\vec{Q_{2}}\cdot \vec{K_{3}} \\
\vec{Q_{3}}\cdot \vec{K_{1}} & 
\vec{Q_{3}}\cdot \vec{K_{2}} & 
\vec{Q_{3}}\cdot \vec{K_{3}} \end{bmatrix}$$

In [30]:
attn_logits = torch.matmul(Q, K.T)
attn_logits

tensor([[-0.2853,  0.0638,  0.0638],
        [-0.0685,  0.0288,  0.0288],
        [-0.0685,  0.0288,  0.0288]], grad_fn=<MmBackward0>)

what is $\vec{Q_{1}}$? It is a representation of the $1^{st}$ token. <br>
what is K1? It is another representation of the $1^{st}$ token. <br>

Each element in the `attn_logits` matrix is a measure of how similar two $\vec{Q}$ and $\vec{K}$ vectors are.

Each row of the `attn_logits` matrix shows how similar a query to all the keys. For example, row 1 is (similarity between Q1 and K1, similarity between Q1 and K2, similarity between Q1 and K3).

The next step is scaling the `attention_logits` matrix with the scaling factor, $\sqrt(d_{k}$.


$$ QK^{T}/\sqrt(d_{k}) = \begin{bmatrix} \frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})} & 
\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) } \\ 
\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) } \\
\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) } & 
\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) } \end{bmatrix}$$

In [31]:
scaled_attention_logits = torch.matmul(Q, K.T)/ math.sqrt(d_k)

In [32]:
scaled_attention_logits

tensor([[-0.2017,  0.0451,  0.0451],
        [-0.0484,  0.0204,  0.0204],
        [-0.0484,  0.0204,  0.0204]], grad_fn=<DivBackward0>)

Finally, we apply the softmax operation to each row. This normalizes the values of each row so that the sum of these values is 1.

$$ softmax(QK^{T}/\sqrt(d_{k})) = \begin{bmatrix} \frac{exp({\frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) }{ exp({\frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}})} & 
\frac{exp(\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}})  } & 
\frac{exp(\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{1}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{1}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{1}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } \\ 
\frac{exp(\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } & 
\frac{exp(\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } & 
\frac{exp(\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{2}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{2}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{2}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } \\
\frac{exp(\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{ \sqrt(d_{k}) })}{  exp({\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } & 
\frac{exp(\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}}) } & 
\frac{exp(\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{ \sqrt(d_{k}) })}{ exp({\frac{\vec{Q_{3}}\cdot \vec{K_{1}}}{\sqrt(d_{k})}}) + exp({\frac{\vec{Q_{3}}\cdot \vec{K_{2}}}{\sqrt(d_{k})}}) + \exp({\frac{\vec{Q_{3}}\cdot \vec{K_{3}}}{\sqrt(d_{k})}})  } \end{bmatrix}$$

In [34]:
attention_weights = F.softmax(scaled_attention_logits, dim=-1)

In [35]:
attention_weights

tensor([[0.2809, 0.3595, 0.3595],
        [0.3182, 0.3409, 0.3409],
        [0.3182, 0.3409, 0.3409]], grad_fn=<SoftmaxBackward0>)

Note that values in each row add up to 1.

In [36]:
attention_weights[0].sum()

tensor(1., grad_fn=<SumBackward0>)

In [37]:
attention_weights[1].sum()

tensor(1., grad_fn=<SumBackward0>)

In [38]:
attention_weights[2].sum()

tensor(1., grad_fn=<SumBackward0>)

In [22]:
attention_weights @ V

tensor([[0.1390, 0.1644],
        [0.1476, 0.1607],
        [0.1476, 0.1607]], grad_fn=<MmBackward0>)

output values

In [23]:
torch.matmul(attention_weights, V)

tensor([[0.1390, 0.1644],
        [0.1476, 0.1607],
        [0.1476, 0.1607]], grad_fn=<MmBackward0>)