# Block 1 — Mathematical Foundations of Tokenization, Batching & Attention Masking

This notebook builds the **mathematical backbone** for how LLMs go from raw text to tokens, how batches are formed and losses computed, and how attention masks (causal/padding/head) enforce the same principles architecturally.

---

## What you will learn in Block 1
1. **Objects & notation**: strings, normalization, tokenizer, vocabulary, special tokens.  
2. **Autoregressive factorization** of sequence probability.  
3. **Maximum-likelihood objective (NLL)** and its link to **cross-entropy**.  
4. **Batching with padding** and the **masked mean token loss**.  
5. **Perplexity** as exponential of mean NLL.  
6. **Attention masking (preview)**: causal mask, key-padding mask, and head masking.

---

## 1. Objects and notation

- **Alphabet & strings.**  
  Let $\Sigma$ be the raw character alphabet (e.g., Unicode code points).  
  Any finite string is an element:
  $$
  x \in \Sigma^{*}.
  $$

- **Normalization.**  
  A preprocessing map:
  $$
  \nu:\Sigma^{*} \to \Sigma^{*}, \qquad x \mapsto \nu(x),
  $$
  that may include:
  - Unicode normalization (e.g., NFKC),  
  - case folding (lowercasing),  
  - whitespace and punctuation rules.  

  *Remark:* Normalization **changes the support** of the data distribution by altering how strings map to tokens (e.g., “Café” → “Cafe”).

- **Tokenizer.**  
  A **tokenizer** is a function that segments normalized text into discrete units (tokens) from a fixed vocabulary $\mathcal{V}$.  
  Mathematically:
  $$
  \tau: \Sigma^{*} \to \mathcal{V}^{*}, \qquad
  \mathbf{t}=\tau(\nu(x))=(t_1,\dots,t_{L(x)}), \quad t_i \in \mathcal{V}.
  $$

  
  - Tokenizers define the *basic symbols* LLMs see.  
  - They can split text into **characters** (character-level), **words** (word-level), or **subwords** (BPE/Unigram).  
  - Subword tokenization (e.g., “tokenization” → `["token", "ization"]`) balances vocabulary size with sequence length.  
  - This mapping is deterministic: the same string always yields the same token sequence.  
  - Without tokenization, raw text would be unmanageable: the space of possible strings $\Sigma^{*}$ is infinite, but the space of tokens $\mathcal{V}$ is finite and tractable.

- **Vocabulary & special tokens.**  
  The vocabulary $\mathcal{V}=\{0,1,\dots,V-1\}$ contains learned tokens plus reserved IDs:
  - `<bos>` : beginning of sequence  
  - `<eos>` : end of sequence  
  - `<pad>` : padding (used in batching)  
  - `<unk>` : unknown token (for OOV strings)

- **Augmented sequence (optional).**  
  To mark boundaries explicitly:
  $$
  \tilde{\mathbf{t}} = (\texttt{<bos>},\, t_1,\dots,t_{L(x)},\, \texttt{<eos>})
  $$

- **Dataset.**  
  A dataset is a collection of $N$ independent samples:
  $$
  \mathcal{D} = \{\, x^{(n)} \,\}_{n=1}^{N}.
  $$
  After tokenization:
  $$
  \mathcal{D}_{\text{tok}} = \{\, \mathbf{t}^{(n)} \,\}_{n=1}^{N}, \qquad
  \mathbf{t}^{(n)} = (t^{(n)}_1, \dots, t^{(n)}_{L(x^{(n)})}).
  $$

---

## 2. Autoregressive factorization (preview)

An LLM with parameters $\theta$ assigns probability to a tokenized string via the **causal chain rule**:

$$
p_\theta(x) = \prod_{i=1}^{L(x)} p_\theta\!\big(t_i \mid t_{<i}\big),
\qquad t_{<i}=(t_1,\dots,t_{i-1}).
$$

With explicit boundaries:
$$
p_\theta(x) = \prod_{i=1}^{\tilde L} p_\theta\!\big(\tilde t_i \mid \tilde t_{<i}\big),
\quad \tilde t_1=\texttt{<bos>},\ \tilde t_{\tilde L}=\texttt{<eos>}.
$$

*Interpretation:*  
The model decomposes the text into **conditional next-token probabilities**; causality forbids peeking at the future.

---

## 3. Training objective (preview)

Given dataset $\mathcal{D}$, **maximum likelihood** maximizes the log-probability assigned to the data:

$$
\max_{\theta}\; \sum_{x \in \mathcal{D}} \sum_{i=1}^{L(x)} \log p_\theta\!\big(t_i \mid t_{<i}\big).
$$

Equivalently, we **minimize the negative log-likelihood (NLL)**.  
Later we will connect this to **softmax cross-entropy**, show the **masked mean over valid tokens** (excluding `<pad>`), and define **perplexity**:

$$
\mathrm{PPL} = \exp\!\big(\text{mean NLL per valid token}\big).
$$

---

## 4. Attention masking (preview)

Self-attention computes scaled dot products between queries and keys.  
To respect the mathematics above, we add:
- a **causal mask** (no attending to future positions),  
- a **key-padding mask** (ignore `<pad>` positions),  
- and optionally a **head mask** (enable/disable heads for ablation/interpretability).

Later we will show how these masks combine into the attention logits **before** the softmax.

---

### Outcome of Block 1

By the end of Block 1 you will be able to:
- Map raw strings $\to$ tokens with clear assumptions about normalization and special tokens.  
- Write and reason about $p_\theta(x)$ via autoregressive factorization.  
- Derive the MLE/NLL objective and relate it to cross-entropy.  
- Implement padding-aware **masked mean** losses and compute **perplexity**.  
- Explain how **attention masks** (causal/padding/head) enforce the same constraints architecturally.


## 2) Autoregressive Factorization and Training Objective

Now that we understand what **tokens** are and how a **tokenizer** converts text into discrete units, we ask:  
**How does a Large Language Model assign a probability to a whole sequence of tokens?**

---

### 2.1 The chain rule of probability

From basic probability theory:  
For any sequence of random variables $(X_1, X_2, \dots, X_L)$ we can always write:

$$
p(X_1, X_2, \dots, X_L) = \prod_{i=1}^L p(X_i \mid X_1, \dots, X_{i-1}).
$$

This is called the **chain rule**.  
It is an identity, not an assumption — it always holds.

- $p(X_1)$ is the probability of the first element.  
- $p(X_2 \mid X_1)$ is the probability of the second given the first.  
- $p(X_3 \mid X_1, X_2)$ is the probability of the third given the first two.  
- And so on.

---

### 2.2 Applying the chain rule to tokens

For a token sequence $\mathbf{t} = (t_1, t_2, \dots, t_L)$ produced by the tokenizer:

$$
p_\theta(\mathbf{t}) = \prod_{i=1}^L p_\theta(t_i \mid t_1, \dots, t_{i-1}),
$$

where $\theta$ are the model parameters (the weights of the neural network).

- $t_i$ is the $i$-th token.  
- $t_{<i} = (t_1, \dots, t_{i-1})$ is the prefix (all tokens before $i$).  
- Each factor is the probability of the **next token** given all previous ones.

With explicit boundary tokens `<bos>` (begin) and `<eos>` (end):

$$
p_\theta(\tilde{\mathbf{t}}) = \prod_{i=1}^{\tilde L} p_\theta(\tilde t_i \mid \tilde t_{<i}),
$$

where $\tilde L = L + 2$ because of the added boundaries.

---

### 2.3 Why autoregression?

- **Causality.** Humans generate text left-to-right; the model imitates this by only looking at the past.  
- **Simplicity.** Instead of modeling an entire sequence directly, we only need to model *next-token prediction*.  
- **Flexibility.** With this formulation we can generate text token by token: sample $t_1$ from $p(t_1)$, then $t_2$ from $p(t_2\mid t_1)$, and so on.

---

### 2.4 Training objective: Maximum Likelihood Estimation (MLE)

We want our model to assign **high probability** to real text sequences from the dataset.  

Given a dataset $\mathcal{D} = \{ \mathbf{t}^{(n)} \}_{n=1}^N$, the **log-likelihood** is:

$$
\log p_\theta(\mathcal{D}) = \sum_{n=1}^N \log p_\theta(\mathbf{t}^{(n)}).
$$

Using the autoregressive factorization:

$$
\log p_\theta(\mathbf{t}^{(n)}) = \sum_{i=1}^{L(x^{(n)})} \log p_\theta(t^{(n)}_i \mid t^{(n)}_{<i}).
$$

Thus the training objective is:

$$
\max_\theta \; \sum_{n=1}^N \sum_{i=1}^{L(x^{(n)})} \log p_\theta\!\big(t^{(n)}_i \mid t^{(n)}_{<i}\big).
$$

This is called **Maximum Likelihood Estimation (MLE).**

---

### 2.5 Why logs?

If we worked directly with the product $\prod p(t_i \mid t_{<i})$, probabilities quickly become astronomically small (multiplying many numbers less than 1).  
Taking the logarithm:

- Turns products into sums (easier to compute).  
- Stabilizes numerics (avoids underflow).  
- Matches information-theoretic interpretations (log-loss).  

---

### 2.6 Negative log-likelihood (NLL)

Instead of maximizing log-likelihood, we minimize the **negative log-likelihood**:

$$
\mathcal{L}_{\text{NLL}}(\theta) = -\sum_{n=1}^N \sum_{i=1}^{L(x^{(n)})} \log p_\theta(t^{(n)}_i \mid t^{(n)}_{<i}).
$$

This is the **loss function** of LLMs.  
Minimizing it means: *the model gets penalized whenever it assigns low probability to the correct next token.*

---

### 2.7 Connection to cross-entropy

At each position $i$, the model outputs a probability distribution over the vocabulary $\mathcal{V}$.  

If the true token is $t_i$, the per-token loss is:

$$
\ell_i = -\log p_\theta(t_i \mid t_{<i}).
$$

This is exactly the **cross-entropy** between the predicted distribution $p_\theta(\cdot \mid t_{<i})$ and the true one-hot distribution $y$ where $y_{k} = 1$ if $k=t_i$:

$$
\ell_i = H(y, p_\theta) = - \sum_{k=1}^{V} y_k \log p_\theta(k \mid t_{<i}).
$$

---

### 2.8 Gradient identity (intuition for learning)

The derivative of the cross-entropy with respect to the logits $z_{i,k}$ is:

$$
\frac{\partial \ell_i}{\partial z_{i,k}} = p_\theta(k \mid t_{<i}) - y_k.
$$

This means:
- If the model assigns too much probability to the wrong token, the gradient pushes it down.  
- If the model assigns too little to the correct token, the gradient pushes it up.  
- Learning is simply **adjusting logits so that predicted probabilities match observed tokens.**

---

### 2.9 Intuition

- Training an LLM = **teaching it to be a good next-token predictor**.  
- By predicting the next token well across billions of examples, the model learns grammar, semantics, style, reasoning patterns, and world knowledge.  
- All of this comes from one simple principle: *minimize negative log-likelihood (maximize the probability of real text).*

---

**Next (Block 1.3):**  
We extend this objective to **batches of sequences** (with different lengths), introduce **padding**, and show how to compute a **masked mean token loss** that ignores `<pad>` tokens.


## 3) Batching, Padding, and the Masked Token Loss

So far we looked at single sequences. But in practice, models are trained with **batches** of sequences in parallel, for efficiency.

---

### 3.1 Why batching?

- GPUs/TPUs can process many examples at once.  
- Instead of updating weights after every single sequence, we group **B sequences** into a **batch**.  
- This allows **vectorized operations** and faster convergence (by averaging gradients).

Formally:  
A batch $\mathcal{B}$ is a set of $B$ token sequences:
$$
\mathcal{B} = \{\mathbf{t}^{(1)}, \mathbf{t}^{(2)}, \dots, \mathbf{t}^{(B)}\}.
$$

---

### 3.2 The problem of variable lengths

Different sequences have different lengths:
- Example: `"The cat"` → 2 tokens.  
- Example: `"A very long sentence here"` → 6 tokens.  

Neural networks expect tensors with the **same shape**, so we cannot stack them directly.  
Solution: **Padding**.

---

### 3.3 Padding sequences

Let $L_{\max}$ be the length of the longest sequence in the batch.  
We build a matrix:

$$
T \in \mathbb{N}^{B \times L_{\max}}, \quad
T_{b,i} =
\begin{cases}
t^{(b)}_i & \text{if } i \leq L(\mathbf{t}^{(b)}), \\\\
\texttt{<pad>} & \text{otherwise}.
\end{cases}
$$

- Each row = one sequence.  
- Shorter sequences are padded with a special `<pad>` token until they reach length $L_{\max}$.

---

### 3.4 Mask of valid tokens

To ensure padding does not affect the loss, we build a **mask**:

$$
M \in \{0,1\}^{B \times L_{\max}}, \qquad
M_{b,i} =
\begin{cases}
1 & \text{if } T_{b,i} \neq \texttt{<pad>} \quad (\text{valid token}), \\\\
0 & \text{if } T_{b,i} = \texttt{<pad>} \quad (\text{ignored}).
\end{cases}
$$

---

### 3.5 Masked mean token loss

Without masking, the loss would include `<pad>` tokens, corrupting training.  
Instead we average only over **valid tokens**:

$$
\mathcal{L}_{\text{token}} =
- \frac{1}{N_{\text{valid}}}
\sum_{b=1}^B \sum_{i=1}^{L_{\max}}
M_{b,i} \; \log p_\theta\!\big(T_{b,i} \mid T_{b,<i}\big),
$$

where
$$
N_{\text{valid}} = \sum_{b=1}^B \sum_{i=1}^{L_{\max}} M_{b,i}.
$$

This ensures:
- Loss is **independent of how much padding there is**.  
- Gradients reflect only *real tokens*.  

---

### 3.6 Implementation intuition

Most deep learning frameworks implement this using:
- A **CrossEntropyLoss** with `ignore_index=pad_id` (in PyTorch).  
- An **attention_mask** to indicate valid tokens (in Transformers).

Thus:
- Forward pass uses the padded tensor.  
- Loss computation ignores `<pad>`.  
- Attention layers also ignore `<pad>` (via masking).



In [7]:
import torch
import torch.nn as nn

# Suppose we have 3 sequences of different lengths
pad_id = 0
batch = torch.tensor([
    [5, 6, 7, 8, pad_id, pad_id],      # length 4
    [9, 10, 11, pad_id, pad_id, pad_id], # length 3
    [12, 13, 14, 15, 16, 17]            # length 6 (max)
])

print("Batch shape:", batch.shape)  # (3, 6)

# Mask: 1 for valid tokens, 0 for pad
mask = (batch != pad_id).long()
print("Mask:\n", mask)


# Example: "the cat sleeps"
#
# We add special tokens: [BOS] at the start, [EOS] at the end.
# Sequence becomes: [BOS], "the", "cat", "sleeps", [EOS]
#
# The model output (logits) has shape (B, L, V):
# - B = batch size (here 1 sentence)
# - L = sequence length (5 tokens including BOS/EOS)
# - V = vocabulary size
#
# At each position 'Pos':
# - Pos 1: Context = [BOS], model predicts the first real token ("the")
# - Pos 2: Context = [BOS, "the"], model predicts "cat"
# - Pos 3: Context = [BOS, "the", "cat"], model predicts "sleeps"
# - Pos 4: Context = [BOS, "the", "cat", "sleeps"], model predicts [EOS]
#
# Each row logits[0, Pos, :] is a vector of size V with scores
# (before softmax) for all tokens in the vocabulary.
# After applying softmax, we get a probability distribution
# P(next_token | context).
#
# In training, we compare this distribution with the true target token
# at that position using cross-entropy loss.



# Example: logits from a model (random for demo)
B, L, V = batch.shape[0], batch.shape[1], 20  # vocab size = 20
logits = torch.randn(B, L, V)  # simulation of a model output

#Print this line if you want to see logits the number of matrix in will be equal to the number of batches, each row in the matrix is the position of each word, th value in each
#row is the probability of that word in the Vocab.
#print("Logits: \n", logits)


# Loss function that ignores pad tokens when computing the average.
# CrossEntropyLoss expects logits of shape (N, V) and target indices of shape (N,).
# 'ignore_index=pad_id' ensures positions labeled as PAD do NOT contribute to the loss.
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)

targets = batch.clone()

# This line is effectively a no-op for PADs (it re-assigns PAD to PAD).
# The intention is to make explicit that PAD tokens remain PAD so that
# 'ignore_index' can exclude them from the loss computation.
targets[targets == pad_id] = pad_id

# A quick debug print to verify the exact target tensor being used.
# Good practice: print shapes and a small slice if the tensor is large.
print("Targets (same shape as batch):\n", targets)  # shape: (B, L)

# CrossEntropyLoss expects flattened shapes:
# - logits.view(-1, V):   (B * L, V)
# - targets.view(-1):     (B * L,)
# Each row in logits corresponds to exactly one target class index.
# Positions where target == pad_id will be ignored in the loss.
loss = loss_fn(logits.view(-1, V), targets.view(-1))


# .view(-1, V) flattens the first two dimensions (B, L) into one.
# The -1 tells PyTorch to automatically infer that size (B*L).
# Example: logits.shape = (B, L, V) -> logits.view(-1, V) = (B*L, V).
# Each row now corresponds to one token position, with V scores (logits).


# Print the scalar loss value. "Masked" indicates PAD positions were excluded
# thanks to 'ignore_index=pad_id'.
print("Masked loss:", loss.item())


# Without ignore_index, PAD tokens are treated as real targets:
# - Loss & gradients get biased toward predicting PAD
# - Shorter sequences contribute many PAD positions
# - Metrics (loss/perplexity) become misleading
unmasked_loss_fn = nn.CrossEntropyLoss(reduction="mean")  # no ignore_index
unmasked_loss = unmasked_loss_fn(logits.view(-1, V), targets.view(-1))

print("Unmasked loss (PAD counted):", unmasked_loss.item())
print("Masked loss   (PAD ignored):", loss.item())

# Rule of thumb:
# If your batch contains padding, the unmasked loss is typically larger
# and will push the model to predict PAD more often—hurting learning.
# Always ignore PAD in the loss (via ignore_index) and in attention masks.

Batch shape: torch.Size([3, 6])
Mask:
 tensor([[1, 1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1]])
Targets (same shape as batch):
 tensor([[ 5,  6,  7,  8,  0,  0],
        [ 9, 10, 11,  0,  0,  0],
        [12, 13, 14, 15, 16, 17]])
Masked loss: 3.303070306777954
Unmasked loss (PAD counted): 3.4462995529174805
Masked loss   (PAD ignored): 3.303070306777954


## 4) Perplexity: measuring how well a model predicts

Up to now we focused on the **loss** (negative log-likelihood). But in language modeling, people often report **perplexity (PPL)**.

---

### 4.1 Definition

For a dataset of $N_{\text{valid}}$ valid tokens, perplexity is:

$$
\mathrm{PPL} = \exp\!\Bigg( \frac{1}{N_{\text{valid}}}
\sum_{n=1}^N \sum_{i=1}^{L(x^{(n)})}
-\log p_\theta(t^{(n)}_i \mid t^{(n)}_{<i}) \Bigg).
$$

This is simply the exponential of the **average negative log-likelihood per token**.

---

### 4.2 Intuition

- If $\mathrm{PPL}=1$: the model is perfect — it always assigns probability 1 to the correct token.  
- If $\mathrm{PPL}=V$ (vocabulary size): the model is as bad as random guessing.  
- Lower perplexity = better model.

In words: **perplexity measures “how many tokens the model is confused among, on average.”**

---

### 4.3 Why use perplexity?

- It is easier to interpret than raw log-loss.  
- It is standard in language modeling benchmarks (e.g., Penn Treebank, WikiText).  
- It allows fair comparison across models trained with different vocabularies (to some extent).

---

### 4.4 Example with code
Now let’s compute perplexity from the loss in PyTorch.


In [None]:
import torch
import torch.nn as nn
import math

# Suppose we have predictions for a batch
pad_id = 0
targets = torch.tensor([
    [5, 6, 7, 8, pad_id, pad_id],
    [9, 10, 11, pad_id, pad_id, pad_id],
    [12, 13, 14, 15, 16, 17]
])

B, L = targets.shape
V = 20
logits = torch.randn(B, L, V)

# Loss function that ignores padding
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id, reduction="sum")

# Compute total loss (sum over tokens, not mean yet)
loss_sum = loss_fn(logits.view(-1, V), targets.view(-1))

# Count valid tokens (non-pad)
valid_tokens = (targets != pad_id).sum().item()

# Average NLL per token
nll_avg = loss_sum.item() / valid_tokens

# Perplexity
ppl = math.exp(nll_avg)

print("Average NLL per token:", nll_avg)
print("Perplexity:", ppl)


## 5) Attention Masking: Causal, Padding, and Head Masks

So far we focused on **token probabilities** and **loss functions**.  
Now we connect this theory to the **attention mechanism** inside Transformers — the core building block of LLMs.

---

### 5.1 Self-attention in a nutshell

In self-attention, every token can attend to every other token.  
For a sequence of length $L$ and embedding size $d$, we compute:

- **Queries:** $Q \in \mathbb{R}^{L \times d}$  
- **Keys:** $K \in \mathbb{R}^{L \times d}$  
- **Values:** $V \in \mathbb{R}^{L \times d}$  

The **attention scores** are:

$$
S = \frac{QK^\top}{\sqrt{d}} \quad \in \mathbb{R}^{L \times L}.
$$

After applying softmax row-wise, we get the **attention weights**:

$$
A = \text{softmax}(S),
$$

and then the output is:

$$
O = A \, V.
$$

---

### 5.2 The problem

- Without restrictions, a token at position $i$ could attend to **future tokens** $j>i$.  
- `<pad>` tokens could receive attention, contaminating representations.  
- Sometimes we want to **disable certain heads** for ablation or interpretability.

**Solution:** use **masks** that tell the model where it is allowed to attend.

---

### 5.3 Causal mask (autoregression)

**Goal:** prevent a token from looking into the future.

Definition:

$$
C \in \{0,1\}^{L \times L}, \qquad
C_{ij} =
\begin{cases}
1 & \text{if } j>i \quad (\text{future blocked}), \\\\
0 & \text{otherwise}.
\end{cases}
$$

Applied to scores:

$$
S \;\leftarrow\; S + (-10^9)\cdot C.
$$

This sets future positions to a very large negative value (≈$-\infty$), so after softmax they get probability 0.

---

### 5.4 Padding mask

**Goal:** ignore `<pad>` tokens added during batching.

Definition for a batch of size $B$:

$$
P \in \{0,1\}^{B \times L}, \qquad
P_{b,j} =
\begin{cases}
1 & \text{if token $j$ in sequence $b$ is <pad>}, \\\\
0 & \text{otherwise}.
\end{cases}
$$

When broadcasted into attention scores, positions marked with 1 are also masked out.

---

### 5.5 Head mask

**Goal:** enable or disable specific heads.

For $H$ attention heads:

$$
h \in \{0,1\}^H \quad \text{(or } h \in \{0,1\}^{B \times H}\text{ if per-example)}.
$$

- $h_h=0$ → that head is disabled (all its contributions zeroed).  
- Useful for **interpretability** (probing which heads matter), or for **structured pruning**.

---

### 5.6 Combined masking

For a full batch, the final scores are:

$$
S = \frac{QK^\top}{\sqrt{d}}.
$$

Then:

1. Broadcast $C$ (causal mask) to $(1,1,L,L)$.  
2. Broadcast $P$ (padding mask) to $(B,1,1,L)$.  
3. Combine:  
   $$
   M = \text{broadcast}(C) \;\lor\; \text{broadcast}(P).
   $$
4. Apply to scores:  
   $$
   S \;\leftarrow\; S + (-10^9)\cdot M.
   $$
5. Softmax row-wise over keys:  
   $$
   A = \text{softmax}(S).
   $$
6. Apply head mask (if any):  
   $$
   A \;\leftarrow\; A \odot \text{broadcast}(h).
   $$
7. Final output:  
   $$
   O = A \, V.
   $$

---

### 5.7 Why attention masking is essential

- **Causal mask:** enforces the autoregressive factorization (no cheating by looking at the future).  
- **Padding mask:** ensures that `<pad>` tokens do not leak into the computation.  
- **Head mask:** gives flexibility for analysis and pruning.

Together, these masks align the **mathematics of training (loss + padding)** with the **architecture of attention**.

---

**Next (Block 1.6):**  
We will implement these masks in **PyTorch code**:  
1. Build a causal mask matrix.  
2. Build a padding mask from a batch with `<pad>`.  
3. Show how they are applied inside a MultiheadAttention layer.  
4. Visualize the causal mask.


In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# -------------------------------
# 1. Create toy batch (with padding)
# -------------------------------
pad_id = 0
batch = torch.tensor([
    [5, 6, 7, 8, pad_id, pad_id],      # length 4
    [9, 10, 11, pad_id, pad_id, pad_id], # length 3
    [12, 13, 14, 15, 16, 17]            # length 6 (max)
])

B, L = batch.shape
D = 16   # embedding dimension
H = 2    # number of heads

# -------------------------------
# 2. Build causal mask
# -------------------------------
# Shape: (L, L)
causal_mask = torch.triu(torch.ones(L, L) * float("-inf"), diagonal=1)

print("Causal mask:\n", causal_mask)

# -------------------------------
# 3. Build key-padding mask
# -------------------------------
# True where PAD tokens are present
padding_mask = (batch == pad_id)
print("Padding mask shape:", padding_mask.shape)  # (B, L)

# -------------------------------
# 4. MultiheadAttention with masks
# -------------------------------
attn = nn.MultiheadAttention(embed_dim=D, num_heads=H, batch_first=True)

# Random embeddings for the tokens
x = torch.randn(B, L, D)

# Apply attention with both masks
out, attn_weights = attn(
    x, x, x,
    attn_mask=causal_mask,         # (L, L)
    key_padding_mask=padding_mask  # (B, L)
)

print("Output shape:", out.shape)                 # (B, L, D)
print("Attention weights shape:", attn_weights.shape) # (B*H, L, L)

# -------------------------------
# 5. Visualize causal mask
# -------------------------------
plt.imshow(causal_mask.numpy(), cmap="gray")
plt.title("Causal Attention Mask (white = -inf)")
plt.colorbar()
plt.show()


## 6) Attention Masking Implementation: Mathematical Foundations and PyTorch Realization

In this section we bridge the **mathematical theory** of autoregressive language modeling with the **architectural constraints** enforced by attention masking. We will prove that proper masking is not just an implementation detail, but a **necessary condition** for the mathematical coherence of the training objective.

---

### 6.1 Mathematical preliminaries: Attention as conditional probability computation

**Formal setup:**  
Let $\mathbf{X} \in \mathbb{R}^{L \times d}$ be a sequence of token embeddings, where $L$ is sequence length and $d$ is embedding dimension.

The **scaled dot-product attention** mechanism computes:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
$$

where:
- $Q = XW_Q \in \mathbb{R}^{L \times d_k}$ (queries)
- $K = XW_K \in \mathbb{R}^{L \times d_k}$ (keys)  
- $V = XW_V \in \mathbb{R}^{L \times d_v}$ (values)
- $W_Q, W_K \in \mathbb{R}^{d \times d_k}$, $W_V \in \mathbb{R}^{d \times d_v}$ are learned parameters

**Critical observation:**  
The attention weights $A_{ij} = \text{softmax}_{j}\left(\frac{q_i^T k_j}{\sqrt{d_k}}\right)$ represent the **probability** that query $i$ attends to key $j$.

**Mathematical constraint from autoregression:**  
For the model to respect the causal factorization $p_\theta(t_i | t_{<i})$, we must enforce:

$$
A_{ij} = 0 \quad \forall j > i
$$

This is not optional—it is a **mathematical necessity** derived from the independence assumptions in autoregressive modeling.

---

### 6.2 Formal definition of attention masks

**Definition 6.1 (Attention Mask):**  
An attention mask is a function $M: \mathbb{N} \times \mathbb{N} \rightarrow \{0, -\infty\}$ that modifies attention scores:

$$
\tilde{S}_{ij} = S_{ij} + M(i,j)
$$

where $S_{ij} = \frac{q_i^T k_j}{\sqrt{d_k}}$ are the raw attention scores.

**Definition 6.2 (Causal Mask):**  
The causal mask $C: \mathbb{N} \times \mathbb{N} \rightarrow \{0, -\infty\}$ is defined as:

$$
C(i,j) = \begin{cases}
0 & \text{if } j \leq i \\
-\infty & \text{if } j > i
\end{cases}
$$

**Theorem 6.1 (Causal Consistency):**  
*The causal mask is necessary and sufficient to ensure that the attention mechanism respects autoregressive ordering.*

**Proof:**  
*Necessity:* Suppose $C(i,j) \neq -\infty$ for some $j > i$. Then after softmax, $A_{ij} > 0$, meaning position $i$ can access information from future position $j$, violating the conditional independence $p(t_i | t_{<i})$.

*Sufficiency:* If $C(i,j) = -\infty$ for all $j > i$, then $\tilde{S}_{ij} = -\infty$, so $A_{ij} = 0$ after softmax, ensuring position $i$ only accesses $\{t_1, \ldots, t_i\}$. □

---

### 6.3 Padding mask: Mathematical treatment of variable-length sequences

**Problem statement:**  
In batched training, sequences have different lengths $L_1, L_2, \ldots, L_B$. To form tensors, we pad to $L_{\max} = \max_b L_b$ using a special token $\text{<pad>}$.

**Mathematical issue:**  
Padded positions contain **no semantic information** but standard attention would still compute:

$$
o_i = \sum_{j=1}^{L_{\max}} A_{ij} v_j = \sum_{j=1}^{L_b} A_{ij} v_j + \sum_{j=L_b+1}^{L_{\max}} A_{ij} v_{\text{<pad>}}
$$

The second sum **contaminates** the representation with meaningless pad embeddings.

**Definition 6.3 (Key-Padding Mask):**  
For a batch $\mathcal{B} = \{\mathbf{t}^{(b)}\}_{b=1}^B$ with lengths $\{L_b\}$, the key-padding mask is:

$$
P^{(b)}(i,j) = \begin{cases}
0 & \text{if } j \leq L_b \\
-\infty & \text{if } j > L_b
\end{cases}
$$

**Theorem 6.2 (Padding Invariance):**  
*With proper key-padding masking, the attention output is invariant to the amount of padding.*

**Proof:**  
Let $\tilde{S}^{(b)}_{ij} = S^{(b)}_{ij} + P^{(b)}(i,j)$. For $j > L_b$, we have $\tilde{S}^{(b)}_{ij} = -\infty$, so:

$$
A^{(b)}_{ij} = \frac{\exp(\tilde{S}^{(b)}_{ij})}{\sum_{k=1}^{L_{\max}} \exp(\tilde{S}^{(b)}_{ik})} = \frac{0}{\sum_{k=1}^{L_b} \exp(\tilde{S}^{(b)}_{ik})} = 0
$$

Therefore:
$$
o_i^{(b)} = \sum_{j=1}^{L_b} A^{(b)}_{ij} v_j^{(b)}
$$

which depends only on the first $L_b$ positions, regardless of $L_{\max}$. □

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Reproducibility
torch.manual_seed(42)

# Define special tokens with clear semantic meaning
PAD_ID = 0   # Padding token (no semantic content)
BOS_ID = 1   # Beginning of sequence marker
EOS_ID = 2   # End of sequence marker

# Create a realistic batch demonstrating the masking problem
# Each sequence represents: [BOS, content_tokens..., EOS, PAD, PAD, ...]
batch_tokens = torch.tensor([
    [BOS_ID, 5, 6, 7, 8, EOS_ID, PAD_ID, PAD_ID],        # Real length: 6
    [BOS_ID, 9, 10, 11, EOS_ID, PAD_ID, PAD_ID, PAD_ID],  # Real length: 5  
    [BOS_ID, 12, 13, 14, 15, 16, 17, EOS_ID]             # Real length: 8 (no padding)
])

B, L = batch_tokens.shape  # B=3 sequences, L=8 max length
print(f"Batch shape: {batch_tokens.shape}")
print("Batch contents (showing real vs padded tokens):")
for i, seq in enumerate(batch_tokens):
    real_tokens = seq[seq != PAD_ID]
    pad_count = (seq == PAD_ID).sum().item()
    print(f"  Sequence {i}: {real_tokens.tolist()} + {pad_count} padding tokens")

### 6.4 Constructing the causal mask: Theory and implementation

**Mathematical derivation:**  
We need to construct a mask matrix $C \in \{0, -\infty\}^{L \times L}$ such that:

$$
C_{ij} = \begin{cases}
0 & \text{if } j \leq i \text{ (causally valid)} \\
-\infty & \text{if } j > i \text{ (causally invalid)}
\end{cases}
$$

This can be expressed using the **upper triangular matrix** with offset:

$$
C = -\infty \cdot \text{triu}(\mathbf{1}_{L \times L}, k=1)
$$

where $\text{triu}(\cdot, k=1)$ extracts the strictly upper triangular part.

**Computational complexity:**  
The causal mask is sequence-length dependent: $O(L^2)$ space and $O(L^2)$ time to construct. For very long sequences (e.g., $L = 100k$), this becomes significant.

**Implementation note:**  
In practice, the mask is pre-computed once per sequence length and reused across batches.

In [None]:
def create_causal_mask(seq_length, device=None, dtype=torch.float32):
    """
    Construct a causal attention mask preventing future information leakage.
    
    Mathematical definition:
        C[i,j] = 0     if j <= i (position j is causally valid for query i)
        C[i,j] = -inf  if j > i  (position j is in the future relative to i)
    
    Args:
        seq_length (int): Length of the sequence
        device (torch.device): Device to place the mask on
        dtype (torch.dtype): Data type for the mask
    
    Returns:
        torch.Tensor: Causal mask of shape (seq_length, seq_length)
        
    Time complexity: O(L²)
    Space complexity: O(L²)
    """
    # Method 1: Using triu (upper triangular)
    # Create matrix of ones, extract upper triangle starting from diagonal+1
    upper_triangle = torch.triu(torch.ones(seq_length, seq_length, dtype=dtype), diagonal=1)
    
    # Convert 1s to -inf (blocked positions) and 0s remain 0 (allowed positions)
    causal_mask = upper_triangle.masked_fill(upper_triangle == 1, float('-inf'))
    
    if device is not None:
        causal_mask = causal_mask.to(device)
    
    return causal_mask

def verify_causal_mask_properties(mask):
    """
    Verify mathematical properties of the causal mask.
    """
    L = mask.shape[0]
    print(f"=== Causal Mask Verification (L={L}) ===")
    
    # Property 1: Lower triangular + diagonal should be 0
    lower_triangular = torch.tril(mask)
    if torch.all(lower_triangular == 0):
        print("✅ Lower triangular part is zero (causally valid positions)")
    else:
        print("❌ Lower triangular part contains non-zero values")
    
    # Property 2: Upper triangular should be -inf
    upper_triangular = torch.triu(mask, diagonal=1)
    if torch.all(upper_triangular == float('-inf')):
        print("✅ Upper triangular part is -inf (future positions blocked)")
    else:
        print("❌ Upper triangular part contains finite values")
    
    # Property 3: Diagonal should be 0 (token can attend to itself)
    diagonal = torch.diag(mask)
    if torch.all(diagonal == 0):
        print("✅ Diagonal is zero (self-attention allowed)")
    else:
        print("❌ Diagonal contains non-zero values")

# Create and verify causal mask
causal_mask = create_causal_mask(L)
print(f"Causal mask shape: {causal_mask.shape}")
print("Causal mask (showing first 5×5 submatrix):")
print(causal_mask[:5, :5])

verify_causal_mask_properties(causal_mask)

### 6.5 Key-padding mask: Measure-theoretic foundations

**Problem formalization:**  
In measure theory, padding introduces **zero-measure support** in the token space. Let $\mathcal{T}$ be the space of valid tokens and $\mu$ be the natural counting measure. Padding tokens live in $\mathcal{T}_{\text{pad}} = \{\text{<pad>}\}$ where $\mu(\mathcal{T}_{\text{pad}}) = 0$ semantically.

**Definition 6.4 (Semantic Support):**  
For a padded sequence $\tilde{\mathbf{t}} = (t_1, \ldots, t_{L_{\text{real}}}, \text{<pad>}, \ldots, \text{<pad>})$, the semantic support is:

$$
\text{supp}_{\text{sem}}(\tilde{\mathbf{t}}) = \{i : \tilde{t}_i \neq \text{<pad>}\} = \{1, 2, \ldots, L_{\text{real}}\}
$$

**Theorem 6.3 (Attention Concentration):**  
*For any query position $i$, attention weights must satisfy the concentration property:*

$$
\sum_{j \in \text{supp}_{\text{sem}}(\tilde{\mathbf{t}})} A_{ij} = 1, \quad A_{ij} = 0 \text{ for } j \notin \text{supp}_{\text{sem}}(\tilde{\mathbf{t}})
$$

**Proof sketch:**  
This follows from the normalization constraint of softmax and the requirement that $\exp(-\infty) = 0$.

**Broadcasting theory:**  
When working with batches, the key-padding mask must be broadcast correctly. Let $\mathcal{B} = \{\mathbf{t}^{(b)}\}_{b=1}^B$ with support sets $\{S_b\}_{b=1}^B$ where $S_b = \text{supp}_{\text{sem}}(\mathbf{t}^{(b)})$.

The batch padding mask $\mathbf{P} \in \{0, -\infty\}^{B \times L_{\max}}$ satisfies:

$$
\mathbf{P}_{b,j} = \begin{cases}
0 & \text{if } j \in S_b \\
-\infty & \text{if } j \notin S_b
\end{cases}
$$

**Computational considerations:**  
- **Space complexity:** $O(B \times L_{\max})$ for the mask tensor
- **Broadcasting complexity:** Converting from $(B, L)$ to $(B, H, L, L)$ requires $O(B \times H \times L^2)$ memory
- **Cache efficiency:** Masks are typically sparse and benefit from specialized sparse attention kernels

In [None]:
def create_key_padding_mask(batch_tokens, pad_id, return_inverse=False):
    """
    Create key-padding mask with comprehensive mathematical verification.
    
    Mathematical definition:
        For batch element b and position j:
        P[b,j] = True   if token[b,j] == pad_id (position to ignore)
        P[b,j] = False  if token[b,j] != pad_id (valid position)
    
    Args:
        batch_tokens (torch.Tensor): Shape (B, L) containing token IDs
        pad_id (int): Token ID used for padding
        return_inverse (bool): If True, also return the inverse mask
    
    Returns:
        torch.Tensor: Boolean mask, True for positions to ignore
        torch.Tensor (optional): Inverse mask for convenience
    """
    # Create boolean mask: True where padding tokens exist
    padding_mask = (batch_tokens == pad_id)
    
    if return_inverse:
        # Inverse mask: True for valid (non-padding) positions
        valid_mask = ~padding_mask
        return padding_mask, valid_mask
    
    return padding_mask

def analyze_batch_padding_statistics(batch_tokens, pad_id):
    """
    Analyze padding patterns and their impact on computational efficiency.
    """
    B, L = batch_tokens.shape
    
    # Compute sequence lengths
    sequence_lengths = []
    for b in range(B):
        seq_len = (batch_tokens[b] != pad_id).sum().item()
        sequence_lengths.append(seq_len)
    
    # Padding statistics
    total_tokens = B * L
    valid_tokens = sum(sequence_lengths)
    padding_tokens = total_tokens - valid_tokens
    padding_ratio = padding_tokens / total_tokens
    
    # Efficiency metrics
    avg_length = np.mean(sequence_lengths)
    length_variance = np.var(sequence_lengths)
    efficiency = valid_tokens / total_tokens  # Fraction of useful computation
    
    print(f"=== Batch Padding Analysis ===")
    print(f"Batch size: {B}, Max length: {L}")
    print(f"Sequence lengths: {sequence_lengths}")
    print(f"Average length: {avg_length:.2f} ± {np.sqrt(length_variance):.2f}")
    print(f"Valid tokens: {valid_tokens}/{total_tokens} ({efficiency:.1%})")
    print(f"Padding ratio: {padding_ratio:.1%}")
    print(f"Computational efficiency: {efficiency:.1%}")
    
    # Memory overhead analysis
    mask_memory = B * L  # Boolean mask
    attention_memory = B * L * L  # Full attention matrix
    print(f"Mask memory overhead: {mask_memory} bools")
    print(f"Attention memory: {attention_memory} floats")

# Create padding mask and analyze batch
key_padding_mask, valid_mask = create_key_padding_mask(batch_tokens, PAD_ID, return_inverse=True)

print("Key padding mask (True = ignore, False = attend):")
print(key_padding_mask.int())
print("\nValid token mask (True = attend, False = ignore):")
print(valid_mask.int())

analyze_batch_padding_statistics(batch_tokens, PAD_ID)

### 6.6 Mask composition theory: Linear algebra of attention constraints

**Composition operator:**  
Multiple masks must be combined via the **Hadamard composition** in log-space. Given masks $M_1, M_2, \ldots, M_k$, the composed mask is:

$$
M_{\text{composed}} = M_1 \oplus M_2 \oplus \cdots \oplus M_k
$$

where $\oplus$ is the **logical OR** operation in $\{0, -\infty\}$ space:

$$
a \oplus b = \begin{cases}
0 & \text{if } a = 0 \text{ AND } b = 0 \\
-\infty & \text{otherwise}
\end{cases}
$$

**Theorem 6.4 (Mask Commutativity):**  
*The mask composition operator is commutative and associative:*

$$
M_1 \oplus M_2 = M_2 \oplus M_1, \quad (M_1 \oplus M_2) \oplus M_3 = M_1 \oplus (M_2 \oplus M_3)
$$

**Broadcasting mathematics:**  
Consider the dimensional analysis:
- Causal mask: $C \in \{0, -\infty\}^{L \times L}$
- Padding mask: $P \in \{0, -\infty\}^{B \times L}$  
- Attention scores: $S \in \mathbb{R}^{B \times H \times L \times L}$

**Broadcasting transformation:**
$$
\begin{align}
C_{\text{broadcast}} &: (L, L) \to (1, 1, L, L) \to (B, H, L, L) \\
P_{\text{broadcast}} &: (B, L) \to (B, 1, 1, L) \to (B, H, L, L)
\end{align}
$$

**Final mask application:**
$$
\tilde{S}_{bhij} = S_{bhij} + C_{\text{broadcast}}[b,h,i,j] + P_{\text{broadcast}}[b,h,i,j]
$$

**Theorem 6.5 (Attention Conservation):**  
*Under proper masking, attention weights satisfy:*

$$
\sum_{j \in \mathcal{V}_{bi}} A_{bhij} = 1 \quad \forall b,h,i
$$

*where $\mathcal{V}_{bi} = \{j : j \leq i \text{ AND } \mathbf{t}^{(b)}_j \neq \text{<pad>}\}$ is the valid attention set.*

**Proof:**  
The softmax normalization over the unmasked set ensures unit mass distribution.

In [None]:
def apply_attention_masks(scores, causal_mask=None, key_padding_mask=None, head_mask=None):
    """
    Apply multiple attention masks with rigorous dimensional analysis.
    
    Mathematical operation:
        scores_masked = scores + causal_mask_broadcasted + padding_mask_broadcasted
    
    Args:
        scores (torch.Tensor): Shape (B, H, L, L) - raw attention scores
        causal_mask (torch.Tensor): Shape (L, L) - causal constraints
        key_padding_mask (torch.Tensor): Shape (B, L) - padding constraints  
        head_mask (torch.Tensor): Shape (H,) - head enable/disable
    
    Returns:
        torch.Tensor: Masked scores ready for softmax
        dict: Broadcasting statistics and verification info
    """
    B, H, L_q, L_k = scores.shape
    original_scores = scores.clone()
    
    # Verification info
    mask_info = {
        'original_shape': scores.shape,
        'masks_applied': [],
        'broadcasting_ops': []
    }
    
    # Apply causal mask
    if causal_mask is not None:
        assert causal_mask.shape == (L_q, L_k), f"Causal mask shape {causal_mask.shape} != ({L_q}, {L_k})"
        
        # Broadcasting: (L, L) -> (1, 1, L, L) -> (B, H, L, L)
        causal_broadcasted = causal_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, L, L)
        scores = scores + causal_broadcasted
        
        mask_info['masks_applied'].append('causal')
        mask_info['broadcasting_ops'].append(f"causal: {causal_mask.shape} -> {causal_broadcasted.shape}")
    
    # Apply key-padding mask
    if key_padding_mask is not None:
        assert key_padding_mask.shape == (B, L_k), f"Padding mask shape {key_padding_mask.shape} != ({B}, {L_k})"
        
        # Broadcasting: (B, L) -> (B, 1, 1, L) -> (B, H, L, L)
        # We expand over the query dimension (dim=2) since we mask keys
        padding_broadcasted = key_padding_mask.unsqueeze(1).unsqueeze(1)  # (B, 1, 1, L)
        
        # Convert boolean mask to additive mask: True -> -inf, False -> 0
        padding_additive = torch.where(padding_broadcasted, 
                                     torch.tensor(float('-inf'), dtype=scores.dtype, device=scores.device),
                                     torch.tensor(0.0, dtype=scores.dtype, device=scores.device))
        
        scores = scores + padding_additive
        
        mask_info['masks_applied'].append('key_padding')
        mask_info['broadcasting_ops'].append(f"padding: {key_padding_mask.shape} -> {padding_additive.shape}")
    
    # Apply head mask (after softmax, so we'll return it for later application)
    if head_mask is not None:
        assert head_mask.shape == (H,), f"Head mask shape {head_mask.shape} != ({H},)"
        mask_info['masks_applied'].append('head')
        mask_info['head_mask'] = head_mask
    
    return scores, mask_info

def verify_mask_application(original_scores, masked_scores, mask_info):
    """
    Verify that mask application preserves mathematical properties.
    """
    print("=== Mask Application Verification ===")
    
    # Check that finite values are preserved where no mask is applied
    finite_original = torch.isfinite(original_scores)
    finite_masked = torch.isfinite(masked_scores)
    
    print(f"Original finite values: {finite_original.sum().item()}")
    print(f"Masked finite values: {finite_masked.sum().item()}")
    print(f"Masks applied: {mask_info['masks_applied']}")
    print(f"Broadcasting operations: {mask_info['broadcasting_ops']}")
    
    # Verify that -inf values are introduced only where intended
    new_neginf = torch.isneginf(masked_scores) & ~torch.isneginf(original_scores)
    print(f"New -inf positions introduced: {new_neginf.sum().item()}")
    
    return True

# Demonstrate mask application
B, H, L = batch_tokens.shape[0], 2, batch_tokens.shape[1]

# Create dummy attention scores (small values for numerical stability)
demo_scores = torch.randn(B, H, L, L) * 0.1

print("Original scores (batch 0, head 0, first 5×5):")
print(demo_scores[0, 0, :5, :5])

# Apply both causal and padding masks
masked_scores, mask_info = apply_attention_masks(
    demo_scores, 
    causal_mask=causal_mask, 
    key_padding_mask=key_padding_mask
)

print("\nAfter applying causal + padding masks:")
print(masked_scores[0, 0, :5, :5])

# Verify the application
verify_mask_application(demo_scores, masked_scores, mask_info)