# Module 7b: The Transformer Architecture

## Prerequisites
- **Module 7a: Self-Attention** -- you should be comfortable with scaled dot-product attention and multi-head attention before starting this module.

## Overview

In Module 7a we built the core **self-attention** mechanism from scratch. In this notebook we assemble the remaining pieces needed to go from a single attention operation to a **complete Transformer architecture**:

1. **Positional Encoding** -- injecting sequence-order information
2. **Layer Normalization** -- stabilizing training
3. **Residual (Skip) Connections** -- enabling gradient flow
4. **Position-wise Feed-Forward Network** -- adding non-linear capacity
5. **Encoder Block** -- combining all of the above
6. **Decoder Block** -- adding causal masking and cross-attention
7. **Architecture Variants** -- BERT, GPT, T5 and how they differ
8. **Pre-training Objectives** -- MLM, CLM, span corruption

By the end you will have a working, from-scratch Transformer encoder that you can trace tensor shapes through from input to output.

---

## 1. Setup

In [None]:
!pip install -q torch numpy matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

---

## 2. Positional Encoding

### Why do we need positional information?

Self-attention is **permutation-invariant**: if you shuffle the input tokens, each token's attention output changes (because it now attends to different positions), but the *set* of operations is the same -- there is nothing in the attention equations themselves that distinguishes "position 0" from "position 5". In contrast, RNNs process tokens sequentially and naturally encode order.

To give the Transformer a sense of **where** each token sits in the sequence, the original paper (Vaswani et al., 2017) adds a **positional encoding** vector to each token embedding before it enters the first layer.

### Sinusoidal Positional Encoding

The original Transformer uses fixed sinusoidal functions:

$$PE_{(pos, 2i)} = \sin\!\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$$

$$PE_{(pos, 2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)$$

where:
- `pos` is the position in the sequence (0, 1, 2, ...)
- `i` is the dimension index (0, 1, 2, ..., d_model/2 - 1)
- `d_model` is the embedding dimension

**Intuition:** Each dimension oscillates at a different frequency. Low-index dimensions change slowly (long wavelengths), high-index dimensions change rapidly. This creates a unique "fingerprint" for every position, and relative positions can be represented as linear functions of the encodings.

### Learned Positional Embeddings

An alternative (used in BERT, GPT-2, etc.) is to **learn** a positional embedding matrix of shape `(max_seq_len, d_model)`. This is simpler to implement (`nn.Embedding`) but cannot extrapolate to sequence lengths longer than those seen during training. The sinusoidal approach generalizes to arbitrary lengths.

In [None]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding from 'Attention Is All You Need'."""

    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create a matrix of shape (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        # Compute the div_term: 10000^(2i/d_model) -- use log-space for numerical stability
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model)
        )  # (d_model/2,)

        pe[:, 0::2] = torch.sin(position * div_term)  # even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # odd indices

        pe = pe.unsqueeze(0)  # (1, max_len, d_model) -- batch dimension
        self.register_buffer('pe', pe)  # not a parameter, but should move with .to(device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Tensor of same shape with positional encoding added.
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


# Quick test
d_model = 64
seq_len = 20
pe_module = PositionalEncoding(d_model=d_model, dropout=0.0)

# Feed in zeros so we can see the raw positional encodings
dummy = torch.zeros(1, seq_len, d_model)
encoded = pe_module(dummy)
print(f"Input shape:  {dummy.shape}")
print(f"Output shape: {encoded.shape}")
print(f"PE values at position 0: {encoded[0, 0, :8]}")
print(f"PE values at position 1: {encoded[0, 1, :8]}")

### Visualizing Positional Encodings

Let's create a heatmap showing how the encoding values vary across positions and dimensions.

In [None]:
# Visualize positional encodings
pe_vis = PositionalEncoding(d_model=128, max_len=100, dropout=0.0)
pe_values = pe_vis(torch.zeros(1, 100, 128))[0].detach().numpy()  # (100, 128)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap
im = axes[0].imshow(pe_values, aspect='auto', cmap='RdBu', interpolation='nearest')
axes[0].set_xlabel('Embedding Dimension')
axes[0].set_ylabel('Position')
axes[0].set_title('Sinusoidal Positional Encoding Heatmap')
plt.colorbar(im, ax=axes[0])

# A few individual dimensions over position
for dim_idx in [0, 1, 4, 5, 20, 21]:
    axes[1].plot(pe_values[:, dim_idx], label=f'dim {dim_idx}')
axes[1].set_xlabel('Position')
axes[1].set_ylabel('Encoding Value')
axes[1].set_title('Positional Encoding by Dimension')
axes[1].legend(fontsize=8)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Notice: lower dimensions oscillate slowly, higher dimensions oscillate rapidly.")
print("Each position gets a unique encoding vector.")

---

## Exercise 1: Implement Sinusoidal Positional Encoding + Visualize

Implement the `PositionalEncodingExercise` class below. Fill in the `TODO` placeholders.

**Requirements:**
1. Compute the sinusoidal positional encoding matrix of shape `(max_len, d_model)`
2. Apply `sin` to even indices and `cos` to odd indices
3. In `forward`, add the positional encoding to the input
4. Create a heatmap visualization of the result

In [None]:
class PositionalEncodingExercise(nn.Module):
    """Exercise: Implement sinusoidal positional encoding from scratch."""

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)

        # TODO: Compute div_term using the formula: exp(arange(0, d_model, 2) * (-log(10000) / d_model))
        div_term = None

        # TODO: Fill even indices (0, 2, 4, ...) with sin(position * div_term)
        # pe[:, 0::2] = ...

        # TODO: Fill odd indices (1, 3, 5, ...) with cos(position * div_term)
        # pe[:, 1::2] = ...

        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: Add positional encoding to x (only up to x's sequence length)
        return None


# TODO: Create an instance with d_model=64, max_len=50
# Pass zeros through it and visualize as a heatmap
# pe_ex = ...
# result = ...

# TODO: Create heatmap using plt.imshow
# plt.figure(figsize=(10, 4))
# plt.imshow(...)
# plt.show()

### Solution

In [None]:
class PositionalEncodingExercise(nn.Module):
    """Solution: Sinusoidal positional encoding from scratch."""

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)

        # Compute div_term in log-space for numerical stability
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float) * (-math.log(10000.0) / d_model)
        )

        # Even indices get sin, odd indices get cos
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, :x.size(1), :]
        return x


# Create and visualize
pe_ex = PositionalEncodingExercise(d_model=64, max_len=50)
result = pe_ex(torch.zeros(1, 50, 64))
pe_data = result[0].detach().numpy()

plt.figure(figsize=(10, 4))
plt.imshow(pe_data, aspect='auto', cmap='RdBu', interpolation='nearest')
plt.colorbar(label='Encoding Value')
plt.xlabel('Embedding Dimension')
plt.ylabel('Token Position')
plt.title('Exercise 1 Solution: Sinusoidal Positional Encoding')
plt.tight_layout()
plt.show()

print(f"Output shape: {result.shape}")
print(f"Each position has a unique encoding. Position 0 != Position 1:")
print(f"  Position 0: {pe_data[0, :6].round(3)}")
print(f"  Position 1: {pe_data[1, :6].round(3)}")

---

## 3. Layer Normalization

### Why Normalization Helps Training

Deep networks suffer from **internal covariate shift**: the distribution of each layer's inputs changes as the parameters of preceding layers change during training. Normalization techniques address this by re-centering and re-scaling activations, which:

- Stabilizes and accelerates training
- Reduces sensitivity to initialization and learning rate
- Acts as a mild regularizer

### Layer Norm vs Batch Norm

| Property | Batch Norm | Layer Norm |
|----------|-----------|------------|
| Normalizes across | Batch dimension (for each feature) | Feature dimension (for each sample) |
| Statistics | Mean/var over batch | Mean/var over features |
| Depends on batch size | Yes | No |
| Works for variable-length sequences | Poorly | Well |
| Used in Transformers | Rarely | Always |

**Layer Norm** computes statistics across the **feature dimension** for each individual sample independently. This makes it ideal for Transformers where:
- Batch sizes may be small
- Sequence lengths vary
- We need inference with batch_size=1

### Formula

For a vector $\mathbf{x}$ of dimension $d$:

$$\text{LayerNorm}(\mathbf{x}) = \gamma \odot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

where $\mu = \frac{1}{d}\sum_i x_i$, $\sigma^2 = \frac{1}{d}\sum_i (x_i - \mu)^2$, and $\gamma$, $\beta$ are learnable scale/shift parameters.

In [None]:
class LayerNorm(nn.Module):
    """Layer Normalization implemented from scratch."""

    def __init__(self, d_model: int, eps: float = 1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))   # learnable scale
        self.beta = nn.Parameter(torch.zeros(d_model))   # learnable shift
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x shape: (batch_size, seq_len, d_model)
        mean = x.mean(dim=-1, keepdim=True)          # mean over d_model
        var = x.var(dim=-1, keepdim=True, unbiased=False)  # variance over d_model
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta


# Compare with PyTorch's built-in LayerNorm
d_model = 64
x = torch.randn(2, 10, d_model)  # (batch=2, seq_len=10, d_model=64)

our_ln = LayerNorm(d_model)
pytorch_ln = nn.LayerNorm(d_model)

# Copy weights so we can compare outputs
pytorch_ln.weight = nn.Parameter(our_ln.gamma.data.clone())
pytorch_ln.bias = nn.Parameter(our_ln.beta.data.clone())

out_ours = our_ln(x)
out_pytorch = pytorch_ln(x)

print(f"Input shape: {x.shape}")
print(f"Our LayerNorm output shape: {out_ours.shape}")
print(f"Max absolute difference: {(out_ours - out_pytorch).abs().max().item():.2e}")
print(f"Outputs match: {torch.allclose(out_ours, out_pytorch, atol=1e-5)}")

# Verify normalization: mean should be ~0, std should be ~1
print(f"\nAfter LayerNorm (sample 0, position 0):")
print(f"  Mean: {out_ours[0, 0].mean().item():.6f}")
print(f"  Std:  {out_ours[0, 0].std().item():.4f}")

---

## 4. Residual Connections

### Skip Connections

In the Transformer each sub-layer (attention, feed-forward) is wrapped with a **residual connection** followed by layer normalization:

$$\text{output} = \text{LayerNorm}(x + \text{Sublayer}(x))$$

### Why Residual Connections Help

1. **Gradient flow**: Without skip connections, gradients must flow through every layer sequentially. With skip connections, gradients have a "highway" that bypasses layers, reducing vanishing gradients.
2. **Training stability**: The identity path ensures that, at worst, a layer can learn to do nothing (the identity function), rather than being forced to learn a complex transformation from scratch.
3. **Deeper networks**: Residual connections are what make it practical to train Transformers with 12, 24, or even 100+ layers.

In [None]:
class SublayerConnection(nn.Module):
    """Residual connection followed by layer normalization.
    
    Implements: LayerNorm(x + Sublayer(x))
    """

    def __init__(self, d_model: int, dropout: float = 0.1):
        super().__init__()
        self.norm = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, sublayer_output: torch.Tensor) -> torch.Tensor:
        """Apply residual connection: LayerNorm(x + Dropout(sublayer_output))"""
        return self.norm(x + self.dropout(sublayer_output))


# Demonstrate gradient flow with and without residual connections
print("=== Gradient Flow: With vs Without Residual Connections ===")
print()

def simulate_gradient_flow(n_layers, use_residual=True):
    """Simulate forward/backward pass and measure gradient magnitude at each layer."""
    torch.manual_seed(42)
    d = 64
    x = torch.randn(1, d, requires_grad=True)

    layers = [nn.Linear(d, d) for _ in range(n_layers)]
    # Initialize with small weights to simulate deep network issues
    for layer in layers:
        nn.init.normal_(layer.weight, std=0.5)

    # Forward pass
    activations = [x]
    h = x
    for layer in layers:
        out = torch.tanh(layer(h))
        if use_residual:
            h = h + out  # residual connection
        else:
            h = out
        activations.append(h)

    # Backward pass
    loss = h.sum()
    loss.backward()

    return x.grad.norm().item()


layer_counts = [1, 5, 10, 20, 50]
grads_no_res = [simulate_gradient_flow(n, use_residual=False) for n in layer_counts]
grads_with_res = [simulate_gradient_flow(n, use_residual=True) for n in layer_counts]

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(layer_counts, grads_no_res, 'ro-', label='Without Residual', linewidth=2)
ax.plot(layer_counts, grads_with_res, 'bs-', label='With Residual', linewidth=2)
ax.set_xlabel('Number of Layers')
ax.set_ylabel('Gradient Norm at Input')
ax.set_title('Gradient Flow: Residual vs Non-Residual Networks')
ax.legend()
ax.set_yscale('log')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("Without residuals, gradients vanish as depth increases.")
print("With residuals, gradients remain healthy even in deep networks.")

---

## 5. Feed-Forward Network

Each Transformer layer contains a **position-wise feed-forward network** (FFN) applied independently to each position:

$$\text{FFN}(x) = \max(0,\, xW_1 + b_1)\, W_2 + b_2$$

Key properties:
- **Two linear transformations** with a ReLU activation in between
- **Expansion factor**: The inner dimension `d_ff` is typically **4x** the model dimension. For example, if `d_model=512`, then `d_ff=2048`.
- **Position-wise**: The same FFN is applied to every position independently (like a 1x1 convolution).
- **Purpose**: Adds non-linear transformation capacity. Attention alone is essentially a weighted average (linear), so the FFN provides the network's non-linear processing power.

In [None]:
class FeedForward(nn.Module):
    """Position-wise Feed-Forward Network.
    
    FFN(x) = max(0, xW1 + b1)W2 + b2
    """

    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, seq_len, d_model)
        return self.linear2(self.dropout(F.relu(self.linear1(x))))


# Test the feed-forward network
d_model = 64
d_ff = 256  # 4x expansion
ffn = FeedForward(d_model, d_ff, dropout=0.0)

x = torch.randn(2, 10, d_model)  # (batch=2, seq_len=10, d_model=64)
out = ffn(x)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {out.shape}")
print(f"\nFFN parameters:")
print(f"  W1: {d_model} -> {d_ff} = {d_model * d_ff:,} weights + {d_ff} bias")
print(f"  W2: {d_ff} -> {d_model} = {d_ff * d_model:,} weights + {d_model} bias")
total_params = sum(p.numel() for p in ffn.parameters())
print(f"  Total: {total_params:,} parameters")

---

## 6. Multi-Head Attention (Recap from 7a)

Before building the encoder block, let's define a clean multi-head attention module. This is the same mechanism from Module 7a, packaged as an `nn.Module`.

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-Head Attention mechanism.
    
    Splits d_model into n_heads parallel attention heads,
    each with dimension d_k = d_model / n_heads.
    """

    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        # Linear projections for Q, K, V, and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """Compute scaled dot-product attention.
        
        Args:
            Q: (batch, n_heads, seq_len, d_k)
            K: (batch, n_heads, seq_len, d_k)
            V: (batch, n_heads, seq_len, d_k)
            mask: optional mask tensor
        Returns:
            output: (batch, n_heads, seq_len, d_k)
            attn_weights: (batch, n_heads, seq_len, seq_len)
        """
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        output = torch.matmul(attn_weights, V)
        return output, attn_weights

    def forward(self, query, key, value, mask=None):
        """Forward pass.
        
        Args:
            query: (batch, seq_len_q, d_model)
            key:   (batch, seq_len_k, d_model)
            value: (batch, seq_len_v, d_model)
            mask:  optional mask
        Returns:
            output: (batch, seq_len_q, d_model)
            attn_weights: (batch, n_heads, seq_len_q, seq_len_k)
        """
        batch_size = query.size(0)

        # 1. Linear projections
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)

        # 2. Reshape to (batch, n_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # 3. Scaled dot-product attention
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # 4. Concatenate heads: (batch, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        # 5. Final linear projection
        output = self.W_o(attn_output)

        return output, attn_weights


# Quick test
mha = MultiHeadAttention(d_model=64, n_heads=4, dropout=0.0)
x = torch.randn(2, 10, 64)
out, weights = mha(x, x, x)
print(f"Multi-Head Attention:")
print(f"  Input:   {x.shape}")
print(f"  Output:  {out.shape}")
print(f"  Weights: {weights.shape}  (batch, heads, seq_q, seq_k)")

---

## 7. Complete Encoder Block

Now we assemble all components into a single **Transformer Encoder Block**:

```
Input (batch, seq_len, d_model)
  |
  |---> Multi-Head Self-Attention
  |         |
  +-----> Add (residual)
            |
          LayerNorm
            |
            |---> Feed-Forward Network
            |         |
            +-----> Add (residual)
                      |
                    LayerNorm
                      |
                    Output (batch, seq_len, d_model)
```

Each encoder block preserves the input shape, so blocks can be stacked.

In [None]:
class TransformerEncoderBlock(nn.Module):
    """A single Transformer Encoder block.
    
    Components:
        1. Multi-Head Self-Attention + Add & Norm
        2. Feed-Forward Network + Add & Norm
    """

    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()

        # Sub-layer 1: Multi-Head Self-Attention
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        # Sub-layer 2: Feed-Forward Network
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask=None):
        """Forward pass.
        
        Args:
            x: (batch_size, seq_len, d_model)
            mask: optional attention mask
        Returns:
            (batch_size, seq_len, d_model)
        """
        # Sub-layer 1: Self-Attention + Add & Norm
        attn_output, attn_weights = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))  # residual + norm

        # Sub-layer 2: FFN + Add & Norm
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_output))   # residual + norm

        return x


# Test the encoder block
d_model = 64
n_heads = 4
d_ff = 256

encoder_block = TransformerEncoderBlock(d_model, n_heads, d_ff, dropout=0.0)

x = torch.randn(2, 10, d_model)  # (batch=2, seq_len=10, d_model=64)
out = encoder_block(x)

print(f"Encoder Block Test:")
print(f"  Input shape:  {x.shape}")
print(f"  Output shape: {out.shape}")
print(f"  Shape preserved: {x.shape == out.shape}")

total_params = sum(p.numel() for p in encoder_block.parameters())
print(f"\n  Total parameters: {total_params:,}")
print(f"\n  Parameter breakdown:")
for name, param in encoder_block.named_parameters():
    print(f"    {name}: {param.shape} ({param.numel():,})")

---

## Exercise 2: Build a Transformer Encoder Block from Scratch

Implement a complete Transformer encoder block as an `nn.Module`. You should use the `MultiHeadAttention`, `FeedForward`, and `nn.LayerNorm` classes.

**Requirements:**
1. Self-attention with residual connection and layer norm
2. Feed-forward with residual connection and layer norm
3. The output shape must match the input shape

In [None]:
class TransformerEncoderBlockExercise(nn.Module):
    """Exercise: Implement a Transformer Encoder Block."""

    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()

        # TODO: Create multi-head attention layer
        self.self_attn = None

        # TODO: Create first layer norm
        self.norm1 = None

        # TODO: Create feed-forward network
        self.ffn = None

        # TODO: Create second layer norm
        self.norm2 = None

        # TODO: Create dropout layers
        self.dropout1 = None
        self.dropout2 = None

    def forward(self, x: torch.Tensor, mask=None):
        # TODO: Sub-layer 1 -- Self-Attention + Add & Norm
        # attn_output, _ = self.self_attn(...)
        # x = self.norm1(...)

        # TODO: Sub-layer 2 -- Feed-Forward + Add & Norm
        # ffn_output = self.ffn(...)
        # x = self.norm2(...)

        return None


# TODO: Test your implementation
# enc = TransformerEncoderBlockExercise(d_model=64, n_heads=4, d_ff=256, dropout=0.0)
# test_input = torch.randn(2, 10, 64)
# test_output = enc(test_input)
# print(f"Input: {test_input.shape}, Output: {test_output.shape}")
# assert test_input.shape == test_output.shape, "Shape mismatch!"

### Solution

In [None]:
class TransformerEncoderBlockExercise(nn.Module):
    """Solution: Complete Transformer Encoder Block."""

    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()

        # Multi-head self-attention
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)

        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Feed-forward network
        self.ffn = FeedForward(d_model, d_ff, dropout)

        # Dropout
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask=None):
        # Sub-layer 1: Self-Attention + Add & Norm
        attn_output, attn_weights = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))

        # Sub-layer 2: Feed-Forward + Add & Norm
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_output))

        return x


# Test the solution
enc = TransformerEncoderBlockExercise(d_model=64, n_heads=4, d_ff=256, dropout=0.0)
test_input = torch.randn(2, 10, 64)
test_output = enc(test_input)

print(f"Input shape:  {test_input.shape}")
print(f"Output shape: {test_output.shape}")
assert test_input.shape == test_output.shape, "Shape mismatch!"
print("Shape check passed!")

# Stack multiple blocks to form a deeper encoder
n_layers = 3
encoder_blocks = nn.ModuleList([
    TransformerEncoderBlockExercise(d_model=64, n_heads=4, d_ff=256, dropout=0.0)
    for _ in range(n_layers)
])

h = test_input
for i, block in enumerate(encoder_blocks):
    h = block(h)
    print(f"After block {i}: {h.shape}")

print(f"\nFinal output shape: {h.shape}")
print(f"Total parameters in {n_layers}-layer encoder: {sum(p.numel() for p in encoder_blocks.parameters()):,}")

---

## 8. Decoder Block

The Transformer **decoder** block is similar to the encoder block but with two key additions:

1. **Masked Self-Attention (Causal Masking):** The decoder must not look at future tokens during generation. We apply a causal mask so that position $i$ can only attend to positions $\leq i$.

2. **Cross-Attention:** After masked self-attention, the decoder attends to the **encoder output**. Here, the queries come from the decoder but the keys and values come from the encoder.

```
Decoder Input (batch, tgt_len, d_model)
  |
  |---> Masked Multi-Head Self-Attention    <-- causal mask
  |         |
  +-----> Add & Norm
            |
            |---> Multi-Head Cross-Attention <-- Q from decoder, K,V from encoder
            |         |
            +-----> Add & Norm
                      |
                      |---> Feed-Forward
                      |         |
                      +-----> Add & Norm
                                |
                              Output (batch, tgt_len, d_model)
```

In [None]:
def create_causal_mask(seq_len: int) -> torch.Tensor:
    """Create a causal (lower-triangular) mask.
    
    Returns a mask of shape (1, 1, seq_len, seq_len) where:
        mask[..., i, j] = 1 if j <= i (allowed)
        mask[..., i, j] = 0 if j > i  (blocked / future)
    """
    mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
    return mask  # (1, 1, seq_len, seq_len)


# Visualize the causal mask
seq_len = 8
causal_mask = create_causal_mask(seq_len)
print(f"Causal mask shape: {causal_mask.shape}")
print(f"Causal mask:\n{causal_mask[0, 0].int()}")
print()
print("1 = can attend, 0 = cannot attend (future token)")

In [None]:
class TransformerDecoderBlock(nn.Module):
    """A single Transformer Decoder block.
    
    Components:
        1. Masked Multi-Head Self-Attention + Add & Norm
        2. Multi-Head Cross-Attention + Add & Norm
        3. Feed-Forward Network + Add & Norm
    """

    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()

        # Sub-layer 1: Masked Self-Attention
        self.masked_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        # Sub-layer 2: Cross-Attention (decoder queries, encoder keys/values)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

        # Sub-layer 3: Feed-Forward
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """Forward pass.
        
        Args:
            x:              (batch, tgt_len, d_model) -- decoder input
            encoder_output: (batch, src_len, d_model) -- encoder output
            src_mask:       mask for cross-attention (optional)
            tgt_mask:       causal mask for self-attention (optional)
        Returns:
            (batch, tgt_len, d_model)
        """
        # Sub-layer 1: Masked Self-Attention
        self_attn_out, _ = self.masked_self_attn(x, x, x, mask=tgt_mask)
        x = self.norm1(x + self.dropout1(self_attn_out))

        # Sub-layer 2: Cross-Attention
        # Query from decoder, Key and Value from encoder
        cross_attn_out, _ = self.cross_attn(x, encoder_output, encoder_output, mask=src_mask)
        x = self.norm2(x + self.dropout2(cross_attn_out))

        # Sub-layer 3: Feed-Forward
        ffn_out = self.ffn(x)
        x = self.norm3(x + self.dropout3(ffn_out))

        return x


# Test the decoder block
d_model = 64
n_heads = 4
d_ff = 256

decoder_block = TransformerDecoderBlock(d_model, n_heads, d_ff, dropout=0.0)

# Encoder output (e.g., from processing a source sentence)
src_len = 12
encoder_output = torch.randn(2, src_len, d_model)

# Decoder input (e.g., target sentence generated so far)
tgt_len = 8
decoder_input = torch.randn(2, tgt_len, d_model)

# Create causal mask for decoder self-attention
tgt_mask = create_causal_mask(tgt_len)

decoder_output = decoder_block(decoder_input, encoder_output, tgt_mask=tgt_mask)

print(f"Decoder Block Test:")
print(f"  Encoder output shape: {encoder_output.shape}  (source)")
print(f"  Decoder input shape:  {decoder_input.shape}   (target)")
print(f"  Decoder output shape: {decoder_output.shape}  (same as decoder input)")
print(f"  Causal mask shape:    {tgt_mask.shape}")

total_params = sum(p.numel() for p in decoder_block.parameters())
print(f"\n  Total parameters: {total_params:,}")
print(f"  (More than encoder block because of the extra cross-attention sub-layer)")

---

## 9. Architecture Variants

The original Transformer is an **encoder-decoder** model. In practice, three major variants have emerged, each using a different subset of the architecture:

### BERT: Encoder-Only (Bidirectional)

```
  Input: [CLS] The cat sat on the mat [SEP]
           |    |   |   |   |   |   |    |
         +------------------------------------+
         |     Transformer Encoder Block      |
         |     (bidirectional attention)       |
         |            x N layers              |
         +------------------------------------+
           |    |   |   |   |   |   |    |
          h0   h1  h2  h3  h4  h5  h6   h7
           |                              |
       [CLS] for                      Token-level
     classification               representations
```

- **Attention**: Each token attends to ALL other tokens (bidirectional)
- **Pre-training**: Masked Language Modeling (MLM) + Next Sentence Prediction
- **Use cases**: Text classification, NER, question answering, semantic similarity
- **Models**: BERT, RoBERTa, ALBERT, DistilBERT, DeBERTa

### GPT: Decoder-Only (Causal/Autoregressive)

```
  Input: The  cat  sat  on   the  mat
          |    |    |    |    |    |
        +-------------------------------+
        |  Transformer Decoder Block    |
        |  (causal/masked attention)    |
        |       x N layers              |
        +-------------------------------+
          |    |    |    |    |    |
         cat  sat  on   the  mat  [END]
         (predict next token at each step)
```

- **Attention**: Each token attends only to previous tokens (causal mask)
- **Pre-training**: Causal Language Modeling (predict next token)
- **Use cases**: Text generation, code generation, chatbots, reasoning
- **Models**: GPT-2, GPT-3, GPT-4, LLaMA, Claude, Mistral

### T5: Encoder-Decoder (Sequence-to-Sequence)

```
  Input: translate English to French: The cat sat on the mat
          |   |   |   |   |   |   |   |   |   |    |   |
        +----------------------------------------------+
        |           Transformer Encoder                |
        |       (bidirectional attention)               |
        +----------------------------------------------+
                          |
                    encoder output
                          |
                          v
        +----------------------------------------------+
        |           Transformer Decoder                |
        |  (causal self-attn + cross-attn to encoder)  |
        +----------------------------------------------+
          |      |      |      |      |       |
         Le    chat   s'est  assis   sur      le tapis
```

- **Attention**: Encoder = bidirectional; Decoder = causal self-attention + cross-attention to encoder
- **Pre-training**: Span corruption (mask random spans, reconstruct them)
- **Use cases**: Translation, summarization, question answering
- **Models**: T5, BART, mBART, FLAN-T5

### Comparison Table

| Feature | BERT (Encoder) | GPT (Decoder) | T5 (Enc-Dec) |
|---------|---------------|---------------|---------------|
| Attention Direction | Bidirectional | Causal (left-to-right) | Encoder: bi, Decoder: causal |
| Pre-training | MLM + NSP | Next token prediction | Span corruption |
| Generation | Not naturally | Excellent | Excellent |
| Understanding | Excellent | Good | Excellent |
| Cross-Attention | No | No | Yes |
| Parameters (base) | 110M | 117M (GPT-2) | 220M |
| Typical Use | Classification, NER | Generation, chat | Translation, summarization |

---

## Exercise 3: Compare BERT and GPT Attention Patterns

Create the attention masks for:
1. **BERT-style** (bidirectional): every token can attend to every other token
2. **GPT-style** (causal): each token can only attend to itself and previous tokens

Visualize them side-by-side as heatmaps.

In [None]:
seq_len = 8
tokens = ["The", "cat", "sat", "on", "the", "mat", ".", "[END]"]

# TODO: Create BERT-style attention mask (all ones -- every token attends to every token)
# bert_mask = ...
bert_mask = None

# TODO: Create GPT-style causal attention mask (lower triangular -- no future tokens)
# gpt_mask = ...
gpt_mask = None

# TODO: Create side-by-side heatmap visualization
# fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Plot bert_mask on axes[0] with title "BERT (Bidirectional)"
# Plot gpt_mask on axes[1] with title "GPT (Causal)"
# Set tick labels to the tokens list
# plt.show()

### Solution

In [None]:
seq_len = 8
tokens = ["The", "cat", "sat", "on", "the", "mat", ".", "[END]"]

# BERT: bidirectional -- all tokens attend to all tokens
bert_mask = torch.ones(seq_len, seq_len)

# GPT: causal -- lower triangular (each token attends only to itself and past)
gpt_mask = torch.tril(torch.ones(seq_len, seq_len))

# Visualize side-by-side
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# BERT mask
im0 = axes[0].imshow(bert_mask.numpy(), cmap='Blues', vmin=0, vmax=1)
axes[0].set_title('BERT (Bidirectional) Attention Mask', fontsize=13)
axes[0].set_xlabel('Key Position (attending to)')
axes[0].set_ylabel('Query Position (attending from)')
axes[0].set_xticks(range(seq_len))
axes[0].set_yticks(range(seq_len))
axes[0].set_xticklabels(tokens, rotation=45, ha='right', fontsize=9)
axes[0].set_yticklabels(tokens, fontsize=9)

# GPT mask
im1 = axes[1].imshow(gpt_mask.numpy(), cmap='Oranges', vmin=0, vmax=1)
axes[1].set_title('GPT (Causal) Attention Mask', fontsize=13)
axes[1].set_xlabel('Key Position (attending to)')
axes[1].set_ylabel('Query Position (attending from)')
axes[1].set_xticks(range(seq_len))
axes[1].set_yticks(range(seq_len))
axes[1].set_xticklabels(tokens, rotation=45, ha='right', fontsize=9)
axes[1].set_yticklabels(tokens, fontsize=9)

# Add cell values
for ax, mask in [(axes[0], bert_mask), (axes[1], gpt_mask)]:
    for i in range(seq_len):
        for j in range(seq_len):
            val = int(mask[i, j].item())
            color = 'white' if val == 1 else 'black'
            ax.text(j, i, str(val), ha='center', va='center', fontsize=8, color=color)

plt.tight_layout()
plt.show()

print("BERT: All positions are 1 -- every token sees the full context.")
print("GPT:  Lower triangle is 1 -- each token only sees past and present.")
print(f"\nBERT total connections: {int(bert_mask.sum().item())} (all-to-all)")
print(f"GPT  total connections: {int(gpt_mask.sum().item())} (causal)")

---

## 10. Pre-training Objectives

Each architecture variant uses a different **pre-training objective** -- the task the model is trained on before fine-tuning:

### Masked Language Modeling (MLM) -- BERT

Randomly mask 15% of input tokens and train the model to predict them:

```
Input:  The [MASK] sat on the [MASK]
Target:      cat              mat
```

Of the 15% selected tokens:
- 80% are replaced with `[MASK]`
- 10% are replaced with a random token
- 10% are left unchanged

This forces bidirectional understanding: to predict a masked word, the model must use both left and right context.

### Causal Language Modeling (CLM) -- GPT

Predict the next token given all previous tokens:

```
Input:  The  cat  sat  on   the
Target: cat  sat  on   the  mat
```

The loss is computed at every position, making training very efficient. This naturally produces a model that can **generate** text by sampling one token at a time.

### Span Corruption -- T5

Replace random contiguous spans with sentinel tokens, and train the model to output the missing spans:

```
Input:  The <X> on the <Y>
Target: <X> cat sat <Y> mat <END>
```

This is a seq2seq objective: the encoder processes the corrupted input and the decoder generates the missing pieces. It combines benefits of both MLM (bidirectional context in encoder) and CLM (autoregressive generation in decoder).

### Summary

| Objective | Model | Masks | Predicts | Direction |
|-----------|-------|-------|----------|-----------|
| MLM | BERT | Random 15% of tokens | Masked tokens | Bidirectional |
| CLM | GPT | Causal (future) | Next token | Left-to-right |
| Span Corruption | T5 | Random spans | Missing spans | Encoder: bi; Decoder: L-to-R |

---

## 11. Putting It All Together

Let's build a minimal **complete Transformer Encoder** that processes a sequence through all stages, and trace the tensor shapes at every step.

In [None]:
class TransformerEncoder(nn.Module):
    """Complete Transformer Encoder.
    
    Pipeline:
        Token IDs -> Embedding -> Positional Encoding -> N x EncoderBlock -> Output
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int,
        n_heads: int,
        d_ff: int,
        n_layers: int,
        max_seq_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.d_model = d_model

        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)

        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)

        # Stack of encoder blocks
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])

        # Final layer norm
        self.final_norm = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, mask=None, verbose: bool = False):
        """
        Args:
            x: (batch_size, seq_len) -- token IDs
            mask: optional attention mask
            verbose: if True, print shapes at each step
        Returns:
            (batch_size, seq_len, d_model)
        """
        if verbose:
            print(f"  Input token IDs:      {x.shape}")

        # Step 1: Token embedding (scale by sqrt(d_model) as in original paper)
        x = self.token_embedding(x) * math.sqrt(self.d_model)
        if verbose:
            print(f"  After embedding:      {x.shape}")

        # Step 2: Add positional encoding
        x = self.pos_encoding(x)
        if verbose:
            print(f"  After pos encoding:   {x.shape}")

        # Step 3: Pass through encoder blocks
        for i, layer in enumerate(self.layers):
            x = layer(x, mask)
            if verbose:
                print(f"  After encoder block {i}: {x.shape}")

        # Step 4: Final layer norm
        x = self.final_norm(x)
        if verbose:
            print(f"  After final norm:     {x.shape}")

        return x


# === Build and trace a complete encoder ===

# Hyperparameters (small for demonstration)
vocab_size = 1000
d_model = 64
n_heads = 4
d_ff = 256
n_layers = 3
max_seq_len = 128

# Create model
encoder = TransformerEncoder(
    vocab_size=vocab_size,
    d_model=d_model,
    n_heads=n_heads,
    d_ff=d_ff,
    n_layers=n_layers,
    max_seq_len=max_seq_len,
    dropout=0.0,
)

# Create dummy input: batch of 2 sequences, each 10 tokens
batch_size = 2
seq_len = 10
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

print("=" * 60)
print("Tracing data flow through the Transformer Encoder")
print("=" * 60)
print(f"\nConfig: vocab={vocab_size}, d_model={d_model}, heads={n_heads}, "
      f"d_ff={d_ff}, layers={n_layers}")
print()

output = encoder(input_ids, verbose=True)

print(f"\nFinal output shape: {output.shape}")
print(f"  - Batch size: {output.shape[0]}")
print(f"  - Sequence length: {output.shape[1]}")
print(f"  - Hidden dimension: {output.shape[2]}")

total_params = sum(p.numel() for p in encoder.parameters())
print(f"\nTotal model parameters: {total_params:,}")

In [None]:
# Visualize the architecture as a shape diagram
print("\n" + "=" * 60)
print("  Complete Transformer Encoder Architecture")
print("=" * 60)
print(f"""
  Token IDs: ({batch_size}, {seq_len})
      |
      v
  [Token Embedding]  (vocab_size={vocab_size}, d_model={d_model})
      |  -> ({batch_size}, {seq_len}, {d_model})
      v
  [* sqrt(d_model)]  (scaling factor = {math.sqrt(d_model):.2f})
      |
      v
  [+ Positional Encoding]  (sinusoidal, max_len={max_seq_len})
      |  -> ({batch_size}, {seq_len}, {d_model})
      v""")

for i in range(n_layers):
    print(f"  [Encoder Block {i}]")
    print(f"      |-- Multi-Head Attn ({n_heads} heads, d_k={d_model // n_heads})")
    print(f"      |-- Add & LayerNorm")
    print(f"      |-- FFN ({d_model} -> {d_ff} -> {d_model})")
    print(f"      |-- Add & LayerNorm")
    print(f"      |  -> ({batch_size}, {seq_len}, {d_model})")
    print(f"      v")

print(f"  [Final LayerNorm]")
print(f"      |  -> ({batch_size}, {seq_len}, {d_model})")
print(f"      v")
print(f"  Output: ({batch_size}, {seq_len}, {d_model})")
print(f"\n  Each position now contains a context-aware representation")
print(f"  that can be used for downstream tasks.")

---

## 12. Summary & References

### Key Takeaways

1. **Positional Encoding** injects sequence-order information that attention alone cannot capture. The sinusoidal approach uses fixed sin/cos functions at different frequencies; the learned approach uses a trainable embedding table.

2. **Layer Normalization** normalizes across the feature dimension for each token independently, stabilizing training without depending on batch statistics.

3. **Residual Connections** (`x + Sublayer(x)`) provide gradient highways that enable training of very deep networks.

4. **The Feed-Forward Network** is a two-layer MLP (with ReLU and a 4x expansion) applied position-wise, providing non-linear transformation capacity.

5. **The Encoder Block** combines self-attention and FFN, each wrapped with residual connections and layer norm. Multiple blocks are stacked to form deeper models.

6. **The Decoder Block** adds masked (causal) self-attention and cross-attention to the encoder block, enabling autoregressive generation conditioned on encoder output.

7. **Architecture Variants** choose different subsets:
   - **BERT** (encoder-only): bidirectional, great for understanding tasks
   - **GPT** (decoder-only): causal, great for generation tasks
   - **T5** (encoder-decoder): full model, great for seq2seq tasks

8. **Pre-training Objectives** define what the model learns:
   - MLM (BERT): predict masked tokens
   - CLM (GPT): predict next token
   - Span corruption (T5): reconstruct masked spans

### References

- **Paper:** Vaswani et al., ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762) (2017) -- the original Transformer paper
- **Blog:** Jay Alammar, ["The Illustrated Transformer"](https://jalammar.github.io/illustrated-transformer/) -- excellent visual walkthrough
- **Blog:** Lilian Weng, ["Attention? Attention!"](https://lilianweng.github.io/posts/2018-06-24-attention/) -- comprehensive attention survey
- **Book:** Sebastian Raschka, ["Build a Large Language Model (From Scratch)"](https://www.manning.com/books/build-a-large-language-model-from-scratch) (Ch. 3-4) -- hands-on implementation
- **Code:** Harvard NLP, ["The Annotated Transformer"](https://nlp.seas.harvard.edu/annotated-transformer/) -- line-by-line code walkthrough of the paper