# Transformers - Advanced Topics

This notebook extends your basic transformer with:
1. Tokenization (BPE)
2. Sinusoidal positional encodings
3. Training loop
4. Text generation strategies

**Your notes here:**


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import re
from collections import Counter
import matplotlib.pyplot as plt

---
## Part 1: Tokenization - How Text Becomes Numbers

Before transformers can process text, we need to convert it to integers.

**Your notes here:**


In [None]:
class SimpleTokenizer:
    """
    Simple word-level tokenizer.
    
    In practice, you'd use:
    - BPE (Byte Pair Encoding) - GPT
    - WordPiece - BERT  
    - SentencePiece - T5, LLaMA
    
    But let's start simple to understand the concept.
    """
    def __init__(self):
        self.word_to_id = {}
        self.id_to_word = {}
        self.vocab_size = 0
    
    def build_vocab(self, texts, max_vocab=1000):
        """Build vocabulary from list of texts"""
        # Tokenize: split on whitespace and punctuation
        all_words = []
        for text in texts:
            words = re.findall(r'\w+|[^\w\s]', text.lower())
            all_words.extend(words)
        
        # Count frequencies
        word_counts = Counter(all_words)
        
        # Take most common words
        most_common = word_counts.most_common(max_vocab - 4)
        
        # Special tokens
        self.word_to_id = {
            '<PAD>': 0,   # padding token
            '<UNK>': 1,   # unknown token
            '<BOS>': 2,   # beginning of sequence
            '<EOS>': 3    # end of sequence
        }
        
        # Add vocabulary
        for i, (word, _) in enumerate(most_common):
            self.word_to_id[word] = i + 4
        
        # Reverse mapping
        self.id_to_word = {v: k for k, v in self.word_to_id.items()}
        self.vocab_size = len(self.word_to_id)
        
        print(f"Built vocabulary of {self.vocab_size} tokens")
    
    def encode(self, text):
        """Convert text to list of token IDs"""
        words = re.findall(r'\w+|[^\w\s]', text.lower())
        ids = [self.word_to_id.get(w, self.word_to_id['<UNK>']) for w in words]
        return ids
    
    def decode(self, ids):
        """Convert list of token IDs back to text"""
        words = [self.id_to_word.get(i, '<UNK>') for i in ids]
        return ' '.join(words)

# Demo
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "The dog was lazy but the fox was quick.",
    "A brown fox and a brown dog."
]

tokenizer = SimpleTokenizer()
tokenizer.build_vocab(texts, max_vocab=50)

sample_text = "The fox jumps"
encoded = tokenizer.encode(sample_text)
decoded = tokenizer.decode(encoded)

print(f"\nOriginal: {sample_text}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")

### Byte Pair Encoding (BPE) - Used in GPT

BPE learns subword units. Why?
- Unknown words: "unhappiness" â†’ "un" + "happiness"
- Rare words: Split into known pieces
- Efficient vocabulary: Balance between character and word level

**Your notes here:**


In [None]:
class SimpleBPE:
    """
    Simplified Byte Pair Encoding.
    
    Algorithm:
    1. Start with character vocabulary
    2. Find most frequent pair of adjacent tokens
    3. Merge that pair into a new token
    4. Repeat until desired vocab size
    """
    def __init__(self, num_merges=100):
        self.num_merges = num_merges
        self.merges = {}  # (a, b) -> ab
        self.vocab = set()
    
    def get_pairs(self, word):
        """Get all adjacent pairs in a word"""
        pairs = set()
        prev = word[0]
        for char in word[1:]:
            pairs.add((prev, char))
            prev = char
        return pairs
    
    def train(self, texts):
        """Learn BPE merges from texts"""
        # Split into words and characters
        vocab = {}
        for text in texts:
            words = text.lower().split()
            for word in words:
                word_chars = tuple(word) + ('</w>',)  # end of word marker
                vocab[word_chars] = vocab.get(word_chars, 0) + 1
        
        # Perform merges
        for i in range(self.num_merges):
            # Count all pairs
            pair_counts = {}
            for word, freq in vocab.items():
                pairs = self.get_pairs(word)
                for pair in pairs:
                    pair_counts[pair] = pair_counts.get(pair, 0) + freq
            
            if not pair_counts:
                break
            
            # Find most frequent pair
            best_pair = max(pair_counts, key=pair_counts.get)
            
            # Merge this pair in vocabulary
            self.merges[best_pair] = best_pair[0] + best_pair[1]
            
            # Update vocabulary
            new_vocab = {}
            for word in vocab:
                new_word = self.merge_pair(word, best_pair)
                new_vocab[new_word] = vocab[word]
            vocab = new_vocab
            
            if i % 20 == 0:
                print(f"Merge {i}: {best_pair} -> {best_pair[0] + best_pair[1]}")
        
        # Extract final vocabulary
        for word in vocab:
            self.vocab.update(word)
        
        print(f"\nFinal vocab size: {len(self.vocab)}")
    
    def merge_pair(self, word, pair):
        """Merge a specific pair in a word"""
        new_word = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and (word[i], word[i+1]) == pair:
                new_word.append(pair[0] + pair[1])
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        return tuple(new_word)
    
    def encode_word(self, word):
        """Encode a word using learned merges"""
        word = tuple(word) + ('</w>',)
        
        # Apply all merges in order
        for pair in self.merges:
            word = self.merge_pair(word, pair)
        
        return list(word)

# Demo BPE
bpe = SimpleBPE(num_merges=50)
bpe.train(texts)

print("\nEncoding examples:")
for word in ["quick", "lazy", "jumping"]:
    encoded = bpe.encode_word(word)
    print(f"{word:10s} -> {encoded}")

---
## Part 2: Positional Encodings - Sinusoidal vs Learned

Transformers have no inherent notion of order. We must inject position information.

**Your notes here:**


In [None]:
def sinusoidal_position_encoding(max_len, d_model):
    """
    Original transformer paper used sinusoidal positional encoding.
    
    PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    
    Why sine/cosine?
    - Model can learn to attend to relative positions
    - Generalizes to longer sequences than seen in training
    - No learned parameters
    """
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()
    
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                        -(math.log(10000.0) / d_model))
    
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe

# Visualize sinusoidal positional encoding
max_len = 100
d_model = 128

pe = sinusoidal_position_encoding(max_len, d_model)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap
im = axes[0].imshow(pe.numpy(), aspect='auto', cmap='RdBu')
axes[0].set_xlabel('Embedding Dimension')
axes[0].set_ylabel('Position')
axes[0].set_title('Sinusoidal Positional Encoding')
plt.colorbar(im, ax=axes[0])

# Individual positions
for pos in [0, 10, 20, 50]:
    axes[1].plot(pe[pos, :64].numpy(), label=f'Position {pos}', alpha=0.7)
axes[1].set_xlabel('Dimension')
axes[1].set_ylabel('Value')
axes[1].set_title('Encoding for Different Positions')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Notice: Each position has a unique pattern")
print("Low dimensions = slow oscillation, High dimensions = fast oscillation")

In [None]:
# Compare sinusoidal vs learned embeddings
# Learned (what we used in your transformer)
learned_pe = nn.Embedding(max_len, d_model)

print(f"Sinusoidal PE shape: {pe.shape}")
print(f"Sinusoidal: 0 parameters, generalizes to unseen lengths\n")

print(f"Learned PE params: {sum(p.numel() for p in learned_pe.parameters()):,}")
print(f"Learned: {max_len * d_model:,} parameters, fixed max length")

print("\n| Type | Pros | Cons |")
print("|------|------|------|")
print("| Sinusoidal | No params, extrapolates | Fixed formula |")
print("| Learned | Flexible, data-driven | Doesn't extrapolate, needs training |")

---
## Part 3: Training Loop - Actually Train the Transformer

Let's train on a simple task: learning to copy sequences.

**Your notes here:**


In [None]:
# Import your transformer from the main notebook
# For this demo, we'll use a simplified version

from transformers import TwoLayerTransformer  # assumes you saved it

# Or copy-paste the class here if needed

In [None]:
# Create toy dataset: learn to repeat sequences
def create_copy_dataset(n_samples=1000, seq_len=10, vocab_size=20):
    """
    Simple task: given [1, 5, 3, ...], predict [1, 5, 3, ...]
    
    This tests if the transformer can learn to copy.
    """
    data = []
    for _ in range(n_samples):
        seq = torch.randint(1, vocab_size, (seq_len,))
        data.append(seq)
    return torch.stack(data)

# Generate data
vocab_size = 50
seq_len = 16
train_data = create_copy_dataset(n_samples=500, seq_len=seq_len, vocab_size=vocab_size)

print(f"Training data shape: {train_data.shape}")
print(f"Sample sequence: {train_data[0].tolist()}")

In [None]:
# Initialize model
model = TwoLayerTransformer(
    vocab_size=vocab_size,
    d_model=128,
    num_heads=4,
    d_ff=512,
    max_seq_len=seq_len,
    dropout=0.1
)

# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

In [None]:
# Training loop
model.train()
losses = []
batch_size = 32
n_epochs = 100

for epoch in range(n_epochs):
    total_loss = 0
    n_batches = 0
    
    for i in range(0, len(train_data), batch_size):
        batch = train_data[i:i+batch_size]
        
        # Forward pass
        logits = model(batch)  # (batch, seq_len, vocab_size)
        
        # Loss: predict next token at each position
        # Teacher forcing: use ground truth for input
        loss = criterion(
            logits[:, :-1].reshape(-1, vocab_size),  # predictions for positions 0 to n-1
            batch[:, 1:].reshape(-1)                  # targets: tokens at positions 1 to n
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping (prevents exploding gradients)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        n_batches += 1
    
    avg_loss = total_loss / n_batches
    losses.append(avg_loss)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch:3d}: Loss = {avg_loss:.4f}")

print("\nTraining complete!")

In [None]:
# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(losses, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Transformer Training on Copy Task')
plt.grid(True, alpha=0.3)
plt.show()

print(f"Final loss: {losses[-1]:.4f}")

In [None]:
# Test the trained model
model.eval()
test_seq = train_data[0]

with torch.no_grad():
    logits = model(test_seq.unsqueeze(0))
    predictions = torch.argmax(logits, dim=-1).squeeze()

print("\nTest on training example:")
print(f"Input:      {test_seq.tolist()}")
print(f"Predicted:  {predictions.tolist()}")
print(f"Target:     {test_seq.tolist()[1:] + [0]}")

# Calculate accuracy
correct = (predictions[:-1] == test_seq[1:]).sum().item()
total = len(test_seq) - 1
print(f"\nAccuracy: {correct}/{total} = {100*correct/total:.1f}%")

---
## Part 4: Generation Strategies

How to generate text from a trained transformer.

**Your notes here:**


In [None]:
def generate_greedy(model, start_tokens, max_len=20, pad_token=0):
    """
    Greedy decoding: always pick the most likely next token.
    
    Simple but leads to repetitive text.
    """
    model.eval()
    tokens = start_tokens.clone()
    
    with torch.no_grad():
        for _ in range(max_len - len(tokens)):
            # Get logits for all positions
            logits = model(tokens.unsqueeze(0))  # (1, seq_len, vocab)
            
            # Take logits for last position
            next_token_logits = logits[0, -1, :]  # (vocab,)
            
            # Pick most likely token
            next_token = torch.argmax(next_token_logits)
            
            # Append to sequence
            tokens = torch.cat([tokens, next_token.unsqueeze(0)])
    
    return tokens

In [None]:
def generate_temperature(model, start_tokens, max_len=20, temperature=1.0):
    """
    Temperature sampling:
    - temperature > 1: more random (more diversity)
    - temperature < 1: more confident (less diversity)
    - temperature = 1: sample from true distribution
    
    Formula: p_i = exp(logit_i / T) / sum(exp(logit_j / T))
    """
    model.eval()
    tokens = start_tokens.clone()
    
    with torch.no_grad():
        for _ in range(max_len - len(tokens)):
            logits = model(tokens.unsqueeze(0))
            next_token_logits = logits[0, -1, :] / temperature
            
            # Sample from probability distribution
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            tokens = torch.cat([tokens, next_token])
    
    return tokens

In [None]:
def generate_top_k(model, start_tokens, max_len=20, k=10):
    """
    Top-k sampling: only sample from k most likely tokens.
    
    Prevents sampling very unlikely tokens (reduces nonsense).
    Used in GPT-2.
    """
    model.eval()
    tokens = start_tokens.clone()
    
    with torch.no_grad():
        for _ in range(max_len - len(tokens)):
            logits = model(tokens.unsqueeze(0))
            next_token_logits = logits[0, -1, :]
            
            # Keep only top k logits
            top_k_logits, top_k_indices = torch.topk(next_token_logits, k)
            
            # Sample from top k
            probs = F.softmax(top_k_logits, dim=-1)
            next_token_idx = torch.multinomial(probs, num_samples=1)
            next_token = top_k_indices[next_token_idx]
            
            tokens = torch.cat([tokens, next_token])
    
    return tokens

In [None]:
def generate_top_p(model, start_tokens, max_len=20, p=0.9):
    """
    Top-p (nucleus) sampling: sample from smallest set of tokens
    whose cumulative probability >= p.
    
    Used in GPT-3, ChatGPT.
    Adaptive: fewer tokens when model is confident, more when uncertain.
    """
    model.eval()
    tokens = start_tokens.clone()
    
    with torch.no_grad():
        for _ in range(max_len - len(tokens)):
            logits = model(tokens.unsqueeze(0))
            next_token_logits = logits[0, -1, :]
            
            # Sort by probability
            probs = F.softmax(next_token_logits, dim=-1)
            sorted_probs, sorted_indices = torch.sort(probs, descending=True)
            
            # Find cumulative probability
            cumulative_probs = torch.cumsum(sorted_probs, dim=0)
            
            # Remove tokens below threshold
            sorted_indices_to_remove = cumulative_probs > p
            # Keep at least one token
            if sorted_indices_to_remove[0]:
                sorted_indices_to_remove[0] = False
            else:
                sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
                sorted_indices_to_remove[0] = False
            
            # Zero out removed indices
            sorted_probs[sorted_indices_to_remove] = 0
            sorted_probs = sorted_probs / sorted_probs.sum()
            
            # Sample
            next_token_idx = torch.multinomial(sorted_probs, num_samples=1)
            next_token = sorted_indices[next_token_idx]
            
            tokens = torch.cat([tokens, next_token])
    
    return tokens

In [None]:
# Demo all sampling methods
start = torch.tensor([1, 5, 8, 3])

print("Starting tokens:", start.tolist())
print("\nGeneration methods:")
print("-" * 60)

print(f"Greedy:           {generate_greedy(model, start, max_len=12).tolist()}")
print(f"Temp=0.5 (safe):  {generate_temperature(model, start, max_len=12, temperature=0.5).tolist()}")
print(f"Temp=1.0 (std):   {generate_temperature(model, start, max_len=12, temperature=1.0).tolist()}")
print(f"Temp=1.5 (wild):  {generate_temperature(model, start, max_len=12, temperature=1.5).tolist()}")
print(f"Top-k (k=5):      {generate_top_k(model, start, max_len=12, k=5).tolist()}")
print(f"Top-p (p=0.9):    {generate_top_p(model, start, max_len=12, p=0.9).tolist()}")

### Sampling Strategy Comparison

| Method | Description | Pros | Cons | Used In |
|--------|-------------|------|------|--------|
| Greedy | Pick argmax | Fast, deterministic | Repetitive, no diversity | Debugging |
| Temperature | Scale logits | Controls randomness | Samples bad tokens | Tuning creativity |
| Top-k | Sample from k best | Filters unlikely tokens | Fixed cutoff | GPT-2 |
| Top-p | Sample from cumulative p | Adaptive cutoff | More complex | GPT-3, ChatGPT |

**Your notes here:**


---
## Summary

You now understand:
1. **Tokenization**: Word-level and BPE (subword)
2. **Positional encodings**: Sinusoidal (generalizes) vs Learned (flexible)
3. **Training**: Cross-entropy loss, teacher forcing, gradient clipping
4. **Generation**: Greedy, temperature, top-k, top-p sampling

**Next steps for building your own library:**
- Implement encoder-decoder architecture (for translation)
- Add Pre-LN vs Post-LN (layer norm placement)
- Implement KV caching (for efficient generation)
- Understand Flash Attention (memory-efficient attention)
- Study ROPE (Rotary Position Embeddings) - used in LLaMA

**Your notes here:**
