# Lesson 3 — N-gram Language Models & Cross-Entropy
**Goal:** Train unigram, bigram, and trigram language models on your custom corpus; generate sample text; and measure how surprised the model is using cross-entropy.

**Where we’re headed**
- An **N-gram** model predicts the next token by looking at the previous `N-1` tokens. Example: a trigram (N=3) predicts the next word using the last two words.
- By counting how often sequences appear in your text, you can estimate probabilities like `P("spaceship" | "the", "silver")`.
- Once you know those probabilities, you can both *sample* new text and *score* how likely a sentence is.

**Vocabulary check**
- **Unigram / Bigram / Trigram:** Models that use 0, 1, or 2 previous tokens respectively.
- **Conditional probability:** The chance of event A given event B already happened, written `P(A | B)`.
- **Add-k smoothing:** A trick for avoiding zero probabilities by pretending you saw every possible sequence a tiny bit (`k`) times.
- **Cross-entropy / Perplexity:** Numbers that summarize how well the model predicts real text. Lower is better because the model is less surprised.

**Step-by-step roadmap**
1. **Prepare the corpus.** Reuse your tokenizer from Lesson 1 or keep things word-based for easier counting.
2. **Count N-grams.** Use Python dictionaries or `collections.Counter` to count how often each context→next-token pair appears.
3. **Turn counts into probabilities.** Divide counts by the total count for that context, optionally applying add-k smoothing.
4. **Sample text.** Starting from a special start token, repeatedly draw the next token based on the probabilities you learned.
5. **Evaluate with cross-entropy.** Feed in a held-out sentence, look up the probability of each token, and compute the average negative log probability.

> 🧠 Mental model: Imagine predicting the next word in a story. If you just read "The dragon breathed", you strongly expect "fire". N-gram models capture that instinct by counting real examples.

In [None]:

from pathlib import Path
import re, random, math, collections

data_dir = Path("../data")
text = ""
for fname in ["space.txt","animals.txt","minecraft.txt"]:
    text += (data_dir / fname).read_text(encoding="utf-8") + "\n"

tokens = re.findall(r"[a-zA-Z']+|[.,!?;:]", text.lower())


In [None]:

def ngrams(tokens, n):
    for i in range(len(tokens)-n+1):
        yield tuple(tokens[i:i+n])

def train_ngram(tokens, n=2, k=1.0):
    counts = collections.Counter(ngrams(tokens, n))
    ctx_counts = collections.Counter(ngrams(tokens, n-1)) if n>1 else None
    vocab = sorted(set(tokens))
    V = len(vocab)
    def prob(context, w):
        if n == 1:
            return (counts[(w,)] + k) / (len(tokens) + k*V)
        else:
            c = counts[context + (w,)]
            ctx = ctx_counts[context]
            return (c + k) / (ctx + k*V)
    return prob, vocab

unigram, V1 = train_ngram(tokens, 1, k=1.0)
bigram, V2 = train_ngram(tokens, 2, k=0.5)
trigram, V3 = train_ngram(tokens, 3, k=0.1)
len(V1), len(V2), len(V3)


In [None]:

import random

def sample(prob, vocab, n=2, max_len=30, temperature=1.0, seed=None):
    random.seed(seed)
    result = []
    if n == 1:
        context = ()
    elif n == 2:
        context = (random.choice(vocab),)
    else:
        context = (random.choice(vocab), random.choice(vocab))

    for _ in range(max_len):
        # build distribution
        scores = []
        for w in vocab:
            p = prob(context, w)
            scores.append(p ** (1.0/temperature))
        s = sum(scores)
        r = random.random() * s
        cum = 0.0
        for w, sc in zip(vocab, scores):
            cum += sc
            if cum >= r:
                result.append(w)
                break
        # advance context
        if n == 1:
            context = ()
        elif n == 2:
            context = (w,)
        else:
            context = (context[-1], w)
    return " ".join(result)

print("Bigram sample:", sample(bigram, V2, n=2, temperature=0.8, seed=42))
print("Trigram sample:", sample(trigram, V3, n=3, temperature=0.9, seed=7))


In [None]:

def cross_entropy(prob, vocab, tokens, n):
    # evaluate on held-out tail
    split = int(0.8*len(tokens))
    test = tokens[split:]
    H = 0.0
    count = 0
    for i in range(len(test)):
        if n == 1:
            context = ()
        elif n == 2:
            if i < 1: continue
            context = (test[i-1],)
        else:
            if i < 2: continue
            context = (test[i-2], test[i-1])
        p = max(prob(context, test[i]), 1e-12)
        H += -math.log2(p)
        count += 1
    return H / max(count,1)

for n, p, V in [(1, unigram, V1),(2, bigram, V2),(3, trigram, V3)]:
    H = cross_entropy(p, V, tokens, n)
    ppl = 2**H
    print(f"{n}-gram: cross-entropy={H:.2f}, perplexity={ppl:.2f}")


### Challenges
- **Smoothing sweeps:** Change the add-k value (0, 0.1, 1.0) and record how perplexity changes on a held-out paragraph.
- **Order comparison:** Plot perplexity for unigram vs. bigram vs. trigram. When does adding more context help? When does data sparsity hurt?
- **Creative sampling:** Try temperature scaling when sampling to make the output more or less random. Compare to deterministic "pick the max" decoding.
- **Error detective:** Find examples where the trigram model makes a weird prediction. Can you explain it from the counts it saw?