# 3 Transformer Anatomy

There are 3 types of transfomers:
- **Encoder only**: Bidirectional attention
- **Decoder only**: Autoregressive or causal attention
- **Encoder-Decoder**

## Calculating Attention and Contextualized embedding 

### 1. Embedding Matrix $E$ $(5 \times 4)$

The input sequence of 5 tokens, each with an embedding dimension of 4, is represented as:

$$
E = \begin{bmatrix}
e_{11} & e_{12} & e_{13} & e_{14} \\
e_{21} & e_{22} & e_{23} & e_{24} \\
e_{31} & e_{32} & e_{33} & e_{34} \\
e_{41} & e_{42} & e_{43} & e_{44} \\
e_{51} & e_{52} & e_{53} & e_{54}
\end{bmatrix}
E \in \mathbb{R}^{5 \times 4}
$$

### 2. Weight Matrices $W_q$, $W_k$, and $W_v (4 \times 2)$ 

The weight matrices for the query, key, and value transformations are:

$$
W_q = \begin{bmatrix}
w_{11}^q & w_{12}^q \\
w_{21}^q & w_{22}^q \\
w_{31}^q & w_{32}^q \\
w_{41}^q & w_{42}^q
\end{bmatrix}, \quad
W_k = \begin{bmatrix}
w_{11}^k & w_{12}^k \\
w_{21}^k & w_{22}^k \\
w_{31}^k & w_{32}^k \\
w_{41}^k & w_{42}^k
\end{bmatrix}, \quad
W_v = \begin{bmatrix}
w_{11}^v & w_{12}^v \\
w_{21}^v & w_{22}^v \\
w_{31}^v & w_{32}^v \\
w_{41}^v & w_{42}^v
\end{bmatrix}
W_q, W_k, W_v \in \mathbb{R}^{4 \times 2}
$$

### 3. Query $Q$, Key $K$, and Value $V$ Matrices $(5 \times 2)$

The query, key, and value matrices are computed as:

$$
Q = E W_q, \quad K = E W_k, \quad V = E W_v
$$

Explicitly:
$$
Q = \begin{bmatrix}
q_{11} & q_{12} \\
q_{21} & q_{22} \\
q_{31} & q_{32} \\
q_{41} & q_{42} \\
q_{51} & q_{52}
\end{bmatrix}, \quad
K = \begin{bmatrix}
k_{11} & k_{12} \\
k_{21} & k_{22} \\
k_{31} & k_{32} \\
k_{41} & k_{42} \\
k_{51} & k_{52}
\end{bmatrix}, \quad
V = \begin{bmatrix}
v_{11} & v_{12} \\
v_{21} & v_{22} \\
v_{31} & v_{32} \\
v_{41} & v_{42} \\
v_{51} & v_{52}
\end{bmatrix}
Q, K, V \in \mathbb{R}^{5 \times 2}
$$

### 4. Attention Scores with Softmax

The attention scores are computed as:

$$
\text{Attention Scores} = \text{Softmax}\left(\frac{QK^T}{\sqrt{d}}\right)
$$

where $d = 2$ (the dimension of the key vectors). The matrix multiplication $QK^T$ results in a $5 \times 5$ matrix:

$$
QK^T = \begin{bmatrix}
q_{11}k_{11} + q_{12}k_{12} & q_{11}k_{21} + q_{12}k_{22} & \cdots & q_{11}k_{51} + q_{12}k_{52} \\
q_{21}k_{11} + q_{22}k_{12} & q_{21}k_{21} + q_{22}k_{22} & \cdots & q_{21}k_{51} + q_{22}k_{52} \\
\vdots & \vdots & \ddots & \vdots \\
q_{51}k_{11} + q_{52}k_{12} & q_{51}k_{21} + q_{52}k_{22} & \cdots & q_{51}k_{51} + q_{52}k_{52}
\end{bmatrix}
$$

After scaling by $\frac{1}{\sqrt{d}}$ and applying the softmax function row-wise, the attention scores matrix $A$ is:

$$
A = \text{Softmax}\left(\frac{QK^T}{\sqrt{2}}\right)
A \in \mathbb{R}^{5 \times 5}
$$

### 5. Attention Output

The final output of the self-attention mechanism is computed as:

$$
\text{Output} = A V
$$

Explicitly:

$$
\text{Output} = \begin{bmatrix}
a_{11}v_{11} + a_{12}v_{21} + \cdots + a_{15}v_{51} & a_{11}v_{12} + a_{12}v_{22} + \cdots + a_{15}v_{52} \\
a_{21}v_{11} + a_{22}v_{21} + \cdots + a_{25}v_{51} & a_{21}v_{12} + a_{22}v_{22} + \cdots + a_{25}v_{52} \\
\vdots & \vdots \\
a_{51}v_{11} + a_{52}v_{21} + \cdots + a_{55}v_{51} & a_{51}v_{12} + a_{52}v_{22} + \cdots + a_{55}v_{52}
\end{bmatrix}
\text{Output} \in \mathbb{R}^{5 \times 2}
$$

Since Bert has 12 attention heads in following visualization, Q,K and V matrices has a column size of 64.

<img src="assets/ch3/1.png" width=750>

Basically, contextualized embedding of $x$ denoted with $x{'}$ is a linear combination of all value matrix rows using attention scores.


In [None]:
import warnings

from bertviz.neuron_view import show
from bertviz.transformers_neuron_view import BertModel
from transformers import AutoTokenizer

warnings.filterwarnings("ignore")

model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)
text = "Time flies like an arrow"

show(model, "bert", tokenizer, text, display_mode="light", layer=0, head=8)

Query, key and value terminology comes from Information Retrieval field. When everything is text, it's harder to understand why we need them, however an example in Chollet's "Deep Learning with Python" book demonstrates clearly when the key and values are different in terms of modality.

<img src="assets/ch3/2.png" width=600>

In [None]:
from math import sqrt

import torch
import torch.nn.functional as F
import torch.nn as nn


def scaled_dot_product(query, key, value, mask=None):
    dim_k = query.size(-1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    if mask:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    weights = F.softmax(scores, dim=-1)
    return torch.bmm(weights, value)


class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.q = nn.Linear(embed_dim, head_dim)  # W_q
        self.k = nn.Linear(embed_dim, head_dim)  # W_k
        self.v = nn.Linear(embed_dim, head_dim)  # W_v

    def forward(self, hidden_state):
        Q = self.q(hidden_state)
        K = self.q(hidden_state)
        V = self.q(hidden_state)
        attn_outputs = scaled_dot_product(Q, K, V)
        return attn_outputs


class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.hidden_size
        num_heads = config.num_attetion_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        self.outpu_layer = nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_state):
        x = torch.cat([h(hidden_state) for h in self.heads], dim=-1)
        x = self.outpu_layer(x)
        return x


class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)

    def forward(self, x):
        hidden_state = self.layer_norm_1(x)
        x = x + self.attention(hidden_state)
        x = x + self.feed_forward(self.layer_norm_2(x))
        return x


class Embedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout()

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = Embedding(config)
        self.layers = nn.ModuleList([TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(self, x):
        x = self.embeddings(x)
        for layer in self.layers():
            x = layer(x)
        return x


class TransformerForSequenceClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = TransformerEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, x):
        x = self.encoder(x)[:, 0, :]
        x = self.dropout(x)
        x = self.classifier(x)
        return x

In cross attention, source and target might have different dimensions therefore attention scores are rectangular:

$$ 
SE \in \mathbb{R}^{4 \times 4}, \quad \text{source embeddings} \\
TE \in \mathbb{R}^{8 \times 4}, \quad \text{target embeddings}
$$

Query matrix passed to the attention comes from decoder (target) and key, value comes from encoder (source).
```python

query = TE * W_tq
key = SE * W_sk
value = SE * W_sv
# Rectangular attention scores
scaled_dot_product(query, key, value, mask=None):
```
