# Module 2: Understanding Transformers

In this notebook, we'll build a deep understanding of transformer architectures from the ground up.

## Learning Objectives

By the end of this notebook, you will:
1. Understand the core concepts behind self-attention
2. Implement scaled dot-product attention from scratch
3. Build multi-head attention mechanisms
4. Explore positional encodings (learned vs. rotary)
5. Understand the decoder-only architecture (GPT-style)
6. Visualize attention patterns

## Prerequisites

- Basic understanding of neural networks
- Familiarity with PyTorch
- Completed Module 1 (Data Preparation)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (10, 6)

## Part 1: The Attention Mechanism

### What is Attention?

Attention allows a model to focus on relevant parts of the input when processing each element. Think of it like reading a sentence - when you process the word "it", you look back at previous words to understand what "it" refers to.

### Key Concepts:

- **Query (Q)**: What am I looking for?
- **Key (K)**: What do I have to offer?
- **Value (V)**: What information do I actually contain?

The attention mechanism computes a weighted sum of values, where weights are determined by the compatibility between queries and keys.

### Scaled Dot-Product Attention

The fundamental attention operation:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

Where:
- $d_k$ is the dimension of the key vectors
- Division by $\sqrt{d_k}$ prevents dot products from becoming too large
- Softmax converts scores into a probability distribution

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute scaled dot-product attention.

    Args:
        query: Query tensor of shape (batch_size, seq_len, d_k)
        key: Key tensor of shape (batch_size, seq_len, d_k)
        value: Value tensor of shape (batch_size, seq_len, d_v)
        mask: Optional mask tensor (1 for positions to keep, 0 to mask)

    Returns:
        output: Attention output (batch_size, seq_len, d_v)
        attention_weights: Attention weights (batch_size, seq_len, seq_len)
    """
    d_k = query.size(-1)

    # 1. Compute attention scores: Q @ K^T
    scores = torch.matmul(query, key.transpose(-2, -1))  # (batch, seq_len, seq_len)

    # 2. Scale by sqrt(d_k)
    scores = scores / np.sqrt(d_k)

    # 3. Apply mask (if provided) - set masked positions to -inf
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))

    # 4. Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)

    # 5. Compute weighted sum of values
    output = torch.matmul(attention_weights, value)

    return output, attention_weights

### Let's Test It!

We'll create a simple example with 4 tokens in a sequence.

In [None]:
# Create sample input
batch_size = 1
seq_len = 4
d_model = 8  # embedding dimension

# Random query, key, value vectors
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

# Compute attention
output, attn_weights = scaled_dot_product_attention(Q, K, V)

print(f"Input shape: {Q.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print("\nAttention weights (each row sums to 1):")
print(attn_weights[0].detach().numpy())

### Visualizing Attention

Let's visualize the attention pattern. Brighter colors indicate higher attention weights.

In [None]:
def plot_attention(attention_weights, tokens=None):
    """
    Visualize attention weights as a heatmap.
    """
    # Convert to numpy
    attn = attention_weights.detach().cpu().numpy()
    if len(attn.shape) == 3:
        attn = attn[0]  # Take first batch

    # Create labels
    if tokens is None:
        tokens = [f"Token {i}" for i in range(attn.shape[0])]

    # Plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attn,
        annot=True,
        fmt=".2f",
        cmap="Blues",
        xticklabels=tokens,
        yticklabels=tokens,
        cbar_kws={"label": "Attention Weight"},
    )
    plt.xlabel("Key/Value Position")
    plt.ylabel("Query Position")
    plt.title("Attention Pattern")
    plt.tight_layout()
    plt.show()


# Visualize our attention weights
plot_attention(attn_weights)

### Causal Masking (for Decoder-only Models)

In language models (like GPT), we predict the next token based only on previous tokens. This requires **causal masking** - preventing tokens from attending to future tokens.

In [None]:
def create_causal_mask(seq_len):
    """
    Create a causal mask that prevents attention to future positions.

    Returns a lower triangular matrix:
    [[1, 0, 0, 0],
     [1, 1, 0, 0],
     [1, 1, 1, 0],
     [1, 1, 1, 1]]
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask


# Create and visualize causal mask
mask = create_causal_mask(4)
print("Causal Mask:")
print(mask.numpy())

# Apply causal attention
output_causal, attn_weights_causal = scaled_dot_product_attention(Q, K, V, mask=mask)

print("\nCausal Attention Weights:")
plot_attention(attn_weights_causal, tokens=["The", "cat", "sat", "down"])

## Part 2: Multi-Head Attention

### Why Multiple Heads?

Multi-head attention allows the model to attend to different aspects of the input simultaneously. Think of it like having multiple "perspectives" or "experts" looking at the same data.

For example, when processing the sentence "The cat sat on the mat":
- Head 1 might focus on subject-verb relationships
- Head 2 might focus on positional relationships (on, under, etc.)
- Head 3 might focus on noun attributes

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism.
    """

    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Dimension per head

        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        # Output projection
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x):
        """Split the last dimension into (num_heads, d_k)."""
        batch_size, seq_len, d_model = x.size()
        # Reshape: (batch, seq_len, num_heads, d_k)
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        # Transpose: (batch, num_heads, seq_len, d_k)
        return x.transpose(1, 2)

    def combine_heads(self, x):
        """Combine heads back to original dimension."""
        batch_size, num_heads, seq_len, d_k = x.size()
        # Transpose: (batch, seq_len, num_heads, d_k)
        x = x.transpose(1, 2).contiguous()
        # Reshape: (batch, seq_len, d_model)
        return x.view(batch_size, seq_len, self.d_model)

    def forward(self, query, key, value, mask=None):
        """
        Forward pass for multi-head attention.

        Args:
            query, key, value: Input tensors of shape (batch, seq_len, d_model)
            mask: Optional attention mask
        """
        # 1. Linear projections
        Q = self.W_q(query)  # (batch, seq_len, d_model)
        K = self.W_k(key)
        V = self.W_v(value)

        # 2. Split into multiple heads
        Q = self.split_heads(Q)  # (batch, num_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # 3. Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)

        # 4. Apply mask
        if mask is not None:
            # Expand mask for multiple heads
            mask = mask.unsqueeze(1)  # (batch, 1, seq_len, seq_len)
            scores = scores.masked_fill(mask == 0, float("-inf"))

        # 5. Apply softmax and dropout
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # 6. Apply attention to values
        context = torch.matmul(attention_weights, V)  # (batch, num_heads, seq_len, d_k)

        # 7. Combine heads
        context = self.combine_heads(context)  # (batch, seq_len, d_model)

        # 8. Final linear projection
        output = self.W_o(context)

        return output, attention_weights

### Testing Multi-Head Attention

In [None]:
# Create multi-head attention module
d_model = 512
num_heads = 8
seq_len = 10
batch_size = 2

mha = MultiHeadAttention(d_model, num_heads)

# Create sample input
x = torch.randn(batch_size, seq_len, d_model)
mask = create_causal_mask(seq_len).unsqueeze(0).expand(batch_size, -1, -1)

# Forward pass
output, attn_weights = mha(x, x, x, mask=mask)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(
    f"Attention weights shape: {attn_weights.shape}"
)  # (batch, num_heads, seq_len, seq_len)
print(f"\nNumber of parameters: {sum(p.numel() for p in mha.parameters()):,}")

### Visualizing Different Attention Heads

Let's visualize what different heads are learning:

In [None]:
# Visualize first 4 heads
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

for i in range(4):
    attn = attn_weights[0, i].detach().cpu().numpy()  # First batch, i-th head

    sns.heatmap(
        attn,
        ax=axes[i],
        cmap="Blues",
        cbar=True,
        cbar_kws={"label": "Attention Weight"},
    )
    axes[i].set_title(f"Head {i + 1}")
    axes[i].set_xlabel("Key Position")
    axes[i].set_ylabel("Query Position")

plt.tight_layout()
plt.show()

print("Notice how different heads learn different attention patterns!")

## Part 3: Positional Encoding

### The Problem

Attention is **permutation-invariant** - it doesn't know the order of tokens! We need to inject positional information.

### Two Approaches:

1. **Sinusoidal Positional Encoding** (original Transformer)
2. **Rotary Positional Embeddings (RoPE)** (modern, used in our model)

We'll focus on RoPE as it's more powerful and used in GPT-NeoX, LLaMA, and our Storyteller model.

### Rotary Positional Embeddings (RoPE)

RoPE applies a rotation to query and key vectors based on their position. This elegantly encodes relative positions.

Key advantages:
- Encodes relative positions naturally
- Works well for long sequences
- No learnable parameters needed

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """
    Rotary Positional Embeddings (RoPE) from Su et al. (2021).
    """

    def __init__(self, dim, max_seq_len=2048, base=10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Precompute frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Cache for efficiency
        self._seq_len_cached = None
        self._cos_cached = None
        self._sin_cached = None

    def _update_cache(self, seq_len, device):
        """Update cached cos/sin values if sequence length changed."""
        if seq_len != self._seq_len_cached:
            self._seq_len_cached = seq_len
            t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self._cos_cached = emb.cos()[None, :, None, :]
            self._sin_cached = emb.sin()[None, :, None, :]

    def rotate_half(self, x):
        """Rotate half the hidden dims of the input."""
        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, q, k):
        """
        Apply rotary embeddings to query and key tensors.

        Args:
            q: Query tensor (batch, seq_len, num_heads, d_k)
            k: Key tensor (batch, seq_len, num_heads, d_k)
        """
        seq_len = q.shape[1]
        self._update_cache(seq_len, q.device)

        # Apply rotation
        q_embed = (q * self._cos_cached[:, :seq_len]) + (
            self.rotate_half(q) * self._sin_cached[:, :seq_len]
        )
        k_embed = (k * self._cos_cached[:, :seq_len]) + (
            self.rotate_half(k) * self._sin_cached[:, :seq_len]
        )

        return q_embed, k_embed


# Test RoPE
rope = RotaryPositionalEmbedding(dim=64)
q = torch.randn(1, 10, 8, 64)  # (batch, seq_len, num_heads, d_k)
k = torch.randn(1, 10, 8, 64)

q_rot, k_rot = rope(q, k)
print(f"Query shape: {q_rot.shape}")
print(f"Key shape: {k_rot.shape}")
print("Rotary embeddings applied successfully!")

## Part 4: Feed-Forward Networks

After attention, transformers apply a position-wise feed-forward network to each token independently.

Structure:
```
FFN(x) = activation(x @ W1 + b1) @ W2 + b2
```

Typically:
- First layer expands dimension (e.g., 512 → 2048)
- Activation function (GELU, ReLU, etc.)
- Second layer projects back (e.g., 2048 → 512)

In [None]:
class FeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network.
    """

    def __init__(self, d_model, d_ff, dropout=0.1, activation="gelu"):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

        # Choose activation function
        if activation == "gelu":
            self.activation = nn.GELU()
        elif activation == "relu":
            self.activation = nn.ReLU()
        else:
            raise ValueError(f"Unknown activation: {activation}")

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = self.fc1(x)  # (batch, seq_len, d_ff)
        x = self.activation(x)  # Apply non-linearity
        x = self.dropout(x)
        x = self.fc2(x)  # (batch, seq_len, d_model)
        return x


# Test
ffn = FeedForward(d_model=512, d_ff=2048)
x = torch.randn(2, 10, 512)
output = ffn(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Parameters: {sum(p.numel() for p in ffn.parameters()):,}")

## Part 5: Building a Complete Transformer Block

Now let's combine everything into a complete transformer decoder block:

1. Multi-head self-attention (with causal mask)
2. Add & Norm (residual connection + layer normalization)
3. Feed-forward network
4. Add & Norm

In [None]:
class TransformerBlock(nn.Module):
    """
    A single transformer decoder block.
    """

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # Multi-head attention
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)

        # Feed-forward network
        self.ffn = FeedForward(d_model, d_ff, dropout)

        # Layer normalization
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Forward pass through transformer block.

        Args:
            x: Input tensor (batch, seq_len, d_model)
            mask: Optional causal mask
        """
        # 1. Multi-head attention with residual connection
        attn_output, attn_weights = self.attention(x, x, x, mask)
        x = self.ln1(x + self.dropout(attn_output))  # Add & Norm

        # 2. Feed-forward with residual connection
        ffn_output = self.ffn(x)
        x = self.ln2(x + self.dropout(ffn_output))  # Add & Norm

        return x, attn_weights


# Test a complete block
block = TransformerBlock(d_model=512, num_heads=8, d_ff=2048)
x = torch.randn(2, 10, 512)
mask = create_causal_mask(10).unsqueeze(0).expand(2, -1, -1)

output, attn = block(x, mask)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nTotal parameters: {sum(p.numel() for p in block.parameters()):,}")

## Part 6: Decoder-Only Architecture (GPT-Style)

Our Storyteller model uses a **decoder-only** architecture, like GPT. Let's understand why:

### Encoder-Decoder vs. Decoder-Only

**Encoder-Decoder** (BERT, T5):
- Encoder processes input without causal masking
- Decoder generates output with cross-attention to encoder
- Good for: translation, summarization (input → different output)

**Decoder-Only** (GPT, LLaMA, our model):
- Single stack of decoder blocks
- Always uses causal masking
- Good for: text generation, completion, few-shot learning

### Why Decoder-Only for Story Generation?

1. **Autoregressive generation**: Naturally suited for sequential text
2. **Simplicity**: Fewer components to train
3. **Flexibility**: Can handle various tasks via prompting
4. **Scalability**: Easier to scale to billions of parameters

In [None]:
class SimpleGPT(nn.Module):
    """
    A simplified GPT-style decoder-only language model.
    """

    def __init__(
        self,
        vocab_size,
        d_model=512,
        num_layers=6,
        num_heads=8,
        d_ff=2048,
        max_seq_len=1024,
        dropout=0.1,
    ):
        super().__init__()

        # Token embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)

        # Positional embeddings (learnable)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)

        # Transformer blocks
        self.blocks = nn.ModuleList(
            [
                TransformerBlock(d_model, num_heads, d_ff, dropout)
                for _ in range(num_layers)
            ]
        )

        # Final layer norm
        self.ln_f = nn.LayerNorm(d_model)

        # Output projection to vocabulary
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Dropout
        self.dropout = nn.Dropout(dropout)

        self.max_seq_len = max_seq_len

    def forward(self, input_ids, labels=None):
        """
        Forward pass.

        Args:
            input_ids: Token indices (batch, seq_len)
            labels: Optional target tokens for loss computation
        """
        batch_size, seq_len = input_ids.size()

        # Create position indices
        positions = torch.arange(0, seq_len, device=input_ids.device).unsqueeze(0)

        # Embed tokens and positions
        token_emb = self.token_embedding(input_ids)  # (batch, seq_len, d_model)
        pos_emb = self.pos_embedding(positions)  # (1, seq_len, d_model)

        # Combine embeddings
        x = self.dropout(token_emb + pos_emb)

        # Create causal mask
        mask = create_causal_mask(seq_len).to(input_ids.device)
        mask = mask.unsqueeze(0).expand(batch_size, -1, -1)

        # Pass through transformer blocks
        attention_maps = []
        for block in self.blocks:
            x, attn = block(x, mask)
            attention_maps.append(attn)

        # Final layer norm
        x = self.ln_f(x)

        # Project to vocabulary
        logits = self.lm_head(x)  # (batch, seq_len, vocab_size)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            # Shift so predictions align with next token
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()

            # Flatten and compute cross-entropy
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )

        return {"logits": logits, "loss": loss, "attention_maps": attention_maps}


# Create a small model
model = SimpleGPT(
    vocab_size=10000, d_model=256, num_layers=4, num_heads=4, d_ff=1024, max_seq_len=128
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model created with {total_params:,} parameters")
print(f"Model size: ~{total_params * 4 / 1e6:.1f} MB (float32)")

### Testing the Model

Let's do a forward pass with some dummy tokens:

In [None]:
# Create dummy input
batch_size = 4
seq_len = 20
input_ids = torch.randint(0, 10000, (batch_size, seq_len))

# Forward pass
with torch.no_grad():
    outputs = model(input_ids, labels=input_ids)

print(f"Input shape: {input_ids.shape}")
print(f"Output logits shape: {outputs['logits'].shape}")
print(f"Loss: {outputs['loss'].item():.4f}")
print(f"Number of attention maps: {len(outputs['attention_maps'])}")

# Look at predictions for first token
first_token_logits = outputs["logits"][0, 0]  # First batch, first position
top5_tokens = torch.topk(first_token_logits, k=5)
print(f"\nTop 5 predicted token IDs for first position: {top5_tokens.indices.tolist()}")
print(f"Their probabilities: {F.softmax(top5_tokens.values, dim=0).tolist()}")

## Summary and Key Takeaways

In this notebook, you learned:

1. **Attention Mechanism**:
   - Query, Key, Value paradigm
   - Scaled dot-product attention
   - Causal masking for autoregressive generation

2. **Multi-Head Attention**:
   - Multiple parallel attention "heads"
   - Different heads learn different patterns
   - Richer representation of relationships

3. **Positional Encoding**:
   - Why we need it (attention is permutation-invariant)
   - Rotary embeddings (RoPE) for relative positions

4. **Transformer Architecture**:
   - Feed-forward networks
   - Residual connections and layer normalization
   - Complete transformer blocks

5. **Decoder-Only Models**:
   - GPT-style architecture
   - Autoregressive text generation
   - Why it works for story generation

### What's Next?

In the next notebook, we'll explore **Mixture of Experts (MoE)** - a technique to scale our model efficiently by using sparse expert routing.

### Further Reading

- [Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani et al., 2017)
- [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) (Su et al., 2021)
- [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/) (Jay Alammar)
- [GPT-3 Paper](https://arxiv.org/abs/2005.14165) (Brown et al., 2020)

## Exercise: Build Your Own Attention Variant

Try implementing these variations:

1. **Sliding Window Attention**: Limit attention to nearby tokens only
2. **Cross-Attention**: Attention between two different sequences
3. **Relative Position Bias**: Add learnable biases based on distance

In [None]:
# Your code here!

# Example: Sliding window attention
def create_sliding_window_mask(seq_len, window_size):
    """
    Create a mask that only allows attention within a sliding window.

    Args:
        seq_len: Sequence length
        window_size: Size of attention window on each side
    """
    # TODO: Implement this!
    pass


# Test your implementation
# mask = create_sliding_window_mask(10, window_size=2)
# plot_attention(mask.unsqueeze(0), tokens=[f"T{i}" for i in range(10)])