# 📖 In‑Depth Introduction to Attention

The **attention mechanism** is a core component of the Transformer architecture. Its primary role is to let the model decide **which parts of the input are most relevant** when computing each output representation. Unlike fixed-context methods like convolutions or RNNs, attention enables **dynamic, data-driven interactions** between all input positions.

At the heart of attention are three learned projections: **queries**, **keys**, and **values**. For each output token, the model uses a query to compare against all keys in the sequence, generating **content-based weights** that determine how much to incorporate each value into the final representation.


---

## 🔑 Queries, Keys & Values: The Semantics

1. **Queries (Q)**  
   - Think of a query as **“what I’m looking for.”**  
   - For each position in the output, we derive a query vector that encodes the aspects of the input we want to match against.  
   - Semantically, a query represents the “question” we ask of all other positions.

2. **Keys (K)**  
   - A key represents **“what each position has to offer.”**  
   - Each input position produces a key vector that describes its content or characteristics.  
   - Semantically, a key is like a “descriptor” that the query compares against.

3. **Values (V)**  
   - A value represents **“what information to bring back.”**  
   - Each position also produces a value vector that carries the actual content to be integrated into the output.  
   - If a query “attends” strongly to a key, its corresponding value is used to update the representation at the query’s position.

---

## 🚀 Why Attention Matters

- **Dynamic Context**  
  - Every output position can attend to every input position.  
  - The model learns to focus on the most relevant parts of the sequence—whether they are nearby or far apart.

- **Parallelizable**  
  - All queries, keys, and values are computed at once for the entire sequence.  
  - Unlike RNNs, there’s no enforced sequential dependency, so GPUs/TPUs can process everything in parallel.

- **Rich Representations**  
  - By mixing information from multiple positions, attention captures long‑range dependencies and nuanced relationships.  
  - This leads to more expressive and context‑aware features than fixed‑window or local‑only methods.

- **Flexible & General**  
  - Attention is the core of models in NLP (machine translation, summarization), vision (Vision Transformer), audio, video, graphs, and beyond.  
  - Its fundamental idea—learning content‑based weights—applies across domains and modalities.


# 🔍 Single‑Head Self‑Attention: Step‑by‑Step

**Core idea**  
Each token “looks at” every other token and builds a new representation as a weighted sum of their values.

---

## 1. Input
$$
X \;\in\; \mathbb{R}^{B \times N \times D},
$$
where  
- $B$ = batch size  
- $N$ = sequence length (number of tokens)  
- $D$ = embedding dimension  

---

## 2. Linear projections
Learn three weight matrices  
$$
W^Q,\;W^K,\;W^V \;\in\; \mathbb{R}^{D \times D}
$$  
and compute:
$$
Q = X\,W^Q,\quad
K = X\,W^K,\quad
V = X\,W^V
$$  
All of shape $\mathbb{R}^{B \times N \times D}$.

---

## 3. Compute raw scores
$$
\mathrm{Scores} = Q\,K^\top
\;\in\;\mathbb{R}^{B \times N \times N},
$$  
where $K^\top$ denotes transposing the last two dimensions of $K$.

---

## 4. Scale and normalize
$$
\mathrm{Scores}_{\text{scaled}}
= \frac{\mathrm{Scores}}{\sqrt{D}}
\quad,\quad
A = \mathrm{softmax}\!\bigl(\mathrm{Scores}_{\text{scaled}}\bigr)
\;\in\;\mathbb{R}^{B \times N \times N}.
$$

---

## 5. Weighted sum of values
$$
Z = A\,V
\;\in\;\mathbb{R}^{B \times N \times D},
$$  
where each output token $z_i = \sum_{j=1}^N A_{ij}\,v_j$.

---

## 6. Complexity
- **Projections**: $\mathcal{O}(N\,D^2)$  
- **Attention (dot + weighted sum)**: $\mathcal{O}(N^2\,D)$  
- **Total**: $\displaystyle \mathcal{O}(N^2 D + N D^2)$


In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Attention(nn.Module):
  def __init__(self, dim):
    super().__init__()
    self.dim = dim
    self.q_proj = nn.Linear(dim, dim)
    self.k_proj = nn.Linear(dim, dim)
    self.v_proj = nn.Linear(dim, dim)

  def forward(self, x):
    """
    Apply single-head self-attention on the input sequence.

    Args:
        x (Tensor): Input tensor of shape (B, N, D), where:
            - B = batch size
            - N = number of tokens (sequence length or patches)
            - D = embedding dimension

    Returns:
        Tensor: Output tensor of shape (B, N, D), where each token's representation
                is updated based on its attention to all other tokens.

    Steps:
        1. Compute queries, keys, and values using linear projections.
        2. Compute attention scores using scaled dot-product between queries and keys.
        3. Normalize scores with softmax to get attention weights.
        4. Compute weighted sum of value vectors based on attention weights.
    """

    # Calculate query, key & value
    q = self.q_proj(x)                              # -> (B, N, D)
    k = self.k_proj(x)                              # -> (B, N, D)
    v = self.v_proj(x)                              # -> (B, N, D)

    # Compute attention weights
    k_t = k.transpose(1, 2)                         # -> (B, D, N)
    qk_t = q @ k_t                                  # -> (B, N, N)
    qk_t_normalized = qk_t / math.sqrt(self.dim)    # -> (B, N, N)
    attn_weights = F.softmax(qk_t_normalized, dim=-1)

    # Weighted sum of values
    out = attn_weights @ v

    return out

In [56]:
x = torch.randn(2, 5, 16)  # -> (2, 5, 16)
attn = Attention(dim=16)
x = attn(x)
print(x.shape)             # -> (2, 5, 16)

torch.Size([2, 5, 16])


# 🔍 Multi‑Head Self‑Attention: Step‑by‑Step

**Core idea**  
Instead of computing a single attention distribution, multi‑head attention splits the embedding into multiple smaller subspaces (heads), applies attention in each subspace independently, and then combines the results. This allows the model to **capture diverse patterns** of interaction in parallel—e.g., short‑term vs. long‑term dependencies, or syntactic vs. semantic roles.

---

## 1. Input
$$
X \;\in\; \mathbb{R}^{B \times N \times D},
$$
where  
- $B$ = batch size  
- $N$ = number of tokens (sequence length)  
- $D$ = total embedding dimension (must be divisible by number of heads $H$)

Let $d = D / H$ be the dimension per head.

---

## 2. Linear projections
We learn combined projection weights:
$$
W^{QKV} \;\in\; \mathbb{R}^{D \times 3D}
$$

Split the output into queries, keys, and values:
$$
Q,\;K,\;V = X \cdot W^{QKV} \quad\in\; \mathbb{R}^{B \times N \times 3D}
$$

Then reshape to split into $H$ heads:
$$
Q,\;K,\;V \in \mathbb{R}^{B \times H \times N \times d}
$$

---

## 3. Compute scaled dot-product attention (per head)
$$
\mathrm{Scores} = \frac{Q \cdot K^\top}{\sqrt{d}} \quad\in\; \mathbb{R}^{B \times H \times N \times N}
$$

Apply softmax across the last dimension:
$$
A = \mathrm{softmax}(\mathrm{Scores}) \quad\in\; \mathbb{R}^{B \times H \times N \times N}
$$

---

## 4. Weighted sum of values
$$
Z = A \cdot V \quad\in\; \mathbb{R}^{B \times H \times N \times d}
$$

Then concatenate the heads:
$$
Z_{\text{concat}} \in \mathbb{R}^{B \times N \times D}
$$

---

## 5. Final projection
Apply a final learned projection to mix the information from all heads:
$$
\text{Output} = Z_{\text{concat}} \cdot W^O, \quad W^O \in \mathbb{R}^{D \times D}
$$

---

## 6. Complexity
Let $H$ be the number of heads:
- **Projections**: $\mathcal{O}(N D^2)$  
- **Attention per head**: $\mathcal{O}(N^2 d)$ × $H$ = $\mathcal{O}(N^2 D)$  
- **Total**: $\displaystyle \mathcal{O}(N^2 D + N D^2)$

> Like single-head attention, multi-head attention still scales quadratically with sequence length $N$, but it allows the model to learn richer patterns through head diversity.


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import math

class MultiheadSelfAttention(nn.Module):
  def __init__(self, dim, num_heads):
    super().__init__()
    self.num_heads = num_heads
    self.dim = dim
    self.dk = dim // num_heads
    self.qkv = nn.Linear(dim, 3*dim)
    self.proj = nn.Linear(dim, dim)

  def forward(self, x):
    """
    Apply multi-head self-attention mechanism on the input tensor.

    Args:
        x (Tensor): Input tensor of shape (B, N, D), where:
            - B = batch size
            - N = number of tokens (sequence length or patches)
            - D = embedding dimension (must be divisible by num_heads)

    Returns:
        Tensor: Output tensor of shape (B, N, D) after applying multi-head self-attention
                and combining all heads. Each token's representation is updated by attending
                to all other tokens in the sequence.

    Steps:
        1. Project input into concatenated queries, keys, and values.
        2. Split into multiple heads and compute scaled dot-product attention per head.
        3. Concatenate outputs from all heads.
        4. Apply final linear projection to combine attention outputs.
    """
    B, N, D = x.shape
    qkv = self.qkv(x)                                     # -> (B, N, 3*D)
    qkv = qkv.reshape(B, N, 3, self.num_heads, self.dk)   # -> (B, N, 3, Nb_heads, head_dim)
    qkv = qkv.permute(2, 0, 3, 1, 4)                      # -> (3, B, Nb_heads, N, head_dim)

    # Calculate query, key & value
    q, k, v = qkv [0], qkv[1], qkv[2]                     # -> (B, Nb_heads, N, head_dim)

    # Compute attention weights
    k_t = k.transpose(-1, -2)
    qk_t = q @ k_t
    scores = qk_t / math.sqrt(self.dk)
    attn_weights = F.softmax(scores, dim=-1)

    # Compute output per head
    out = attn_weights @ v                                # -> (B, Nb_heads, N, head_dim)

    # Concatenate for all heads
    out = out.permute(0, 2, 1, 3)                         # -> (B, N, Nb_heads, head_dim)
    out = out.reshape(B, N, D)                            # -> (B, N, D)

    # Project in order to combine all information from all heads into a meaningful representation
    out = self.proj(out)                                  # -> (B , N, D)

    return out

In [9]:
# Test
x = torch.randn(size=(1, 4, 10))
print(x.shape)
mhsa = MultiheadSelfAttention(dim = 10, num_heads=2)
x = mhsa(x)

torch.Size([1, 4, 10])


# 🔧 Transformer Encoder Layer: Step-by-Step

**Core idea**  
A **Transformer layer** refines token representations using two key operations:
1. A **multi-head self-attention block** that mixes information across tokens.
2. A **feedforward MLP block** that processes each token independently.

Each block is wrapped with **layer normalization** and a **residual connection**, enabling stable and deep training.

---

## 1. Input
$$
X \;\in\; \mathbb{R}^{B \times N \times D},
$$
where  
- $B$ = batch size  
- $N$ = sequence length  
- $D$ = embedding dimension

---

## 2. Multi-Head Self-Attention (MHSA)

Apply **layer normalization** before attention:
$$
X_{\text{norm1}} = \text{LayerNorm}(X)
$$

Apply **multi-head self-attention**:
$$
Z = \text{MultiHeadSelfAttention}(X_{\text{norm1}})
\quad\in\; \mathbb{R}^{B \times N \times D}
$$

Add residual connection:
$$
X_{\text{attn}} = X + Z
$$

---

## 3. Feedforward MLP

Apply **layer normalization** again:
$$
X_{\text{norm2}} = \text{LayerNorm}(X_{\text{attn}})
$$

Pass through **two-layer MLP with GELU**:
$$
\text{MLP}(x) = W_2\,\text{GELU}(W_1\,x) \quad \text{with}\quad W_1 \in \mathbb{R}^{D \times D'},\; W_2 \in \mathbb{R}^{D' \times D}
$$

Then:
$$
Y = \text{MLP}(X_{\text{norm2}}) \quad\in\; \mathbb{R}^{B \times N \times D}
$$

Add second residual connection:
$$
\text{Output} = X_{\text{attn}} + Y
$$

---

## 4. Summary: The Full Transformer Layer

```python
x = x + MultiHeadSelfAttention(LayerNorm(x))
x = x + MLP(LayerNorm(x))


In [11]:
class Block(nn.Module):
  def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
    super().__init__()
    self.norm1 = nn.LayerNorm(dim)
    self.attn = MultiheadSelfAttention(dim=dim, num_heads=num_heads)
    self.norm2 = nn.LayerNorm(dim)
    hidden_dim = int (dim * mlp_ratio)
    self.mlp = nn.Sequential(
        nn.Linear(dim, hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_dim, dim),
        nn.Dropout(dropout)
    )

  def forward(self, x):
    """
      Transformer encoder block with:
      - Pre-LN + Multi-Head Self-Attention
      - Pre-LN + MLP
      - Two residual connections

      Args:
          x (Tensor): Input of shape (B, N, D)

      Returns:
          Tensor of shape (B, N, D)
      """
    # Multi-head self-attention with residual connection
    x = x + self.attn(self.norm1(x))                      # -> (B, N, D)

    # Feed-forward network with residual connection
    x = x + self.mlp(self.norm2(x))                       # -> (B, N, D)

    return x

In [12]:
# Test
x = torch.randn(size=(1, 5, 512))
block = Block(dim=512, num_heads=4)
x = block(x)
print(x.shape)

torch.Size([1, 5, 512])
