## Full Transformer Architecture
<img src="images/Transformer_arch.png" width=450, height=600>

## Notes
- The attention mechanism view its input as a set rather than a sequence, aka it is in itself position-invariant. However in language processing, the order of words is important $\rightarrow$ need positional encoding.
- Multi-head self-attention (input: $b$ sequences (batch-size), each sequence has $t$ tokens, each token has $k$ dimension, $h$ heads)
    1. Project the input of dimension $(t, k)$ into $h$ inputs of dimension $(t, \frac{k}{h})$ with $3 \times h$ (3 for key, query, value) projection matrices size $(k, \frac{k}{h})$
    2. The outputs are the concatenated $\rightarrow$ size $(t, k)$ $\rightarrow$ put through a "unify matrix" with size $(k, k)$

In [1]:
import torch
from  torch import nn
from torch import Tensor

In [None]:
class SelfAttention():
    def __init__(self, token_dims: int, n_heads: int) -> None:
        self.n_heads = n_heads
        self.token_dims = token_dims
        self.head_dims = self.token_dims // self.n_heads

        assert (self.head_dims*self.n_heads == self.token_dims)     # the number of token's dimensions must be divisible by number of heads

        self.projection = nn.Linear(in_features=self.token_dims, out_features=self.head_dims, bias=False)
        self.key_transform = nn.Linear(in_features=self.head_dims, out_features=self.head_dims, bias=False)
        self.query_transform = nn.Linear(in_features=self.head_dims, out_features=self.head_dims, bias=False)
        self.value_transform = nn.Linear(in_features=self.head_dims, out_features=self.head_dims, bias=False)

        self.unify = nn.Linear(self.token_dims, self.token_dims)
        
    def forward(self, X: Tensor):           # X: (b, t, k)
        self.batch_size, self.n_tokens = X.shape[0], X.shape[1]

        key = self.key_transform(X)         # key: (b, t, k)
        query = self.query_transform(X)
        value = self.value_transform(X)

        key = key.view(self.batch_size, self.n_tokens, self.n_heads, self.head_dims)
        query = query.view(self.batch_size, self.n_tokens, self.n_heads, self.head_dims)
        value = value.view(self.batch_size, self.n_tokens, self.n_heads, self.head_dims)

        key = key.transpose(1, 2).contiguous().view(self.batch_size, self.n_heads, self.n_tokens, self.head_dims)       # prepare for matrix multiplication
        query = query.transpose(1, 2).contiguous().view(self.batch_size, self.n_heads, self.n_tokens, self.head_dims)
        value = value.transpose(1, 2).contiguous().view(self.batch_size, self.n_heads, self.n_tokens, self.head_dims)   # (b, h, t, s)

        dot_prod = torch.bmm(query, key.transpose(1, 2)) / torch.sqrt(self.head_dims)   # dot_prod: (b, h, t, t)
        weights = torch.softmax(dot_prod, dim=2)                                        # soft-max in the key dimension
        out = torch.bmm(weights, value)                                                 # (b, h, t, s)

        out.transpose(1, 2).view(self.batch_size, self.n_tokens, self.n_heads*self.head_dims)
        out = self.unify(out)
        return out

## Resources
    - https://peterbloem.nl/blog/transformers