## MultiheadAttention in details

In [None]:
import torch
from torch import nn

### einops

In [None]:
import einops

In [None]:
x = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
x

Let's swap second and third dimensions and then union first and second dimensions

In [None]:
z = x.transpose(1, 2)

In [None]:
assert torch.allclose(torch.cat([z_ for z_ in z]), z.flatten(0, 1))

In [None]:
# [2, 3, 4, 5] -> [2, 4, 3, 5] -> [8, 3, 5]
x.transpose(1, 2).flatten(0, 1).shape

In [None]:
torch.allclose(
    einops.rearrange(x, 'first second third fourth -> (first third) second fourth'),
    x.transpose(1, 2).flatten(0, 1)
)

Which is more readable? :)

In [None]:
class MultiheadAttention(nn.Module):
    
    def __init__(self, input_dim: int, num_heads: int, dropout: float):
        super(MultiheadAttention, self).__init__()
        
        self.input_dim = input_dim
        self.num_heads = num_heads

        self.head_dim = input_dim // num_heads
        
        assert self.head_dim * num_heads == self.input_dim
        
        self.scaling = self.head_dim ** -0.5
        
        # Gather Q, K, V projections into one big projection
        self.projection = nn.Linear(input_dim, input_dim * 3, bias=False)
        self.out_projection = nn.Linear(input_dim, input_dim, bias=False)
        
        self.dropout = nn.Dropout(dropout)
    
    @staticmethod
    def get_key_padding_mask(lengths: torch.Tensor) -> torch.Tensor:
        """
        Args:
            lengths (torch.Tensor):
        Returns: mask to exclude keys that are pads, of shape `(batch, src_len)`,
            where padding elements are indicated by 1s.
        """
        
        max_length = torch.max(lengths).item()
        mask = (
            torch.arange(max_length, device=lengths.device)
            .ge(lengths.view(-1, 1))
            .contiguous()
            .bool()
        )

        return mask

    def _check_input_shape(self, input: torch.Tensor, mask: torch.BoolTensor):
        if input.dim() != 3:
            raise ValueError('Input should have 3 dimensions')

        if input.size(-1) != self.input_dim:
            raise ValueError('Expected order of dimensions is [T, B, C]')

        if mask.dtype != torch.bool:
            raise ValueError('Expected type of mask is torch.bool')
    
    def forward(self, input: torch.Tensor, key_padding_mask: torch.BoolTensor) -> torch.Tensor:
        self._check_input_shape(input, key_padding_mask)

        input_len, batch_size, _ = input.size()

        query, key, value = self.projection(input).chunk(3, dim=-1)
        assert query.size() == (input_len, batch_size, self.input_dim)
        
        # Gather batches with heads
        query = einops.rearrange(
            query, 'T batch (head dim) -> (batch head) T dim', head=self.num_heads
        )
        key = einops.rearrange(
            key, 'T batch (head dim) -> (batch head) dim T', head=self.num_heads
        )
        value = einops.rearrange(
            value, 'T batch (head dim) -> (batch head) T dim', head=self.num_heads
        )

        attn_weights = torch.bmm(query, key)
        attn_weights.mul_(self.scaling)
        assert attn_weights.size() == (batch_size * self.num_heads, input_len, input_len)

        # Masking padding scores
        attn_weights = attn_weights.view(batch_size, self.num_heads, input_len, input_len)
        attn_weights = attn_weights.masked_fill(
            key_padding_mask.unsqueeze(1).unsqueeze(2),
            float('-inf'),
        )
        attn_weights = attn_weights.view(batch_size * self.num_heads, input_len, input_len)

        attn_probs = torch.softmax(attn_weights, dim=-1)
        attn_probs = self.dropout(attn_probs)

        attn = torch.bmm(attn_probs, value)
        assert attn.size() == (batch_size * self.num_heads, input_len, self.head_dim)

        attn = einops.rearrange(
            attn, '(batch head) T dim -> T batch (head dim)', head=self.num_heads
        )
        attn = self.out_projection(attn)
        attn = self.dropout(attn)

        return attn


### Transformers in PyTorch

In [None]:
nn.Transformer
nn.TransformerDecoder
nn.TransformerDecoderLayer
nn.TransformerEncoder
nn.TransformerEncoderLayer
nn.MultiheadAttention

## How to train Transformers

https://tnq177.github.io/data/transformers_without_tears.pdf

https://arxiv.org/pdf/1804.00247.pdf

https://tunz.kr/post/4