# Attention mechanism and multi-head attention

After having a encoder-decoder framewark in mind, we now build the compoments of a transformer one by one, starting from the attention mechanism (the selected area below). 

We'll first build a single-head attention attention, then extend it to multi-head and implement masking for decoder inputs

![transformer model with its attention mechanism selected](./photos/Screenshot%20from%202022-09-07%2017-36-39.png)

According to the [transformer paper](https://arxiv.org/pdf/1706.03762.pdf), the hidden states are first **linearly projected** to the key vector, query vector, and the value vector, then computed according to the formula below:
$$Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}}V)$$

![An illustration of the attention mechanism](./photos/Screenshot%20from%202022-09-07%2018-12-14.png)

As indicated in the illustration above, the query vector $Q$ and the key vector $K$ has the same dimensions, that is, $d_q = d_k$. The value vector can have any dimensions $d_v$ if we add an addition linear layer before the $SoftMax$ layer of $Q$ and $K$, yet in this paper $d_k = d_q = d_v$.

The linear projection matrices for a single-head attention mechanism, $W^Q$, $W^K$, $W^V$,  are of dimensions $d_{model}\times d_k$, $d_{model}\times d_q$, $d_{model}\times d_v$, respectively.

Below is an implementation of the single-head attention:

In [1]:
import torch
from torch import nn

In [2]:
class SingleHeadAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k

        self.w_q = nn.Linear(d_model, d_k)
        self.w_k = nn.Linear(d_model, d_k)
        self.w_v = nn.Linear(d_model, d_k)
    
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        # q,k,v: [batch_size, seq_len, d_k]
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)

        # scores, normalized_scores, softmax_scores: [batch_size, seq_len, seq_len]
        scores = torch.matmul(q, torch.permute(k, (0, 2, 1)))
        normalized_scores = scores/self.d_k
        softmax_scores = nn.functional.softmax(normalized_scores, dim=-1)
        #TODO: verify whether the softmax is applied along the correct dimesion

        # attentions: [batch_size, seq_len, d_k]
        attentions = torch.matmul(softmax_scores, v)
        return {
            'k': k,
            'q': q,
            'v': v,
            'softmax_scores': softmax_scores,
            'attentions': attentions
        }
        
        

# Now let's write tests to make sure the dimensions and the computed attentions are correct

In [3]:
single_head_attention = SingleHeadAttention(5,3)

In [4]:
dummy_inputs = torch.randn((7,4,5)) # A dummy input of shape [batch_size, seq_len, d_model]

In [5]:
dummy_outputs = single_head_attention(dummy_inputs)

In [6]:
assert dummy_outputs['attentions'].shape == (7,4,3) 
# Make sure that the shape of the attentions is [batch_size, seq_len, d_v]

In [7]:
print(dummy_outputs['v'])
print(dummy_outputs['attentions'])

tensor([[[-1.2249,  0.0661,  0.6594],
         [-0.5190, -0.0567,  0.1383],
         [-0.6636, -0.3275,  1.1567],
         [-0.9605, -0.0276,  1.5279]],

        [[ 0.1181, -0.2155,  0.0951],
         [-0.9050, -0.1242,  0.6470],
         [ 0.2617,  0.0027, -0.0351],
         [-0.3784,  0.4812, -0.6979]],

        [[ 1.4064, -0.0663,  0.2217],
         [-0.8405, -1.2527,  0.9365],
         [-0.7726, -0.0266,  0.2264],
         [-0.6351,  0.4558,  0.3582]],

        [[ 0.3821,  0.6742, -0.3431],
         [-0.4670,  0.7657, -0.3341],
         [-0.4310,  0.0676,  0.6857],
         [-0.3748,  1.1992, -0.9629]],

        [[ 0.6693, -0.5286,  0.5813],
         [-1.4740, -0.1369,  0.1863],
         [-0.5995,  0.0515, -0.5938],
         [-2.0923, -0.5931,  1.2989]],

        [[-0.8034,  0.2576,  0.0963],
         [-0.8780, -0.1405, -0.2514],
         [ 0.3408,  0.7539, -1.0165],
         [-0.5890, -0.3180,  0.8081]],

        [[ 0.1114, -0.1259, -0.3868],
         [ 0.8805,  0.5552, -1.2451],


# Multi-headed attentions
With this simple single-headed attention module, we can now try to implement a multi-headed one based on it. The multi-headed attention mechanism would require us to expand the dimensions of the q,k,v vectors, and include a concatenation and an additional concatenation layer after the attention computations. (As shown in the illustration below)

![single-headed and multi-headed attention mechanism](./photos/Screenshot%20from%202022-09-07%2021-19-42.png)

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_head, d_k):
        super().__init__()
        self.d_k = d_k
        self.num_head = num_head
        self.d_model = d_model

        # The output matrices from these linear layers would be 
        # slice into num_head parts to preform multi-headed 
        # attention calculation
        self.w_q = nn.Linear(d_model, d_k*num_head)
        self.w_k = nn.Linear(d_model, d_k*num_head)
        self.w_v = nn.Linear(d_model, d_k*num_head)
        self.final_linear = nn.Linear(d_k*num_head, d_model)
    
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        batch_size, seq_len, _ = x.shape
        

        # q,k,v: [batch_size, seq_len, num_head*d_k]
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)

        # scores, normalized_scores, softmax_scores: [batch_size, seq_len, seq_len]

        # q: [batch_size, seq_len, num_head*d_k] -> [batch_size, num_head, seq_len, d_k]
        q = torch.permute(q, (0, 2, 1)).reshape(batch_size, self.num_head, self.d_k, 
            seq_len).permute(0, 1, 3, 2)
        assert q.shape == (batch_size, self.num_head, seq_len, self.d_k)

        # k: [batch_size, seq_len, num_head*d_k] -> [batch_size, num_head, d_k, seq_len]
        k = torch.permute(k, (0, 2, 1)).reshape(batch_size, self.num_head, self.d_k, 
            seq_len)
        assert k.shape == (batch_size, self.num_head, self.d_k, seq_len)
        
        # scores, normalized_scores: [batch_size, num_head, seq_len, seq_len]
        scores = torch.matmul(q, k)
        assert scores.shape == (batch_size, self.num_head, seq_len, seq_len)
    
        normalized_scores = scores/self.d_k
        softmax_scores = nn.functional.softmax(normalized_scores, dim=-1)
        #TODO: verify whether the softmax is applied along the correct dimesion

        # v: [batch_size, seq_len, num_head*d_k] -> [batch_size, num_head, seq_len, d_k]
        v = torch.permute(v, (0, 2, 1)).reshape(batch_size, self.num_head, self.d_k, 
            seq_len).permute(0, 1, 3, 2)

        assert v.shape == (batch_size, self.num_head, seq_len, self.d_k)

        # attentions: [batch_size, num_head, seq_len, d_k]
        attentions = torch.matmul(normalized_scores, v)

        # Concat attentions from all the attention heads
        # attentions: [batch_size, num_head, seq_len, d_k] ->
        # [batch_size, seq_len, num_head*d_k]
        attentions = attentions.permute(0,2,1,3).reshape((batch_size, seq_len, self.num_head*self.d_k))
        assert attentions.shape == (batch_size, seq_len, self.num_head*self.d_k)

        # Final linear layer: 
        #[batch_size, seq_len, num_head*d_k] -> [batch_size, seq_len, d_model]
        outputs = self.final_linear(attentions)
        assert outputs.shape == (batch_size, seq_len, self.d_model)

        return {
            'k': k,
            'q': q,
            'v': v,
            'softmax_scores': softmax_scores,
            'attentions': attentions,
            'outputs': outputs
        }
        
            

In [9]:
multihead_attention = MultiHeadAttention(2, 3, 5).to('cuda:0') #[d_model, num_head, d_k]

In [10]:
dummy_inputs = torch.randn((7, 11, 2)).to('cuda:0') #[batch_size, seq_len, d_model]

In [11]:
# [batch_size, seq_len, d_model] == (7, 11, 2)
multihead_attention(dummy_inputs)['outputs'].shape

torch.Size([7, 11, 2])

In [14]:
multihead_attention(dummy_inputs)['outputs']

tensor([[[-0.8120,  0.3692],
         [-0.7812,  0.5577],
         [-0.4525, -0.2322],
         [-0.8219,  0.9419],
         [-0.6648, -0.2185],
         [-0.1780, -0.7772],
         [-0.6171,  0.1114],
         [-0.5702, -0.4131],
         [-0.9026,  0.9138],
         [-1.0254,  1.0927],
         [ 0.0157, -1.1121]],

        [[-0.3096, -0.3299],
         [-1.9933,  1.7937],
         [ 0.0325, -0.6157],
         [ 0.4420, -1.3144],
         [ 0.0763, -0.8618],
         [-0.7081,  0.4242],
         [-2.1309,  2.0334],
         [-1.6788,  1.1860],
         [-0.8109,  0.2756],
         [-1.1934,  0.8401],
         [-1.4569,  1.0500]],

        [[-2.2969,  1.9327],
         [-1.1403,  0.7162],
         [-0.6666,  0.1410],
         [ 1.0163, -1.6569],
         [-0.1207, -0.4114],
         [-2.3476,  2.0281],
         [-0.7486,  0.3493],
         [-0.7244,  0.1886],
         [-1.6772,  1.2432],
         [-1.6050,  1.1225],
         [ 0.3330, -0.9215]],

        [[ 1.2644, -1.5887],
        

In [16]:
multihead_attention.w_q.weight.grad

# Summary
So we have it, a multi-headed attention mechanims from scratch!! However, as you might have already noticed, it only allows attentions between input tokens feeded into the encoder, but not between encoder and decoder tokens. Therefore, this is a **Self Attehntion Mechanism**. In the next section, we'll build other parts of the encoder before moving to the decoder.

## Unit test for `MultiHeadAttention`

In [38]:
from vit.multiheaded_attentions import MultiHeadAttention
import torch
from torch import nn

def multihead_attention_unit_test():
    multihead_attention = MultiHeadAttention(512, 8, 64).to('cpu') #[d_model, num_head, d_k]
    dummy_inputs = torch.randn((2, 128, 512)).to('cpu') #[batch_size, seq_len, d_model]
    y = multihead_attention(dummy_inputs)
    loss = y.mean()
    loss.backward()
    for name, param in multihead_attention.named_parameters():
        assert param.grad is not None

In [None]:
multihead_attention_unit_test()