## Intro

1. Concept of attention was introduced in [Bahdanau paper](https://arxiv.org/abs/1409.0473). Idea was to improve Recurrent Neural networks used in the context of machine translation, so that we do not translate word-by-word, but rather have access to all sequence elements at each time step.
   ![sentence translation](img/sentence.png)
2. Moreover, attention is selective and attributes different weights to different words in the sequence. Using the context notation:
   $$
   c_i = \sum\limits_{j=1}^{T_x} \alpha_{ij} h_j
   $$
3. Transformer architecture was later introduced in ["Attention is all you need" paper](https://arxiv.org/pdf/1706.03762), removing the need for RNNs altogether by utilising the concept of self-attention.
   ![self attention](img/self_attention.png)
4. Self-attention essentially adds additional context information to each input. This context information is used by the model in order to adjust the relative impact of each word on the resulting output.
5. There are many types of self-attention mechanisms, the original one is named "scaled dot-product attention".

## Implementation Steps

### 1. Embedding

Let's consider a simple sentence:

In [164]:
sentence = 'Life is short, eat dessert first'

Let's pretend that our vocabulary consists *only* of the words in the sentence above. We construct it via:


In [165]:
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


Now we can use this dictionary to assign an index to each word:

In [166]:
import torch

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


By now, we already have a vocabulary, `dc`, and a vector representation of the sentence, `sentence_int`. Let's use PyTorch's `Embedding` layer to construct an embedding for the sentence:

In [167]:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])
torch.Size([6, 16])


### 2. Weight matrices

Self-attention relies on query, key, and value weight matrices.

We'll denote them by $W_q$, $W_k$, and $W_v$, respectively.

Two important things to keep mind:
1. These are model parameters, therefore, are being **adjusted during training**.
2. They are multiplied with inputs in order to obtain query, key, and value sequences via:
   - Query sequence: $q^{(i)} = W_q x^{(i)}, \; i = \overline{0,T}$
   - Key sequence: $k^{(i)} = W_k x^{(i)}, \; i = \overline{0,T}$
   - Value sequence: $v^{(i)} = W_v x^{(i)}, \; i = \overline{0,T}$
     
   Here $T$ is the length of the input sequence.

A visual representation:

![attention matrices](img/attention-matrices.png)

**Dimensions** of these are:
- $x^{(i)}$ have length $d$
- $W_q$ and $W_k$ are $d_k \times d$
- $W_v$ is $d_v \times d$
- $q^{(i)}$ and $q^{(i)}$ have lengths $d_k$
- $q^{(i)}$ has length $d_v$

**Note:** since we will be computing a product of query and key vectors $q^{(i)}$ and $k^{(i)}$, their dimensions are identical.

Now the code:

In [168]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 24, 24, 28

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))

Let's compute the sequences for the **second input element**. It will act as a **query element** for subsequent computations, that is why it's shaded:

![second_input](img/second_input_computation.png)

In [169]:
x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([24])
torch.Size([24])
torch.Size([28])


We can generalize the computation of key and value matrices for all sequence elements:

In [170]:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


### 3. Attention weights

Now we can compute **unnormalized attention weights** $\omega$. These are defined as a product of query and key vectors:
$$
\omega_{ij} = q^{(i)T} k^{(j)}
$$

![attention weights](img/attention-weights.png)

For instance, the computation of our query element and fifth key element is performed via:

In [171]:
omega_25 = query_2.dot(keys[4])
print(omega_25)

tensor(11.1466, grad_fn=<DotBackward0>)


In matrix form:

In [172]:
omega_2 = query_2.matmul(keys.T)
print(omega_2)

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800],
       grad_fn=<SqueezeBackward4>)


### 4. Attention scores

Now we need to normalize the unnormalized attention weights $\omega$. We will denote **normalized attention weights** by $\alpha_{ij}$.

Normalization is done in two steps:
1. Divide $\omega_{ij}$ by $\sqrt{d_k}$, so that the length of weight vectors is within a certain amount.
2. Apply softmax to the above.

Visual representation:

![attention scores](img/attention-scores.png)

The code for the computation described above:

In [173]:
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458],
       grad_fn=<SoftmaxBackward0>)


### 5. Context vector

And now the value vectors $v^{(i)}$ finally come into play. We will use them to compute a context vector $z^{(2)}$ for current input $x^{(2)}$ and current query $q^{(2)}$:
$$
z^{(2)} = \sum\limits_{j=1}^T \alpha_{2,j} v^{(j)}
$$

![context_vector](img/context-vector.png)

The code:

In [174]:
context_vector_2 = attention_weights_2.matmul(values)

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084], grad_fn=<SqueezeBackward4>)


To summarize, below are the computations and dimensions of different tensors involved in the process:

![summary](img/summary.png)

### 6. Multi-head attention

So far we've implemented single-head attention model. "Single-head" means that we utilized a single set of (query, key, value) matrices:

![single-head](img/single-head.png)

In multi-head attention, we use several sets of such matrices. This is similar to using multiple kernels (or filters) in convolutional neural networks:

![multi-head](img/multi-head.png)

Let's suppose we have $h=3$ attention heads. In code this will be:

In [175]:
h = 3
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))

Let's again consider $x^{(2)}$. Each query element will now be $3 \times d_q$-dimensional:

In [176]:
multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2.shape)

torch.Size([3, 24])


Key and value sequences are obtained via:

In [177]:
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)

We need *all* key and value elements in order to compute attention scores for the second input element. So we first need to expand the input sequence embeddings to size 3:

In [178]:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs.shape)

torch.Size([3, 16, 6])


Now we compute *all* keys and values via PyTorch's batch matrix multiplication:

In [179]:
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 24, 6])
multihead_values.shape: torch.Size([3, 28, 6])


In order to make keys and values easier to interpret, we can swap their second and third dimensions, so that their shape is similar to the one of `embedded_sentence`:

In [180]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 6, 24])
multihead_values.shape: torch.Size([3, 6, 28])


Now we can follow the steps outlined for single-head attention in order to compute:

1. Unscaled attention weights $\omega$
2. Scaled attention weights $\alpha$
3. $h$ $d_v$-dimensional context vectors $z_i^{(2)}, i = \overline{1,h}$.

Final question remains - how do we construct a single $d_v$-dimensional vector expected by the subsequent layers in the architecture? By utilizing an additional learned matrix $W^o$ with dimensions $hd_v \times d_v$, which get multiplied with stacked $z_i^{(2)}$ matrices via:
$$
z^{(2)} = \left(z_1^{(2)} \frown \dots \frown z_h^{(2)}\right)W^o
$$
Computation of weights/scores/final context vector for multi-head attention is left as an **exercise**.

In [181]:
import torch.nn.functional as F

multihead_queries = torch.bmm(multihead_W_query, stacked_inputs)
print("multihead_queries.shape:", multihead_queries.shape)

attention_scores = multihead_keys.bmm(multihead_queries)
print("attention_scores.shape:", attention_scores.shape)

d_k = multihead_queries.size(-1)
attention_weights = F.softmax(attention_scores / (d_k ** 0.5), dim=-1)
print("attention_weights.shape:", attention_weights.shape)

context_vectors = attention_weights.transpose(-2, -1).bmm(multihead_values).transpose(-2, -1)
print("context_vectors.shape:", context_vectors.shape)

W_output = torch.nn.Parameter(torch.rand(d_v * h, d_v))
print("W_output.shape:", W_output.shape)

final_context = W_output.T @ context_vectors.reshape(-1, 6)
print("final_context.shape:", final_context.shape)

print(final_context)


multihead_queries.shape: torch.Size([3, 24, 6])
attention_scores.shape: torch.Size([3, 6, 6])
attention_weights.shape: torch.Size([3, 6, 6])
context_vectors.shape: torch.Size([3, 28, 6])
W_output.shape: torch.Size([84, 28])
final_context.shape: torch.Size([28, 6])
tensor([[ 4.0896e-02, -3.4264e-04, -1.9216e+02, -4.2820e-02, -3.3769e-02,
          1.1699e+02],
        [ 3.0241e-02, -2.9074e-04, -2.3712e+02, -3.8028e-02, -3.4496e-02,
          1.3699e+02],
        [ 3.8418e-02,  1.5550e-05, -1.8847e+02, -3.7054e-03, -3.9817e-03,
          1.1592e+02],
        [ 1.7002e-02, -2.6866e-04, -2.1844e+02, -3.4931e-02, -3.5825e-02,
          1.3128e+02],
        [ 4.0694e-02, -2.4788e-04, -2.0304e+02, -3.2624e-02, -2.7242e-02,
          1.1894e+02],
        [ 3.8981e-02, -5.9892e-04, -2.3794e+02, -7.1452e-02, -6.3189e-02,
          1.3102e+02],
        [ 4.9119e-02, -5.1041e-05, -1.6964e+02, -1.1154e-02, -5.9654e-03,
          1.0055e+02],
        [ 3.6791e-02, -9.9137e-05, -1.9562e+02, -1.6645e