## Coding attention mechanisms


### Implement four different variants of the attention mechanisms:
* Simlified self-attention
* Self-attention
* Casual attention
* Multi-head attention

#### A simple self-attention mechanism wihtout trainable weights

The goal of self-attention is to compute a context vector for each input element that combines information from all other input elements. A context vector can be interpreted as an enriched embedding vector, which is created by incorporating information from all other elements in the sequense.

Consider following input sentance, which has already been embedded into three-dimentional vectors:


In [2]:
%pip install torch numpy


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # your     (x1)
     [0.55, 0.87, 0.66],  # journey  (x2)
     [0.57, 0.85, 0.64],  # starts   (x3)
     [0.22, 0.58, 0.33],  # with     (x4)
     [0.77, 0.25, 0.10],  # one      (x5)
     [0.05, 0.80, 0.55]]  # step     (x6)
)

**Step1**. In order to calculate the intermediate attention scores for second input element we compute the dot product of x2 with every other input token:

In [4]:
# the second input token serves as the query
query = inputs[1]

attn_scores_for_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    attn_scores_for_2 [i] = torch.dot(x_i, query)

print(attn_scores_for_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


The dot product is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot procuct indicates a greater degree of alignment or similarity between the vectors. In the context of the self-attention mechanisms, the dot product determines the extent to which each element in the sequnce focuses on, or **"attends to"** any other element: the higher the dot product, the higher the similarity and attention score between two elements.

**Step2**. Normalize each of the attention scores:


In [5]:
attn_weights_2_tmp = attn_scores_for_2 / attn_scores_for_2.sum()

print("Attention weights: ", attn_weights_2_tmp)
print("Sum: ", attn_weights_2_tmp.sum())

Attention weights:  tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum:  tensor(1.0000)


**In practice, it's more common to use softmax function for normalization**

In [9]:
attn_weights_2_softmax = torch.softmax(attn_scores_for_2, dim=0)

print("Attention weights: ", attn_weights_2_softmax)
print("Sum: ", attn_weights_2_softmax.sum())

Attention weights:  tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum:  tensor(1.)


  **Step3**. Calculate the context vector by multiplying the embedded input tokens, with the corresponding attention weights and then summing the resulting vectors. Context vector is the weighted sum of all input vectors, obtained by multiplying each input vector by its corresponding attention weight.

In [19]:
query = inputs[1] # x2

print(query.shape)
context_vec_2 = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2_softmax[i] * x_i

print(context_vec_2)

torch.Size([3])
tensor([0.4419, 0.6515, 0.5683])


So again all the calculations done manually:
```
attn_scores_for_2 = input * x2 =
    [0.43, 0.15, 0.89]
    [0.55, 0.87, 0.66] (x2)
    [0.57, 0.85, 0.64]
    [0.22, 0.58, 0.33]
    [0.77, 0.25, 0.10]
    [0.05, 0.80, 0.55]         *   [0.55, 0.87, 0.66] (x2)  =
                                   [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865]

attn_weights_2_softmax = softmax([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865]) =
                                 [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581]

context_vec_2 = sum(input[i] * attn_weights_2_softmax[i]) =
              = sum(
                    [0.43, 0.15, 0.89] * 0.1385,
                    [0.55, 0.87, 0.66] * 0.2379
                    [0.57, 0.85, 0.64] * 0.2333
                    [0.22, 0.58, 0.33] * 0.1240
                    [0.77, 0.25, 0.10] * 0.1082
                    [0.05, 0.80, 0.55] * 0.1581)
              = sum(
                    [0.0596, 0.0208, 0.1233]
                    [0.1308, 0.2070, 0.1570]
                    [0.1330, 0.1983, 0.1493]
                    [0.0273, 0.0719, 0.0409]
                    [0.0833, 0.0270, 0.0108]
                    [0.0079, 0.1265, 0.0870])   =   [0.4419, 0.6515, 0.5683]
```
