# MultiHead Attention

By far the most important, and most significant component of the transformer is the Multihead Attention mechanism.
After all, the paper is titled, "Attention is all you need". From now on, Multihead Attention will be shorthanded to MHA from not on for convenience. 

Therefore, before diving deep into the rest of the transformer, I think it is worthwhile to try implementing 
this module from scratch, which in turn will help us understand ths subtle parts about MHA.

There does exist official PyTorch implementation, which is used for their official Transformer block, but I am a strong believer of
trying something from scratch to appreciate the details when using a de-facto implementation. 



# Attention
MHA is simply multiple Attention blocks stacked together, therefore in order to truly understand MHA, we need to do a deep dive on actual attention.
From the paper, Attention is defined with this mathematical equation:

$\text{Attention}(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$

with the following diagram to support the equation
![Attention Diagram](images/attention.jpeg)

# Q, K, and V
Q, K, and V stand for Query, Key, and Value respectively. To cite this awesome [Stack Overfow Answer](https://stats.stackexchange.com/questions/421935/what-exactly-are-keys-queries-and-values-in-attention-mechanisms)
> The key/value/query concepts come from retrieval systems. For example, when you type a query to search for some video on Youtube, the search engine will map your query against a set of keys (video title, description etc.) associated with candidate videos in the database, then present you the best matched videos (values).

One thing that really confused me personally when I was studying Attention is how they seemed to use the terms Q,K,V interchangeably even though they had different implications.
The following is what I mean:
![Attention Diagram 2](images/confusion.png "Attention and MHA")
The input to both Scaled-Dot-Product Attention and Multi-Head Attention are Q,K,V, but Multi-Head Attention has Scaled-Dot-Product Attention in them, and it seems to take in Q,K,V that is projected with Linear Layer.
Later on, they show the following equation for MHA that looks like the following:
![MHA Equation](images/mha_equation.png "MHA Equation")

Comparing it to the previos Attention equation which takes in Q,K,and V only, the Attention equation from MHA takes in 
$QW_{i}^{Q}$, $KW_{i}^{K}$, and $VW_{i}^{V}$. So which one is it...?


Maybe this was just me that was confused, but if you are like me who was also confused by this during the first read through of the paper, 
hopefully the following description may help.

The image for Attention Module is correct in the sense that **Q** (Query), **K** (Key), **V** (Value) are the inputs, but I think that MHA diagram should have something else as its inputs instead of **V,K,Q**. These values should instead be some other input, and the vector that went through **Linear** layer should be **Q,K,V** that goes into each Attention head.
I crudly redrew the Attention and MHA blocks to illustrate what I mean.
![New Attention Diagram](images/AttentionDiagram_New.png "New Attention Diagram for Illustration")

The Scaled Dot Product Attention box diagram would have the same internals as the previous diagram from the paper.

The new equation that follows the new diagram would be: 

![MHA Equation](images/new_mha_equation.png "MHA Equation")

$\text{MultiHead}(Q,K,V)=\text{Concat}(head_1,...,head_h)W^O$

$\text{where   } \text{head_i} = \text{Attention}(qW_i^Q, kvW_i^K, kvW_i^V)$

The **q** input gets turned into **Q** after being projected by $W^Q$, and **kv** gets turned into **K** and **V** after being projected by $W^K$ and $W^V$ respectively. **K** (Key), and **V** (Value) are extracted from the same inut vector. Personally, I am not sure why, but this is how it is in the paper. 
Instead of using the raw input $h$ times for each head of MHA, the authors decided to split up the input vector by projecting it with linear layer with output size of $d_k, d_k, d_v$ (dimensions for $q,k,v$ respectively) where $d_k = d_v = \frac{d_{model}}{h} = 64$ 
(Not sure why the distinguished between $d_k$ and $d_v$, but in the paper, they are ultimately equal).
This allowed for multiple heads to "jointly attend to information from different representation subspaces at different positions", which simply put,
means the network is able to obtain information better by looking at the input differently through multiple heads.




In self-attention (used in Encoder portion), **q** and **kv** are also equal, where as with non-self-attention (used in Decoder portion), **q** and **kv** are different vectors.


# Coding Attention Module

Okay, so I think the general overview of the Attntion Module and MHA is somewhat complete. The rest of the detail will be understood better with code.

The first part of creating MHA module is to create the Scaled Dot Product Attention Module. MHA is a simple stacking of multiple Attention Modules, so once the actual Attention Module is properly formed, we can expect the rest to be pretty straight forward

In [1]:
# Import neccessary libraries
import torch
from torch import nn
from torch.nn import Module, ModuleList
from torch.nn import functional as F
import numpy as np

In [2]:
class Attention(Module):
    def __init__(self, d_k, mask=None):
        # initialize all the components here
        super().__init__()
        self.d_k = d_k
        self.mask = mask
        self.layer_norm = nn.LayerNorm(normalized_shape=self.d_k)
        
    
    def forward(self, q,k,v):
        """
        q,k,v are in lower case letters to match Python practice. It is same as the capital letter variables from above
        All the vectors have dimension of (batch, seq_len, d_k)
        :param q: Query 
        :param k: Key 
        :param v: Value
        :return: Result vector with shape (batch, seq_len, d_k)
        """
        # need to transpose k to matrix multiply with q. Transposes row and col
        k_T = torch.transpose(k, 1, -1) # k_T.shape = (batch, d_k, seq_len)
        assert self.d_k == q.shape[-1] and self.d_k == k.shape[-1] and self.d_k == v.shape[-1], "Dimension set and actual dimension does not match" # shape = (batch, d_k)
        # torch.matmul performs batch-matrix-multiplication.
        # More detail on how matmul deals with its cases is listed here https://pytorch.org/docs/stable/generated/torch.matmul.html
        qkT = torch.matmul(q,k_T) # qkT.shape = (batch, seq_len, seq_len)
        qkT_scaled = qkT * (self.d_k**(-0.5))
        prob_map = None
        if self.mask:
            pass
        else:
            prob_map = F.softmax(input=qkT_scaled, dim=-1)
        
        # prob_map -> (batch, seq_len, seq_len)
        # v => (batch, seq_len, d_k)
        # Therefore, result -> (batch, seq_len, d_k), which is the same dimension as the original inputs
        result = torch.matmul(prob_map, v)
        return result 
        

Attention module is now created. For now, the mask is ignored, but we will get back to this when we deal with the decoder module.
Notice that the output of the module has the same the same dimensions as the input vectors.

Now that Attention Module is done, let's go ahead and make the MHA module. 

# Coding MHA Module

In [3]:
wqs = ModuleList()
wks = ModuleList()
wvs = ModuleList()
for head in range(8):
    wqs.append(nn.Linear(in_features=512, out_features=64))
    wks.append(nn.Linear(in_features=512, out_features=64))
    wvs.append(nn.Linear(in_features=512, out_features=64))

seq_len = 5
inputs = torch.rand((2,seq_len, 512))
for wq in wqs:
    output = wq(inputs)
    
for wq,wv,wk in zip(wqs, wks, wvs):
    output1 = wq(inputs)
    output2 = wv(inputs)
    output3 = wk(inputs)
    print(output1.size(), output2.size(), output3.size())
    


torch.Size([2, 5, 64]) torch.Size([2, 5, 64]) torch.Size([2, 5, 64])
torch.Size([2, 5, 64]) torch.Size([2, 5, 64]) torch.Size([2, 5, 64])
torch.Size([2, 5, 64]) torch.Size([2, 5, 64]) torch.Size([2, 5, 64])
torch.Size([2, 5, 64]) torch.Size([2, 5, 64]) torch.Size([2, 5, 64])
torch.Size([2, 5, 64]) torch.Size([2, 5, 64]) torch.Size([2, 5, 64])
torch.Size([2, 5, 64]) torch.Size([2, 5, 64]) torch.Size([2, 5, 64])
torch.Size([2, 5, 64]) torch.Size([2, 5, 64]) torch.Size([2, 5, 64])
torch.Size([2, 5, 64]) torch.Size([2, 5, 64]) torch.Size([2, 5, 64])


In [4]:
class MultiHeadAttention(Module):
    def __init__(self, h, d_model, mask=None):
        """
        
        :param h: number of heads  (8 in paper)
        :param d_model: input feature size (512 in the paper)
        :param mask: 
        """
        # initialize all the components here
        super().__init__()
        self.d_model = d_model
        self.mask = mask
        self.h = h
        self.d_k = self.d_model//self.h
        self.attention = Attention(d_k=self.d_k, mask=self.mask)
        
        
        self.wqs = ModuleList()
        self.wks = ModuleList()
        self.wvs = ModuleList()
        
        for head in range(self.h):
            self.wqs.append(nn.Linear(in_features=d_model, out_features=self.d_k))
            self.wks.append(nn.Linear(in_features=d_model, out_features=self.d_k))
            self.wvs.append(nn.Linear(in_features=d_model, out_features=self.d_k))
            
        self.wo = nn.Linear(in_features=self.d_k*self.h, out_features=self.d_model)

    def forward(self, q, kv):
        """
        q,k,v are in lower case letters to match Python practice. It is same as the capital letter variables from above
        All the vectors have dimension of (batch, seq_len, d_k)
        :param q: Input used to get "q" (query)
        :param kv: Input used to get "k" (key) and "v" (value)
        :return: Result vector with shape (batch, seq_len, d_k)
        """
        # recall that (seq_len, d_model)* (d_model, d_k) -> (seq_len, d_k) 
        # which is the matrix multiplication between input vector with linear layer (wq, wk, wv)
        head_outputs = []
        for wi_q, wi_k, wi_v in zip(self.wqs, self.wks, self.wvs):
            Q = wi_q(q)
            K = wi_k(kv)
            V = wi_v(kv)
            _output = self.attention(Q, K, V) # each output has size of (batch, seq_len, d_k)
            head_outputs.append(_output)
        attention_output = torch.cat([head_outputs], -1) # concat it to get (batch, seq_len, d_model)
        output = self.wo(attention_output)
        return output

Now we have all the major components needed to build the encoder portion of the Transfomer. 
The architecture of the Transformer can be viewed from this image, provided by the original paper.

![Transformer Encoder Diagram](images/transformer_encoder.png "Encoder of Transformer")

The normalization from "Add & Norm" block is Layer Normalization. 

Feed Forward according to paper is characterized by this equation:
![Transformer Encoder Diagram](images/ffn_equation.png "Encoder of Transformer")

$W_1$ and $W_2$ indicates that the feed forward network is made from two linear layers, with the input and output vector dimension being $d_in = d_out = 512$, and inner layer vector dimension of $d_{ff} = 2048$.


In [7]:
class TransformerEncoderLayer():
    def __init__(self, h, d_model):
        super().__init__()
        self.mha = MultiHeadAttention()
        self.feed_forward = nn.Sequential(nn.Linear(in_features = 512, out_features = 2048), nn.Linear(in_features = 2048, out_features = 512))
    def forward(self, x):
        mha_output = self.mha(q=x, kv=x)
        add_and_norm_1 = F.layer_norm(x+mha_output)
        ffn_output = self.feed_forward(add_and_norm_1)
        output = F.layer_norm(add_and_norm_1+ffn_output)
        return output
