# Self-Attention 

Ref: https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html

We can think of self-attention as a mechanism that enhances the information content of an input embedding by including information about the input’s context. In other words, the self-attention mechanism enables the model to weigh the importance of different elements in an input sequence and dynamically adjust their influence on the output.

### Embedding an Input Sentence

In [1]:
sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


use this dictionary to assign an integer index to each word

In [2]:
import torch

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


Now, using the integer-vector representation of the input sentence, we can use an embedding layer to encode the inputs into a real-vector embedding. Here, we will use a 16-dimensional embedding such that each input word is represented by a 16-dimensional vector. Since the sentence consists of 6 words, this will result in a **6×16**
-dimensional embedding

In [5]:
torch.manual_seed(123)

d = 16 # size of word vector embedding
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])
torch.Size([6, 16])


## Scaled dot-product self attention mechanism

Self-attention utilizes three weight matrices, referred as $\mathbf{W_{q}}$, $\mathbf{W_{k}}$ and $\mathbf{W_{V}}$, which are adjusted as model parameters during the model training. These matrices serve to project the inputs into query, key, and value components of the sequence, respectively. 

The respective query, key and value sequences are obtained via matrix multiplication between the weight matrices $\mathbf{W}$ and the embedded inputs $\mathbf{x}$:

- Query sequence : $\mathbf{q^{(i)} = W_{q}x^{i}} $ for $i \in [1,T]$
- Key sequence : $\mathbf{k^{(i)} = W_{k}x^{i}} $ for $i \in [1,T]$
- Value sequence : $\mathbf{v^{(i)} = W_{v}x^{i}} $ for $i \in [1,T]$

Where $i$ refers to the token index position in the input sequence, which has lenght $T$.

Let ${d}$ be the size of each word vector $\mathbf{x}$. Therefore, both $q^{(i)}$ and $k^{(i)}$ are vectors of dimension $d_{k}$. The projection matrices $\mathbf{W_{q}}$ and $\mathbf{W_{k}}$ have shape $d_{k} \times d$, while $\mathbf{W_{v}}$ has shape $d_{v} \times d$.

<img src="./img/attention-matrices.png" alt="alt text"  style="height: 50%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 50%;"/>

Since we are computing the dot-product between the query and key vectors, these two vector have to contain the same number of elements ($d_{q} = d_{k}$). However, the number of elements in the value vector $v^{i}$, which determines the size of the resulting context vector, is arbitrary. 

In [6]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 24, 24, 28

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

Computing the Unnormalized Attention Weights

Now, let’s suppose we are interested in computing the attention-vector for the second input element – the second input element acts as the query here:

<img src="./img/query.png" alt="alt text" 
  style="height: 50%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 50%;"/>

In [8]:
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


We can then generalize this to compute $th$ remaining key, and value elements for all inputs as well, since we will need them in the next step when we compute the unnormalized attention weights $\omega$ :
:

In [9]:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


### Unnormalized attention weights $\omega$

<img src="./img/omega.png" alt="alt text" 
  style="height: 50%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 50%;"/>

For which, $w_{i,j}$ is the dot product between the query and the key sequences, $w_{i,j} = \mathbf{q^{(i)^{T}}k^{j}}$.

The unnormalized attention weight for the query and the 5th input element will be:

In [11]:
omega_24 = query_2.dot(keys[4])
print(omega_24)

tensor(11.1466, grad_fn=<DotBackward0>)


Let's compute the $omega$ values for all input tokens as illustrated before:

In [12]:
omega_2 = query_2.matmul(keys.T)
print(omega_2)

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800],
       grad_fn=<SqueezeBackward4>)


## Computing the Attention Scores

The subsequent step in self-attention is to normalize the unnormalized attention weights, $ω$
, to obtain the normalized attention weights, $α$
, by applying the softmax function. Additionally, $\frac{1}{\sqrt{d_{k}}}$
 is used to scale ω
 before normalizing it through the softmax function, as shown below:


<img src="./img/attention-scores.png" alt="alt text" 
  style="height: 80%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 80%;"/>

The scaling by $d_{k}$ ensures that the Euclidean length of the weight vectors will be approximately in the same magnitude. This helps prevent the attention weights from becoming too small or too large, which could lead to numerical instability or affect the model’s ability to converge during training.

In [13]:
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458],
       grad_fn=<SoftmaxBackward0>)


### Attention-weighted version of orignal query $\mathbf{x}$

Finally, the last step is to compute the context vector $\mathbf{z^{(2)}}$, which is an attention-weighted version of our original query input $\mathbf{x^{(2)}}$
, including all the other input elements as its context via the attention weights:

<img src="./img/context-vector.png" alt="alt text" 
  style="height: 80%;   
    display: block;
    margin-left: auto;
    margin-right: auto;
    width: 80%;"/>

In [14]:
context_vector_2 = attention_weights_2.matmul(values)

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084], grad_fn=<SqueezeBackward4>)


Note that this output vector has more dimensions $(d_{v}=28
)$ than the original input vector $(d=16
)$ since we specified $d_{v}>d$
 earlier; however, the embedding size choice is arbitrary