### Multihead attention

#### Encoder

Sean:

- $Q \in \mathbb{R}^{s \times d_q}$
- $K \in \mathbb{R}^{s \times d_k}$
- $V \in \mathbb{R}^{s \times d_v}$

Con $s$ la longitud de la secuencia de tokens.

$$ \text{MultiHead}(Q, K, V) = W^O \text{Concat}(\text{head}_1, \text{head}_2, \cdots, \text{head}_h) $$
$$ \text{head}_i = \text{Attention}(W^Q_iQ, W^K_iK, W^V_iV) $$


Sea $X \in \mathbb{R}^{s \times d}$ una secuencia de tokens.

$$ X := \begin{bmatrix} 

    x^1_1 & x^1_2 & \cdots & x^1_d  \\
    x^2_1 & x^2_2 & \cdots & x^2_d   \\
    \vdots & \vdots & \ddots  & \vdots \\
    x^s_1 & x^s_2 & \cdots & x^s_d  \\

\end{bmatrix} $$


$$ X := \begin{bmatrix} 

    \begin{bmatrix} 

        x^1_1 & x^1_2 & \cdots & x^1_{d/h}  \\

        
    \vdots & \vdots & \ddots  & \vdots \\


        x^1_{d\frac{(h-1)}{h}+1} & x^1_2 & \cdots & x^1_d  \\

    \end{bmatrix} \\

    \vdots \\

    
    \begin{bmatrix} 

        x^s_1 & x^s_2 & \cdots & x^s_{d/h}  \\

        
    \vdots & \vdots & \ddots  & \vdots \\


        x^s_{d\frac{(h-1)}{h}+1} & x^s_2 & \cdots & x^s_d  \\

    \end{bmatrix} \\

\end{bmatrix} $$ 



$$ X := \begin{bmatrix} 

    \begin{bmatrix} 

        x^1_1 & x^1_2 & \cdots & x^1_{d/h}  \\

        
    \vdots & \vdots & \ddots  & \vdots \\


        x^s_1 & x^s_2 & \cdots & x^s_{d/h}  \\

    \end{bmatrix} \\

    \vdots \\

    
    \begin{bmatrix} 

        x^1_{d\frac{(h-1)}{h}+1} & x^1_2 & \cdots & x^1_d  \\

        
    \vdots & \vdots & \ddots  & \vdots \\


        x^s_{d\frac{(h-1)}{h}+1} & x^s_2 & \cdots & x^s_d  \\

    \end{bmatrix} \\

\end{bmatrix} $$ 

In [None]:
import math
from torch import Tensor
from torch.nn import Module
from torch.nn import Linear
from torch.nn.functional import softmax

def split(sequence: Tensor, number_of_heads: int) -> Tensor:
    batch_size, sequence_length, model_dimension = sequence.size()
    sequence = sequence.view(batch_size, sequence_length, model_dimension // number_of_heads, number_of_heads)
    sequence = sequence.transpose(1, 2).contiguous()
    return sequence

def concat(sequence: Tensor) -> Tensor:
    batch_size, sequence_lenght, heads_dimension, number_of_heads = sequence.size()
    sequence = sequence.transpose(1, 2).contiguous()
    sequence = sequence.view(batch_size, sequence_lenght, heads_dimension* number_of_heads)
    return sequence

In [None]:

class Attention(Module):
    def __init__(self):
        super().__init__()

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:        
        variance = math.sqrt(key.size(-1))
        score = query @ key.transpose(-2, -1) / variance
        return softmax(score) @ value
    

class MultiHead(Module):
    def __init__(self, model_dimension: int, key_dimension: int, value_dimension):
        super().__init__()
        self.query_projection = Linear(key_dimension, model_dimension)
        self.key_projection = Linear(key_dimension, model_dimension)
        self.value_projection = Linear(value_dimension, model_dimension)

    def attention(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:        
        variance = math.sqrt(key.size(-1))
        score = query @ key.transpose(-2, -1) / variance
        return softmax(score) @ value

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        query, key, value = self.query_projection(query), self.key_projection(key), self.value_projection(value)
        query, key, value = split(query), split(key), split(value)
        attention = self.attention(query, key, value)

In [2]:
from torch.nn import MultiheadAttention
from torch.nn.functional import scaled_dot_product_attention