## Multi-Head attention

<img src="../assets/multi-head-attention.png" width="500" height="500" alt="meow" >

- In the Transformer, the Attention module repeats its computations multiple times in parallel. 
- Each of these is called an `Attention Head`. 
- The **Attention module splits its Query, Key, and Value parameters N-ways and passes each split independently through a separate Head**. 
- *`All of these similar Attention calculations are then combined together to produce a final Attention score`*. 
- This is called **`Multi-head attention`** and gives the Transformer greater power to encode multiple relationships and nuances for each word.

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
sequence_length = 4
batch_size = 1
input_dim = 512
d_model = 512
x = torch.randn((batch_size, sequence_length, input_dim))

x.shape

torch.Size([1, 4, 512])

- Think of `x` as embedding representation of a sentence of length 4, and each word is represented by a vector of length 512.

---

## Query, Key, Value

- Now, we will pass it through a linear layer to get the `Query`, `Key`, and `Value` representations of the sentence.

In [3]:
qkv_layer = nn.Linear(input_dim, 3 * d_model) # multiplied by 3 bcoz we need q, k, v
qkv = qkv_layer(x)

qkv.shape

torch.Size([1, 4, 1536])

---

## Split it for Multi-Head Attention

- Now, we will split the `Query`, `Key`, and `Value` representations into `N` heads.
- Each head will have a dimension of `512/N`.
- We will then calculate the `Attention` for each head.
- Finally, we will concatenate the `Attention` scores of each head to get the final `Attention` scores.

In [4]:
num_heads = 8
head_dim = d_model // num_heads
qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3 * head_dim)

qkv.shape

torch.Size([1, 4, 8, 192])

---

## Swap axis (let's have (head) dimension first, then, number of words, and then, the dimension of each word)

- This will help us to easily feed each head to the `multi-head Attention` module.

In [5]:
qkv = qkv.permute(0, 2, 1, 3)  # [batch_size, num_heads, sequence_length, 3*head_dim]
qkv.shape

torch.Size([1, 8, 4, 192])

---

## Split QKV into Q, K, and V for each head

In [6]:
q, k, v = qkv.chunk(3, dim=-1)
q.shape, k.shape, v.shape

(torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]),
 torch.Size([1, 8, 4, 64]))

> 
> So, for 8-heads, and 4 words, we have tensor of shape 192.

---

## Self Attention for multiple heads

- We will perform the same `Self-Attention` for each head.

$$ attentionScore = softmax\biggl(\frac{(Q . K^T)}{\sqrt d_k} + M \biggr) $$
$$ contextualEmbedding = attentionScore . V $$

Where,
$ d_k = dimension of keys $

$ Q = Query $

$ K = key $

$ V = Value $

$ M = Mask $
(Mask will be useful for decoder)

In [7]:
import math

d_k = q.size()[-1] # dimension of k (key)

# for tensor with dimension >2, tensor.T is not supported for transpose. We need to specify the dimension
scaled = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
scaled.shape

torch.Size([1, 8, 4, 4])

---

## Masking (to prevent future words from being attended) - useful for decoder

In [10]:
mask = torch.full(scaled.size(), float("-inf"))
mask = torch.triu(mask, diagonal=1)
mask[0][1]  # mask for input to a single head

mask.shape, mask[0][1]

(torch.Size([1, 8, 4, 4]),
 tensor([[0., -inf, -inf, -inf],
         [0., 0., -inf, -inf],
         [0., 0., 0., -inf],
         [0., 0., 0., 0.]]))

In [11]:
(scaled + mask)[0][0]

tensor([[-0.3185,    -inf,    -inf,    -inf],
        [-0.1710, -0.2914,    -inf,    -inf],
        [ 0.4216,  0.3508,  0.4670,    -inf],
        [-0.2973,  0.3598, -0.0526, -0.1818]], grad_fn=<SelectBackward0>)

In [12]:
scaled += mask

In [13]:
attention = F.softmax(scaled, dim=-1)

attention.shape

torch.Size([1, 8, 4, 4])

In [14]:
attention[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5301, 0.4699, 0.0000, 0.0000],
        [0.3358, 0.3128, 0.3514, 0.0000],
        [0.1877, 0.3620, 0.2397, 0.2106]], grad_fn=<SelectBackward0>)

In [15]:
values = torch.matmul(attention, v)
values.shape

torch.Size([1, 8, 4, 64])

---

## functional way of implementing multi-head attention

In [16]:
import math


def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [17]:
values, attention = scaled_dot_product(q, k, v, mask=mask)

attention.shape, values.shape, attention[0][0]

(torch.Size([1, 8, 4, 4]),
 torch.Size([1, 8, 4, 64]),
 tensor([[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5301, 0.4699, 0.0000, 0.0000],
         [0.3358, 0.3128, 0.3514, 0.0000],
         [0.1877, 0.3620, 0.2397, 0.2106]], grad_fn=<SelectBackward0>))

In [18]:
values = values.reshape(batch_size, sequence_length, num_heads * head_dim)
values.size()

torch.Size([1, 4, 512])

In [19]:
linear_layer = nn.Linear(d_model, d_model)

out = linear_layer(values)

out.shape

torch.Size([1, 4, 512])

In [20]:
# this output will be much more context rich than the input. This is the essence of self-attention

out

tensor([[[ 0.6225,  0.5013,  0.2466,  ...,  0.1342,  0.1591,  0.0558],
         [ 0.0986, -0.0177,  0.3862,  ..., -0.4152,  0.2971, -0.3395],
         [ 0.0876, -0.1153,  0.2119,  ...,  0.0884, -0.1442, -0.0105],
         [ 0.2140, -0.0680,  0.2028,  ..., -0.0101,  0.0936, -0.0240]]],
       grad_fn=<ViewBackward0>)

---

## Class way of implementing multi-head attention

In [21]:
import torch
import torch.nn as nn
import math


def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention


class MultiHeadAttention(nn.Module):

    def __init__(self, input_dim, d_model, num_heads):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(input_dim, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        print(f"x.size(): {x.size()}")
        qkv = self.qkv_layer(x)
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.reshape(
            batch_size, sequence_length, self.num_heads, 3 * self.head_dim
        )
        print(f"qkv.size(): {qkv.size()}")
        qkv = qkv.permute(0, 2, 1, 3)
        print(f"qkv.size(): {qkv.size()}")
        q, k, v = qkv.chunk(3, dim=-1)
        print(f"q size: {q.size()}, k size: {k.size()}, v size: {v.size()}, ")
        values, attention = scaled_dot_product(q, k, v, mask)
        print(f"values.size(): {values.size()}, attention.size:{ attention.size()} ")
        values = values.reshape(
            batch_size, sequence_length, self.num_heads * self.head_dim
        )
        print(f"values.size(): {values.size()}")
        out = self.linear_layer(values)
        print(f"out.size(): {out.size()}")
        return out

In [22]:
input_dim = 1024
d_model = 512
num_heads = 8

batch_size = 30
sequence_length = 5
x = torch.randn((batch_size, sequence_length, input_dim))

model = MultiHeadAttention(input_dim, d_model, num_heads)
out = model.forward(x)

x.size(): torch.Size([30, 5, 1024])
qkv.size(): torch.Size([30, 5, 1536])
qkv.size(): torch.Size([30, 5, 8, 192])
qkv.size(): torch.Size([30, 8, 5, 192])
q size: torch.Size([30, 8, 5, 64]), k size: torch.Size([30, 8, 5, 64]), v size: torch.Size([30, 8, 5, 64]), 
values.size(): torch.Size([30, 8, 5, 64]), attention.size:torch.Size([30, 8, 5, 5]) 
values.size(): torch.Size([30, 5, 512])
out.size(): torch.Size([30, 5, 512])
