# Attention mechanism and multi-head attention

After having a encoder-decoder framewark in mind, we now build the compoments of a transformer one by one, starting from the attention mechanism (the selected area below). 

We'll first build a single-head attention attention, then extend it to multi-head and implement masking for decoder inputs

![transformer model with its attention mechanism selected](./photos/Screenshot%20from%202022-09-07%2017-36-39.png)

According to the [transformer paper](https://arxiv.org/pdf/1706.03762.pdf), the hidden states are first **linearly projected** to the key vector, query vector, and the value vector, then computed according to the formula below:
$$Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}}V)$$

![An illustration of the attention mechanism](./photos/Screenshot%20from%202022-09-07%2018-12-14.png)

As indicated in the illustration above, the query vector $Q$ and the key vector $K$ has the same dimensions, that is, $d_q = d_k$. The value vector can have any dimensions $d_v$ if we add an addition linear layer before the $SoftMax$ layer of $Q$ and $K$, yet in this paper $d_k = d_q = d_v$.

The linear projection matrices for a single-head attention mechanism, $W^Q$, $W^K$, $W^V$,  are of dimensions $d_{model}\times d_k$, $d_{model}\times d_q$, $d_{model}\times d_v$, respectively.

Below is an implementation of the single-head attention:

In [1]:
import torch
from torch import nn

In [25]:
class SingleHeadAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k

        self.w_q = nn.Linear(d_model, d_k)
        self.w_k = nn.Linear(d_model, d_k)
        self.w_v = nn.Linear(d_model, d_k)
    
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        # q,k,v: [batch_size, seq_len, d_k]
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)

        # scores, normalized_scores, softmax_scores: [batch_size, seq_len, seq_len]
        scores = torch.matmul(q, torch.permute(k, (0, 2, 1)))
        normalized_scores = scores/self.d_k
        softmax_scores = nn.functional.softmax(normalized_scores, dim=-1)
        #TODO: verify whether the softmax is applied along the correct dimesion

        # attentions: [batch_size, seq_len, d_k]
        attentions = torch.matmul(normalized_scores, v)
        return attentions
        
        

# Now let's write tests to make sure the dimensions and the computed attentions are correct

In [26]:
single_head_attention = SingleHeadAttention(5,3)

In [32]:
dummy_inputs = torch.randn((1,4,5)) # A dummy input of shape [batch_size, seq_len, d_model]

In [33]:
dummy_outputs = single_head_attention(dummy_inputs)

In [35]:
assert dummy_outputs.shape == (1,4,3)

In [36]:
dummy_outputs

tensor([[[-0.7294,  0.2412,  0.0146],
         [-0.0643, -0.3173, -0.3077],
         [ 1.5143, -0.2022,  0.2589],
         [-0.9435, -0.0336, -0.3060]]], grad_fn=<UnsafeViewBackward0>)