# Part 8.1: Tokenization & Language Model Training — The Formula 1 Edition

Before an LLM can process text, it must convert characters into numbers. This seemingly mundane step — **tokenization** — turns out to be one of the most consequential design decisions in modern AI. The choice of tokenizer affects model performance, multilingual capabilities, inference cost, and even what the model can learn.

**F1 analogy:** Think of tokenization as how an F1 team encodes telemetry data. Raw sensor readings arrive as continuous streams — throttle position, brake pressure, tire temperatures, GPS coordinates. Before the pit wall can analyze them, these signals must be discretized into a vocabulary of meaningful patterns. Just as BPE discovers common character sequences and merges them into tokens, an F1 data system discovers common telemetry patterns (e.g., "hard braking into a slow corner" or "DRS activation on a straight") and encodes them as reusable units. The vocabulary size is a tradeoff: too small and you miss rare but critical events (a one-off sensor anomaly), too large and the system becomes unwieldy — just like choosing between 32K and 100K tokens.

In this notebook, we'll build tokenizers from scratch, understand how language models are trained, and explore the **scaling laws** that govern how model performance improves with size, data, and compute.

## Learning Objectives

- [ ] Understand why tokenization matters and its impact on model behavior
- [ ] Implement Byte Pair Encoding (BPE) from scratch
- [ ] Build a WordPiece tokenizer and compare with BPE
- [ ] Understand pretraining objectives: causal LM vs masked LM
- [ ] Implement a small language model training loop from scratch
- [ ] Explore scaling laws: how performance relates to model size, data, and compute
- [ ] Understand training dynamics: loss curves, learning rate schedules, warm-up

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from collections import defaultdict, Counter
import re
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(42)
torch.manual_seed(42)

print("Part 8.1: Tokenization & Language Model Training — The Formula 1 Edition")
print("=" * 72)

---

## 1. Why Tokenization Matters

Models don't see text — they see sequences of integer IDs. The tokenizer defines this mapping.

| Approach | Example: "unhappiness" | Vocab Size | Trade-off | F1 Parallel |
|----------|----------------------|------------|----------|-------------|
| **Character** | u, n, h, a, p, p, i, n, e, s, s | ~256 | Short vocab, long sequences | Recording every individual sensor tick — precise but overwhelming |
| **Word** | unhappiness | ~100K+ | Short sequences, huge vocab, OOV problem | One code per entire maneuver — compact but can't represent novel situations |
| **Subword** | un, happi, ness | ~32K-100K | Best of both worlds | Encoding reusable telemetry patterns — "brake-turn-in", "apex-throttle" |

### Why Subword Tokenization Won

1. **No OOV**: Can represent any word by composing subwords
2. **Shared morphology**: "unhappy" and "happiness" share subwords
3. **Efficient**: Common words stay whole, rare words decompose
4. **Multilingual**: Works across languages with shared scripts

**F1 analogy:** Subword tokenization is like the telemetry compression an F1 team uses over race weekend. Common sequences — a clean lap through Maggotts-Becketts, a standard pit stop — get encoded as single compact tokens. But a rare event like a puncture or a safety car restart can still be represented by composing smaller known patterns. The corpus of race transcripts, telemetry logs, and FIA regulations all feed into building this vocabulary.

In [None]:
# Visualize tokenization approaches
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

text = "unhappiness"

approaches = [
    ('Character-level', list(text), '#e74c3c'),
    ('Word-level', [text], '#3498db'),
    ('Subword (BPE)', ['un', 'happi', 'ness'], '#2ecc71'),
]

for ax, (name, tokens, color) in zip(axes, approaches):
    ax.set_xlim(-0.5, max(len(tokens), 4) - 0.5)
    ax.set_ylim(-0.5, 1.5)
    ax.axis('off')
    ax.set_title(name, fontsize=13, fontweight='bold')
    
    total_width = len(tokens)
    start_x = (max(len(tokens), 4) - total_width) / 2
    
    for i, tok in enumerate(tokens):
        box = mpatches.FancyBboxPatch((start_x + i - 0.4, 0.2), 0.8, 0.8,
                                       boxstyle="round,pad=0.05", facecolor=color,
                                       edgecolor='black', linewidth=1.5, alpha=0.8)
        ax.add_patch(box)
        ax.text(start_x + i, 0.6, tok, ha='center', va='center',
               fontsize=10 if len(tok) < 5 else 8, fontweight='bold', color='white')
    
    ax.text(max(len(tokens), 4) / 2, -0.2, f'{len(tokens)} token{"s" if len(tokens) > 1 else ""}',
           ha='center', fontsize=10, color='gray')

plt.suptitle(f'Tokenizing: "{text}"', fontsize=14, fontweight='bold', y=1.05)
plt.tight_layout()
plt.show()

---

## 2. Byte Pair Encoding (BPE)

BPE is the most widely used tokenization algorithm (GPT, LLaMA, etc.). It works by iteratively merging the most frequent pair of adjacent tokens.

### Algorithm

1. Start with a vocabulary of individual characters
2. Count all adjacent pairs in the corpus
3. Merge the most frequent pair into a new token
4. Repeat until desired vocabulary size is reached

### Example
```
Corpus: "low lower newest"
Start: l o w _ l o w e r _ n e w e s t
Merge 'e','w' -> 'ew':  l o w _ l o w ew r _ n ew e s t  
Merge 'l','o' -> 'lo':  lo w _ lo w ew r _ n ew e s t
...
```

**F1 analogy:** BPE is like the way an F1 engineer learns to read telemetry traces over a season. At first, every data point is individual — throttle at 73%, brake pressure 42 bar, steering angle 12 degrees. Over thousands of laps, the engineer starts recognizing frequent *pairs* of events: "lift-and-coast" (throttle drop + coasting), "trail-braking" (brake + turn-in together). These get merged into single recognized patterns. The most common sequences get their own shorthand first, while rare corner-specific patterns stay decomposed.

In [None]:
class BPETokenizer:
    """Byte Pair Encoding tokenizer from scratch."""
    
    def __init__(self):
        self.merges = {}  # (a, b) -> merged token
        self.vocab = {}   # token -> id
        self.inverse_vocab = {}  # id -> token
        self.merge_history = []  # For visualization
    
    def _get_word_freqs(self, text):
        """Split text into words and count frequencies."""
        words = re.findall(r'\S+', text.lower())
        word_freqs = Counter()
        for word in words:
            # Add end-of-word marker
            chars = tuple(list(word) + ['</w>'])
            word_freqs[chars] += 1
        return word_freqs
    
    def _get_pair_counts(self, word_freqs):
        """Count adjacent pairs across all words."""
        pairs = Counter()
        for word, freq in word_freqs.items():
            for i in range(len(word) - 1):
                pairs[(word[i], word[i+1])] += freq
        return pairs
    
    def _merge_pair(self, word_freqs, pair):
        """Merge a pair in all words."""
        new_word_freqs = {}
        merged = pair[0] + pair[1]
        
        for word, freq in word_freqs.items():
            new_word = []
            i = 0
            while i < len(word):
                if i < len(word) - 1 and word[i] == pair[0] and word[i+1] == pair[1]:
                    new_word.append(merged)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word_freqs[tuple(new_word)] = freq
        
        return new_word_freqs
    
    def train(self, text, num_merges=20, verbose=True):
        """Train BPE tokenizer on text."""
        word_freqs = self._get_word_freqs(text)
        
        # Initial vocab: all characters + end-of-word
        all_chars = set()
        for word in word_freqs:
            all_chars.update(word)
        
        self.vocab = {ch: i for i, ch in enumerate(sorted(all_chars))}
        self.merge_history = []
        
        if verbose:
            print(f"Initial vocab ({len(self.vocab)} tokens): {sorted(self.vocab.keys())}")
            print(f"\nTraining {num_merges} merges...\n")
        
        for step in range(num_merges):
            pairs = self._get_pair_counts(word_freqs)
            if not pairs:
                break
            
            best_pair = max(pairs, key=pairs.get)
            best_count = pairs[best_pair]
            
            # Merge
            word_freqs = self._merge_pair(word_freqs, best_pair)
            merged_token = best_pair[0] + best_pair[1]
            self.merges[best_pair] = merged_token
            self.vocab[merged_token] = len(self.vocab)
            
            self.merge_history.append({
                'step': step + 1,
                'pair': best_pair,
                'merged': merged_token,
                'count': best_count,
                'vocab_size': len(self.vocab)
            })
            
            if verbose:
                print(f"  Step {step+1}: merge '{best_pair[0]}' + '{best_pair[1]}' "
                      f"-> '{merged_token}' (freq={best_count}, vocab={len(self.vocab)})")
        
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}
        return self
    
    def encode(self, text):
        """Encode text to token IDs."""
        words = re.findall(r'\S+', text.lower())
        all_tokens = []
        
        for word in words:
            tokens = list(word) + ['</w>']
            
            # Apply merges in order
            for pair, merged in self.merges.items():
                i = 0
                new_tokens = []
                while i < len(tokens):
                    if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i+1] == pair[1]:
                        new_tokens.append(merged)
                        i += 2
                    else:
                        new_tokens.append(tokens[i])
                        i += 1
                tokens = new_tokens
            
            all_tokens.extend(tokens)
        
        return [self.vocab.get(t, 0) for t in all_tokens], all_tokens
    
    def decode(self, ids):
        """Decode token IDs back to text."""
        tokens = [self.inverse_vocab.get(i, '?') for i in ids]
        text = ''.join(tokens).replace('</w>', ' ').strip()
        return text


# Training corpus
corpus = """the cat sat on the mat. the cat ate the rat. 
the dog sat on the log. the dog chased the cat.
a cat is a small animal. a dog is a loyal animal.
the cat and the dog are friends. the mat is on the floor."""

tokenizer = BPETokenizer()
tokenizer.train(corpus, num_merges=25)

# Test encoding
test_texts = ["the cat", "the dog sat", "animal friends"]
print("\nEncoding examples:")
for text in test_texts:
    ids, tokens = tokenizer.encode(text)
    print(f"  '{text}' -> {tokens} -> {ids}")
    decoded = tokenizer.decode(ids)
    print(f"  Decoded: '{decoded}'")

In [None]:
# Visualize BPE merge history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Vocab size growth
ax = axes[0]
steps = [h['step'] for h in tokenizer.merge_history]
vocab_sizes = [h['vocab_size'] for h in tokenizer.merge_history]
ax.plot(steps, vocab_sizes, 'b-o', linewidth=2, markersize=5)
ax.set_xlabel('Merge Step', fontsize=11)
ax.set_ylabel('Vocabulary Size', fontsize=11)
ax.set_title('BPE Vocabulary Growth', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

# Merge frequency (how frequent each merged pair was)
ax = axes[1]
merge_freqs = [h['count'] for h in tokenizer.merge_history]
merge_labels = [h['merged'][:8] for h in tokenizer.merge_history]
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(merge_freqs)))
ax.bar(range(len(merge_freqs)), merge_freqs, color=colors, edgecolor='black', alpha=0.8)
ax.set_xticks(range(len(merge_labels)))
ax.set_xticklabels(merge_labels, rotation=60, ha='right', fontsize=7)
ax.set_xlabel('Merged Token', fontsize=11)
ax.set_ylabel('Pair Frequency', fontsize=11)
ax.set_title('BPE Merge Frequencies', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---

## 3. WordPiece Tokenization

WordPiece (used by BERT) is similar to BPE but uses a different criterion for choosing merges: instead of frequency, it maximizes the **likelihood of the training data**.

$$\text{score}(a, b) = \frac{\text{freq}(ab)}{\text{freq}(a) \times \text{freq}(b)}$$

This favors merging pairs where the combination is more frequent *relative* to the individual pieces — capturing meaningful subwords rather than just frequent character sequences.

**F1 analogy:** If BPE merges whatever telemetry patterns appear most often (like "throttle-on + gear-up" which happens on every straight), WordPiece asks a smarter question: "Is this pair appearing together *more than you'd expect by chance*?" A rare but always-cooccurring pair like "anti-stall activation + clutch override" would score high in WordPiece even though each event alone is rare — because whenever one happens, the other always follows. That's a more meaningful pattern than just "the two most common signals."

In [None]:
class WordPieceTokenizer:
    """WordPiece tokenizer (BERT-style) from scratch."""
    
    def __init__(self):
        self.vocab = {}
        self.merges = {}
        self.merge_history = []
    
    def _get_word_freqs(self, text):
        words = re.findall(r'\S+', text.lower())
        word_freqs = Counter()
        for word in words:
            # WordPiece uses ## prefix for continuation tokens
            chars = tuple([word[0]] + ['##' + c for c in word[1:]])
            word_freqs[chars] += 1
        return word_freqs
    
    def _get_pair_scores(self, word_freqs):
        """Score pairs by likelihood ratio (WordPiece criterion)."""
        pair_freqs = Counter()
        token_freqs = Counter()
        
        for word, freq in word_freqs.items():
            for i in range(len(word) - 1):
                pair_freqs[(word[i], word[i+1])] += freq
            for token in word:
                token_freqs[token] += freq
        
        scores = {}
        for pair, freq in pair_freqs.items():
            denom = token_freqs[pair[0]] * token_freqs[pair[1]]
            scores[pair] = freq / denom if denom > 0 else 0
        
        return scores, pair_freqs
    
    def _merge_pair(self, word_freqs, pair):
        merged = pair[0] + pair[1].replace('##', '')
        new_word_freqs = {}
        
        for word, freq in word_freqs.items():
            new_word = []
            i = 0
            while i < len(word):
                if i < len(word) - 1 and word[i] == pair[0] and word[i+1] == pair[1]:
                    new_word.append(merged)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word_freqs[tuple(new_word)] = freq
        
        return new_word_freqs, merged
    
    def train(self, text, num_merges=20, verbose=True):
        word_freqs = self._get_word_freqs(text)
        
        all_tokens = set()
        for word in word_freqs:
            all_tokens.update(word)
        self.vocab = {t: i for i, t in enumerate(sorted(all_tokens))}
        
        if verbose:
            print(f"Initial vocab ({len(self.vocab)} tokens)")
            print(f"Training {num_merges} merges...\n")
        
        for step in range(num_merges):
            scores, pair_freqs = self._get_pair_scores(word_freqs)
            if not scores:
                break
            
            best_pair = max(scores, key=scores.get)
            word_freqs, merged_token = self._merge_pair(word_freqs, best_pair)
            
            self.merges[best_pair] = merged_token
            self.vocab[merged_token] = len(self.vocab)
            
            self.merge_history.append({
                'step': step + 1,
                'pair': best_pair,
                'merged': merged_token,
                'score': scores[best_pair],
                'freq': pair_freqs[best_pair]
            })
            
            if verbose:
                print(f"  Step {step+1}: merge '{best_pair[0]}' + '{best_pair[1]}' "
                      f"-> '{merged_token}' (score={scores[best_pair]:.4f}, freq={pair_freqs[best_pair]})")
        
        return self


# Train WordPiece on the same corpus
wp_tokenizer = WordPieceTokenizer()
wp_tokenizer.train(corpus, num_merges=15)

# Compare merge choices
print("\n" + "=" * 50)
print("BPE vs WordPiece merge comparison (first 10):")
print(f"{'Step':>5} {'BPE Merge':>20} {'WordPiece Merge':>20}")
for i in range(min(10, len(tokenizer.merge_history), len(wp_tokenizer.merge_history))):
    bpe = tokenizer.merge_history[i]['merged']
    wp = wp_tokenizer.merge_history[i]['merged']
    print(f"{i+1:>5} {bpe:>20} {wp:>20}")

---

## 4. Tokenization Impact on Models

The tokenizer directly affects what the model "sees". Let's measure how different tokenization granularities change sequence length and vocabulary coverage.

**F1 analogy:** This is the fundamental tradeoff every F1 data team faces: how granular should your telemetry encoding be? Character-level is like logging every sensor at 1000Hz — you capture everything but drown in data. Word-level is like logging only "completed a lap" — efficient but you've lost all the nuance. BPE-style subword encoding hits the sweet spot: recognizable patterns like "chicane-sequence" or "tire-deg-phase" that are compact but still informative.

In [None]:
# Compare tokenization approaches on different texts
test_corpus = [
    "The transformer architecture revolutionized natural language processing.",
    "Backpropagation computes gradients efficiently using the chain rule.",
    "Reinforcement learning from human feedback aligns language models.",
    "Self-attention enables parallel processing of sequential data.",
    "Neural networks approximate complex nonlinear functions.",
]

def char_tokenize(text):
    return list(text.lower())

def word_tokenize(text):
    return re.findall(r'\w+', text.lower())

def bpe_tokenize(text, tokenizer):
    _, tokens = tokenizer.encode(text)
    return tokens

# Train BPE on test corpus for fair comparison
all_text = ' '.join(test_corpus)
test_bpe = BPETokenizer()
test_bpe.train(all_text, num_merges=40, verbose=False)

print("Tokenization Comparison\n")
print(f"{'Text (first 50 chars)':>55} {'Char':>6} {'Word':>6} {'BPE':>6}")
print("-" * 80)

char_lens, word_lens, bpe_lens = [], [], []

for text in test_corpus:
    cl = len(char_tokenize(text))
    wl = len(word_tokenize(text))
    bl = len(bpe_tokenize(text, test_bpe))
    
    char_lens.append(cl)
    word_lens.append(wl)
    bpe_lens.append(bl)
    
    print(f"{text[:55]:>55} {cl:>6} {wl:>6} {bl:>6}")

print(f"\n{'Average':>55} {np.mean(char_lens):>6.1f} {np.mean(word_lens):>6.1f} {np.mean(bpe_lens):>6.1f}")

# Visualize
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
x = np.arange(len(test_corpus))
w = 0.25
ax.bar(x - w, char_lens, w, label='Character', color='#e74c3c', edgecolor='black', alpha=0.8)
ax.bar(x, bpe_lens, w, label='BPE', color='#2ecc71', edgecolor='black', alpha=0.8)
ax.bar(x + w, word_lens, w, label='Word', color='#3498db', edgecolor='black', alpha=0.8)

ax.set_xticks(x)
ax.set_xticklabels([f'Text {i+1}' for i in range(len(test_corpus))])
ax.set_ylabel('Sequence Length', fontsize=11)
ax.set_title('Sequence Length by Tokenization Method', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

---

## 5. Pretraining Objectives

How do we train a language model? The choice of **pretraining objective** determines what the model learns.

### Causal Language Modeling (CLM) — GPT-style
Predict the next token given all previous tokens:
$$P(x_t | x_1, x_2, \ldots, x_{t-1})$$

### Masked Language Modeling (MLM) — BERT-style
Randomly mask tokens and predict them from context:
$$P(x_\text{mask} | x_1, \ldots, x_{\text{mask}-1}, x_{\text{mask}+1}, \ldots, x_n)$$

| Feature | CLM (GPT) | MLM (BERT) | F1 Parallel |
|---------|-----------|------------|-------------|
| Direction | Left-to-right only | Bidirectional | Predicting the next sector time vs. inferring a missing mid-sector from surrounding data |
| Generation | Natural text generation | Not designed for generation | Live strategy calls (what happens next?) vs. post-race gap analysis |
| Understanding | Good but unidirectional | Excellent bidirectional | Forecasting from history vs. understanding a full race in hindsight |
| Use case | Chatbots, code gen, writing | Classification, NER, QA | Real-time pit wall prediction vs. post-race data classification |

**F1 analogy:** Causal LM is like the pit wall predicting what will happen *next* on track — given everything up to now, what's the next event? Masked LM is like a post-race analyst filling in missing telemetry gaps: "Given the braking zone entry and the corner exit, what must have happened at the apex?"

In [None]:
# Demonstrate both pretraining objectives

class SimpleTokenizer:
    """Minimal word-level tokenizer for demonstrations."""
    def __init__(self, texts):
        words = set()
        for text in texts:
            words.update(re.findall(r'\w+', text.lower()))
        
        self.word2id = {'<pad>': 0, '<mask>': 1, '<unk>': 2}
        for i, w in enumerate(sorted(words)):
            self.word2id[w] = i + 3
        self.id2word = {v: k for k, v in self.word2id.items()}
        self.vocab_size = len(self.word2id)
    
    def encode(self, text):
        return [self.word2id.get(w, 2) for w in re.findall(r'\w+', text.lower())]
    
    def decode(self, ids):
        return ' '.join(self.id2word.get(i, '?') for i in ids)


# Training data
train_texts = [
    "the cat sat on the mat",
    "the dog chased the cat",
    "a bird flew over the tree",
    "the fish swam in the pond",
    "a cat is a small animal",
    "the dog is a loyal friend",
    "birds can fly very high",
    "fish live in the water",
]

tok = SimpleTokenizer(train_texts)
print(f"Vocabulary size: {tok.vocab_size}")

# CLM: show next-token prediction setup
print("\n--- Causal Language Modeling (Next Token Prediction) ---")
example = "the cat sat on the mat"
tokens = example.split()
for i in range(1, len(tokens)):
    context = ' '.join(tokens[:i])
    target = tokens[i]
    print(f"  Input: '{context}' -> Predict: '{target}'")

# MLM: show masked prediction setup
print("\n--- Masked Language Modeling ---")
for text in train_texts[:3]:
    tokens = text.split()
    mask_idx = np.random.randint(0, len(tokens))
    original = tokens[mask_idx]
    masked = tokens.copy()
    masked[mask_idx] = '[MASK]'
    print(f"  Input: '{' '.join(masked)}' -> Predict: '{original}' at position {mask_idx}")

---

## 6. Training a Small Language Model

Let's train a tiny transformer language model from scratch to see the complete training loop: data preparation, causal masking, loss computation, and generation.

**F1 analogy:** This is like building a miniature version of the simulation system F1 teams use. Real teams train their models on hundreds of thousands of virtual laps. Our tiny LM is like a simplified simulator that learns patterns from a small dataset of race transcripts — not powerful enough for race day, but perfect for understanding how the learning process works.

In [None]:
class TinyLM(nn.Module):
    """Tiny transformer language model for demonstration."""
    
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=2, max_seq_len=32):
        super().__init__()
        self.d_model = d_model
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        
        # Transformer decoder layers
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 4,
            dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        
        self.ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        B, T = x.shape
        positions = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
        
        h = self.token_emb(x) + self.pos_emb(positions)
        
        # Causal mask (upper triangular = masked)
        causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        
        # Use transformer decoder with self-attention only (no encoder)
        h = self.transformer(h, h, tgt_mask=causal_mask, memory_mask=causal_mask)
        h = self.ln(h)
        logits = self.head(h)
        
        return logits
    
    def generate(self, start_tokens, max_new_tokens=10, temperature=1.0):
        """Autoregressive generation."""
        self.eval()
        tokens = start_tokens.clone()
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                logits = self(tokens)
                next_logits = logits[:, -1, :] / temperature
                probs = F.softmax(next_logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                tokens = torch.cat([tokens, next_token], dim=1)
        
        return tokens


# Prepare training data for CLM
def prepare_clm_data(texts, tokenizer, max_len=16):
    """Create input-target pairs for causal LM training."""
    all_ids = []
    for text in texts:
        ids = tokenizer.encode(text)
        if len(ids) > max_len:
            ids = ids[:max_len]
        else:
            ids = ids + [0] * (max_len - len(ids))
        all_ids.append(ids)
    
    data = torch.tensor(all_ids)
    # Input: all tokens except last; Target: all tokens except first
    inputs = data[:, :-1]
    targets = data[:, 1:]
    return inputs, targets


inputs, targets = prepare_clm_data(train_texts, tok)
print(f"Training data: {inputs.shape[0]} sequences, length {inputs.shape[1]}")
print(f"Vocab size: {tok.vocab_size}")
print(f"\nExample:")
print(f"  Input:  {tok.decode(inputs[0].tolist())}")
print(f"  Target: {tok.decode(targets[0].tolist())}")

In [None]:
# Train the model
model = TinyLM(tok.vocab_size, d_model=64, n_heads=4, n_layers=2)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

losses = []
n_epochs = 200

model.train()
for epoch in range(n_epochs):
    logits = model(inputs)
    
    # Cross-entropy loss (ignoring padding)
    loss = F.cross_entropy(
        logits.reshape(-1, tok.vocab_size),
        targets.reshape(-1),
        ignore_index=0  # Ignore padding
    )
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if (epoch + 1) % 40 == 0:
        print(f"  Epoch {epoch+1}/{n_epochs}: loss = {loss.item():.4f}")

# Generate some text
print("\nGeneration examples:")
prompts = ["the cat", "a bird", "the dog"]
for prompt in prompts:
    start = torch.tensor([tok.encode(prompt)])
    generated = model.generate(start, max_new_tokens=6, temperature=0.8)
    print(f"  '{prompt}' -> '{tok.decode(generated[0].tolist())}'")

In [None]:
# Visualize training
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curve
ax = axes[0]
ax.plot(losses, color='#3498db', linewidth=1, alpha=0.3)
# Smoothed
window = 10
smoothed = [np.mean(losses[max(0,i-window):i+1]) for i in range(len(losses))]
ax.plot(smoothed, color='#3498db', linewidth=2, label='Smoothed')
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Cross-Entropy Loss', fontsize=11)
ax.set_title('Language Model Training Loss', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# Token probability heatmap for a test sequence
ax = axes[1]
test_input = torch.tensor([tok.encode("the cat sat on the")])
model.eval()
with torch.no_grad():
    logits = model(test_input)
    probs = F.softmax(logits[0], dim=-1)

# Show top-5 predictions for each position
input_tokens = "the cat sat on the".split()
n_pos = len(input_tokens)
top_k = 5

heatmap = np.zeros((top_k, n_pos))
labels_y = []

for pos in range(n_pos):
    top_probs, top_ids = probs[pos].topk(top_k)
    for k in range(top_k):
        heatmap[k, pos] = top_probs[k].item()
        if pos == 0:
            labels_y.append(f"Top-{k+1}")

im = ax.imshow(heatmap, cmap='YlOrRd', aspect='auto')
ax.set_xticks(range(n_pos))
ax.set_xticklabels(input_tokens, fontsize=10)
ax.set_yticks(range(top_k))
ax.set_yticklabels(labels_y, fontsize=10)
ax.set_xlabel('Input Position', fontsize=11)
ax.set_title('Next Token Probabilities', fontsize=13, fontweight='bold')

# Add text annotations
for pos in range(n_pos):
    top_probs, top_ids = probs[pos].topk(top_k)
    for k in range(top_k):
        word = tok.id2word.get(top_ids[k].item(), '?')
        prob = top_probs[k].item()
        color = 'white' if prob > 0.3 else 'black'
        ax.text(pos, k, f'{word}\n{prob:.2f}', ha='center', va='center',
               fontsize=7, color=color, fontweight='bold')

plt.colorbar(im, ax=ax, label='Probability')
plt.tight_layout()
plt.show()

---

## 7. Scaling Laws

One of the most important discoveries in modern AI: model performance follows **predictable power laws** as you scale model size, dataset size, and compute.

### Chinchilla Scaling Laws (Hoffmann et al., 2022)

For a given compute budget $C$:
$$L(N, D) = \frac{A}{N^\alpha} + \frac{B}{D^\beta} + L_\infty$$

where $N$ = parameters, $D$ = tokens, and $\alpha \approx 0.34$, $\beta \approx 0.28$.

**Key insight**: Models should be trained on ~20x their parameter count in tokens. A 1B parameter model needs ~20B tokens.

### What This Means in Practice

| Model Size | Optimal Training Tokens | Approximate Compute | F1 Parallel |
|-----------|------------------------|--------------------|-------------|
| 1B | ~20B tokens | ~$10K | A backmarker team's simulation budget — basic wind tunnel + limited CFD |
| 7B | ~140B tokens | ~$100K | A midfield team's off-season development program |
| 70B | ~1.4T tokens | ~$2M | A top team's full-scale simulation farm for one season |
| 400B | ~8T tokens | ~$50M+ | The entire F1 grid's combined simulation capacity |

**F1 analogy:** Scaling laws in AI are remarkably similar to the diminishing returns in F1 development. Spending your first $10M on aero development gives huge lap time gains. The next $10M gives less. At some point, you need more *data* (track time, wind tunnel hours) not just a bigger model (more engineers). The Chinchilla insight — balance model size with training data — is exactly like the FIA cost cap philosophy: don't just throw money at the car; invest proportionally in testing, simulation, and track time.

In [None]:
# Simulate and visualize scaling laws

def scaling_law_loss(N, D, A=406.4, B=410.7, alpha=0.34, beta=0.28, L_inf=1.69):
    """Chinchilla-style scaling law for loss prediction.
    
    N: number of parameters
    D: number of training tokens
    """
    return A / (N ** alpha) + B / (D ** beta) + L_inf


fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# 1. Loss vs model size (fixed data)
ax = axes[0]
param_counts = np.logspace(6, 11, 50)  # 1M to 100B
for D_label, D in [('1B tokens', 1e9), ('10B tokens', 1e10), ('100B tokens', 1e11), ('1T tokens', 1e12)]:
    losses_pred = [scaling_law_loss(N, D) for N in param_counts]
    ax.plot(param_counts, losses_pred, linewidth=2, label=D_label)

ax.set_xscale('log')
ax.set_xlabel('Parameters (N)', fontsize=11)
ax.set_ylabel('Loss', fontsize=11)
ax.set_title('Loss vs Model Size', fontsize=13, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# 2. Loss vs data size (fixed model)
ax = axes[1]
token_counts = np.logspace(8, 13, 50)  # 100M to 10T
for N_label, N in [('100M params', 1e8), ('1B params', 1e9), ('10B params', 1e10), ('70B params', 7e10)]:
    losses_pred = [scaling_law_loss(N, D) for D in token_counts]
    ax.plot(token_counts, losses_pred, linewidth=2, label=N_label)

ax.set_xscale('log')
ax.set_xlabel('Training Tokens (D)', fontsize=11)
ax.set_ylabel('Loss', fontsize=11)
ax.set_title('Loss vs Data Size', fontsize=13, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# 3. Compute-optimal frontier (Chinchilla)
ax = axes[2]
compute_budgets = np.logspace(18, 24, 30)  # FLOPs

# For each compute budget, find optimal N and D
# Chinchilla: D ≈ 20 * N, and C ≈ 6 * N * D
optimal_N = []
optimal_D = []
optimal_loss = []

for C in compute_budgets:
    # C = 6ND, D = 20N -> C = 120N^2 -> N = sqrt(C/120)
    N = math.sqrt(C / 120)
    D = 20 * N
    optimal_N.append(N)
    optimal_D.append(D)
    optimal_loss.append(scaling_law_loss(N, D))

ax.plot(compute_budgets, optimal_loss, 'r-', linewidth=2, label='Compute-optimal')

# Also show suboptimal: too-big model, too-small model
for factor, label, color in [(0.2, 'Too small model', '#3498db'), (5, 'Too big model', '#f39c12')]:
    sub_losses = []
    for C in compute_budgets:
        N = factor * math.sqrt(C / 120)
        D = C / (6 * N)
        sub_losses.append(scaling_law_loss(N, D))
    ax.plot(compute_budgets, sub_losses, '--', linewidth=2, label=label, color=color)

ax.set_xscale('log')
ax.set_xlabel('Compute (FLOPs)', fontsize=11)
ax.set_ylabel('Loss', fontsize=11)
ax.set_title('Compute-Optimal Scaling', fontsize=13, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

plt.suptitle('Neural Scaling Laws', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Print some specific predictions
print("Scaling Law Predictions:\n")
for N_label, N in [('1B', 1e9), ('7B', 7e9), ('70B', 7e10)]:
    D_optimal = 20 * N
    loss = scaling_law_loss(N, D_optimal)
    print(f"  {N_label} params + {D_optimal/1e9:.0f}B tokens: predicted loss = {loss:.3f}")

---

## 8. Training Dynamics

Successfully training a language model requires careful management of the training process.

### Key Concepts

| Technique | Purpose | Details | F1 Parallel |
|-----------|---------|--------|-------------|
| **Learning rate warmup** | Stabilize early training | Linearly increase LR for first N steps | Warming up tires on an out-lap before pushing — go too hard too early and you spin |
| **Cosine schedule** | Gradual LR decay | Smoothly decrease LR following cosine curve | Fuel-load management — push hard early in a stint, then manage as tires degrade |
| **Gradient clipping** | Prevent exploding gradients | Cap gradient norm at max value | Rev limiter — prevents the engine from destroying itself under full load |
| **Weight decay** | Regularization | L2 penalty on parameters | Minimum weight regulations — prevents the car from being optimized into fragility |
| **Mixed precision** | Speed + memory savings | FP16 compute, FP32 accumulation | Using lower-precision sensors where acceptable, high-precision only where critical |

In [None]:
class LRScheduler:
    """Learning rate schedules commonly used in LLM training."""
    
    @staticmethod
    def cosine_with_warmup(step, total_steps, warmup_steps, max_lr, min_lr=0):
        """Cosine schedule with linear warmup."""
        if step < warmup_steps:
            # Linear warmup
            return max_lr * step / warmup_steps
        
        # Cosine decay
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
    
    @staticmethod
    def linear_with_warmup(step, total_steps, warmup_steps, max_lr):
        """Linear decay with warmup."""
        if step < warmup_steps:
            return max_lr * step / warmup_steps
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return max_lr * (1 - progress)
    
    @staticmethod
    def constant_with_warmup(step, total_steps, warmup_steps, max_lr):
        """Constant LR with warmup."""
        if step < warmup_steps:
            return max_lr * step / warmup_steps
        return max_lr


# Visualize schedules
total_steps = 1000
warmup_steps = 100
max_lr = 3e-4

steps = range(total_steps)

schedules = {
    'Cosine + Warmup': [LRScheduler.cosine_with_warmup(s, total_steps, warmup_steps, max_lr, max_lr * 0.1) for s in steps],
    'Linear + Warmup': [LRScheduler.linear_with_warmup(s, total_steps, warmup_steps, max_lr) for s in steps],
    'Constant + Warmup': [LRScheduler.constant_with_warmup(s, total_steps, warmup_steps, max_lr) for s in steps],
}

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# LR schedules
ax = axes[0]
colors = ['#3498db', '#e74c3c', '#2ecc71']
for (name, lrs), color in zip(schedules.items(), colors):
    ax.plot(steps, lrs, linewidth=2, label=name, color=color)

ax.axvline(x=warmup_steps, color='gray', linestyle=':', alpha=0.5, label='Warmup end')
ax.set_xlabel('Training Step', fontsize=11)
ax.set_ylabel('Learning Rate', fontsize=11)
ax.set_title('Learning Rate Schedules', fontsize=13, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Simulated training loss for different schedules
ax = axes[1]
np.random.seed(42)
base_loss = 4.0

for (name, lrs), color in zip(schedules.items(), colors):
    # Simulate loss curve (lower LR = smoother convergence)
    sim_loss = [base_loss]
    for i in range(1, total_steps):
        lr = lrs[i]
        # Loss decreases proportional to LR, with noise
        decrease = lr / max_lr * 0.005
        noise = np.random.normal(0, 0.02)
        new_loss = max(1.5, sim_loss[-1] - decrease + noise)
        sim_loss.append(new_loss)
    
    # Smooth for display
    w = 20
    smoothed = [np.mean(sim_loss[max(0,i-w):i+1]) for i in range(len(sim_loss))]
    ax.plot(steps, smoothed, linewidth=2, label=name, color=color)

ax.set_xlabel('Training Step', fontsize=11)
ax.set_ylabel('Loss', fontsize=11)
ax.set_title('Simulated Training Loss', fontsize=13, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## Exercises

### Exercise 1: Unigram Tokenizer

Implement a **Unigram** tokenizer (used by SentencePiece/T5). Unlike BPE which builds up by merging, Unigram starts with a large vocabulary and prunes tokens that contribute least to the likelihood of the training data. Compare its vocabulary with BPE on the same corpus.

**F1 scenario:** Imagine you start with an enormous telemetry codebook containing every possible sensor pattern (all substrings up to length k). Your job is to slim it down by removing the patterns that matter least to reconstructing real race data — keeping the codebook compact while still covering the critical patterns.

In [None]:
# Exercise 1: Your code here
# Hint: Start with all substrings up to length k, then iteratively remove
# the token whose removal increases training loss the least.


### Exercise 2: Masked Language Model

Modify the TinyLM to support masked language modeling (BERT-style). Randomly mask 15% of tokens, and train the model to predict the masked tokens. Compare the learned representations with the CLM model.

**F1 scenario:** Instead of predicting "what happens next on track," train a model that fills in missing telemetry gaps. Mask out 15% of a lap's data points and train the model to reconstruct them from context — like a race engineer reconstructing corrupted sectors from the data around them.

In [None]:
# Exercise 2: Your code here
# Hint: Remove the causal mask, add random masking to inputs,
# and only compute loss on masked positions.


### Exercise 3: Scaling Experiment

Train 3 versions of TinyLM with different sizes (e.g., d_model=32, 64, 128) on the same data. Plot their loss curves together and verify that larger models converge faster. Does the scaling law prediction hold even at this tiny scale?

**F1 scenario:** Think of this as comparing three different simulation rigs: a basic desktop sim (d_model=32), a professional driver-in-the-loop simulator (d_model=64), and a full-scale hydraulic platform (d_model=128). All trained on the same track data. The bigger rigs should converge to accurate predictions faster — but do the scaling laws hold even at "model car" scale?

In [None]:
# Exercise 3: Your code here
# Hint: Loop over model sizes, train each, collect loss curves, plot together.


---

## Summary

### Key Concepts

| Concept | What It Does | F1 Parallel |
|---------|-------------|-------------|
| **Subword tokenization** (BPE, WordPiece) | Balances vocabulary size with sequence length | Building a codebook of telemetry patterns — common sequences get single tokens, rare events decompose |
| **BPE** | Iteratively merges the most frequent adjacent pairs | Learning the most common sensor co-occurrences across thousands of laps |
| **WordPiece** | Uses a likelihood ratio instead of raw frequency | Finding patterns that are *surprisingly* co-occurring, not just frequent |
| **Causal LM** (GPT-style) | Predicts the next token | Pit wall predicting the next event on track from everything so far |
| **Masked LM** (BERT-style) | Fills in masked tokens from bidirectional context | Reconstructing corrupted telemetry from surrounding data |
| **Scaling laws** | Loss decreases as a power law with model size, data, and compute | Diminishing returns on development spend — but predictable ones |
| **Chinchilla optimal** | ~20 tokens per parameter | Balance car development budget with testing/simulation time |
| **Training dynamics** | Warmup, cosine decay, gradient clipping for stability | Tire warmup laps, fuel management, rev limiters for the training process |

### Why This Matters

Tokenization and pretraining are the foundation that everything else builds on. The tokenizer determines what the model can represent — like how the telemetry encoding determines what patterns the pit wall can detect. The pretraining objective determines what it learns — predicting the future (CLM) or understanding context (MLM). The scaling laws tell us how to allocate our compute budget — just as an F1 team must allocate its cost cap between car development, testing, and race operations. Every downstream capability — from following instructions to writing code — depends on getting these foundations right.

---

## Next Steps

Now that we understand how language models are trained from scratch — building the telemetry codebook (tokenization), teaching the model to predict the next event (pretraining), and understanding the budget tradeoffs (scaling laws) — the next question is: how do we make these models *fast* in production? In **Notebook 30: Inference Optimization**, we'll explore the techniques that take a model from the simulation farm to the live pit wall — quantization, KV caching, speculative decoding, and more.