# Model Comparison: The Evolution

*From word salad to... hopefully something better.*

---

This notebook loads all trained models from the Lil Transformy sequence and compares their outputs on the same prompts. Run it whenever you want to see how far we've come.

**Models in the main sequence:**
- 03: Bigram (embed → unembed)
- 04: + Attention
- 05: + Positional encoding
- 06: + FFN
- 07: + Residuals & LayerNorm ← THE FISH HAS A SPINE
- 08: + Stacked blocks (coming soon)
- 09: + Multi-head attention (coming soon)

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pathlib import Path
import json
import math

# Reproducibility for generation
torch.manual_seed(42)

# Device
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(f"Using device: {device}")

Using device: mps


In [26]:
from transformers import GPT2TokenizerFast

class LilTokenizer:
    """Compact tokenizer for Lil Transformy."""
    
    def __init__(self, gpt2_to_compact, compact_to_gpt2, vocab_size):
        self.gpt2_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.gpt2_to_compact = gpt2_to_compact
        self.compact_to_gpt2 = compact_to_gpt2
        self.vocab_size = vocab_size
        self.pad_id = 0
        self.unk_id = 1
        self.eos_id = 2
    
    def encode(self, text, add_eos=True):
        gpt2_tokens = self.gpt2_tokenizer.encode(text)
        compact_tokens = [self.gpt2_to_compact.get(t, self.unk_id) for t in gpt2_tokens]
        if add_eos:
            compact_tokens.append(self.eos_id)
        return compact_tokens
    
    def decode(self, token_ids):
        gpt2_tokens = []
        for tid in token_ids:
            if tid in [self.pad_id, self.unk_id, self.eos_id]:
                continue
            if tid in self.compact_to_gpt2:
                gpt2_tokens.append(self.compact_to_gpt2[tid])
        return self.gpt2_tokenizer.decode(gpt2_tokens)
    
    def __len__(self):
        return self.vocab_size
    
    @classmethod
    def load(cls, path):
        with open(path, 'r') as f:
            config = json.load(f)
        gpt2_to_compact = {int(k): v for k, v in config['gpt2_to_compact'].items()}
        compact_to_gpt2 = {int(k): v for k, v in config['compact_to_gpt2'].items()}
        return cls(gpt2_to_compact, compact_to_gpt2, config['vocab_size'])


tokenizer = LilTokenizer.load('tokenizer/tokenizer.json')
VOCAB_SIZE = len(tokenizer)
print(f"Vocabulary size: {VOCAB_SIZE:,}")

Vocabulary size: 4,096


---

## Model Definitions

We need to define each model architecture to load the weights. This will grow as we add more notebooks.

In [27]:
# === 03: Bigram ===

class BigramLM(nn.Module):
    """Notebook 03: Simplest autoregressive model. Each position predicts next from itself only."""
    
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.unembed = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        return self.unembed(self.embedding(x))
    
    def generate(self, prompt_tokens, max_new_tokens=50, temperature=1.0):
        self.eval()
        tokens = list(prompt_tokens)
        generated = []
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                # Bigram only looks at last token
                x = torch.tensor([[tokens[-1]]], device=next(self.parameters()).device)
                logits = self.forward(x)
                probs = F.softmax(logits[0, 0] / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).item()
                
                generated.append(next_token)
                tokens.append(next_token)
                
                if next_token == 2:  # EOS
                    break
        
        return generated

In [28]:
# === 04: Attention ===

class CausalSelfAttention(nn.Module):
    """Single-head causal self-attention."""
    
    def __init__(self, d_model, max_seq_len=256):
        super().__init__()
        self.d_model = d_model
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
        self.register_buffer('mask', mask)
        self.scale = math.sqrt(d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        Q, K, V = self.W_q(x), self.W_k(x), self.W_v(x)
        scores = (Q @ K.transpose(-2, -1)) / self.scale
        scores = scores.masked_fill(self.mask[:T, :T], float('-inf'))
        attn = F.softmax(scores, dim=-1)
        return self.W_o(attn @ V)


class AttentionLM(nn.Module):
    """Notebook 04: Bigram + single-head attention."""
    
    def __init__(self, vocab_size, d_model, max_seq_len=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.attention = CausalSelfAttention(d_model, max_seq_len)
        self.unembed = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        return self.unembed(self.attention(self.embedding(x)))
    
    def generate(self, prompt_tokens, max_new_tokens=50, temperature=1.0):
        self.eval()
        tokens = list(prompt_tokens)
        generated = []
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                context = tokens[-256:]
                x = torch.tensor([context], device=next(self.parameters()).device)
                logits = self.forward(x)
                probs = F.softmax(logits[0, -1] / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).item()
                
                generated.append(next_token)
                tokens.append(next_token)
                
                if next_token == 2:
                    break
        
        return generated


# === 05: Attention + Position ===

class PositionalAttentionLM(nn.Module):
    """Notebook 05: Attention + learned positional embeddings."""
    
    def __init__(self, vocab_size, d_model, max_seq_len=256):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.attention = CausalSelfAttention(d_model, max_seq_len)
        self.unembed = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        B, T = x.shape
        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(T, device=x.device))
        return self.unembed(self.attention(tok_emb + pos_emb))
    
    def generate(self, prompt_tokens, max_new_tokens=50, temperature=1.0):
        self.eval()
        tokens = list(prompt_tokens)
        generated = []
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                context = tokens[-self.max_seq_len:]
                x = torch.tensor([context], device=next(self.parameters()).device)
                logits = self.forward(x)
                probs = F.softmax(logits[0, -1] / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).item()
                
                generated.append(next_token)
                tokens.append(next_token)
                
                if next_token == 2:
                    break
        
        return generated


# === 06: Attention + Position + FFN ===

class FeedForward(nn.Module):
    """Position-wise feedforward network."""
    
    def __init__(self, d_model, d_ff=None):
        super().__init__()
        if d_ff is None:
            d_ff = 4 * d_model
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))


class AttentionFFNLM(nn.Module):
    """Notebook 06: Attention + Position + FFN."""
    
    def __init__(self, vocab_size, d_model, d_ff=None, max_seq_len=256):
        super().__init__()
        self.max_seq_len = max_seq_len
        if d_ff is None:
            d_ff = 4 * d_model
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.attention = CausalSelfAttention(d_model, max_seq_len)
        self.ffn = FeedForward(d_model, d_ff)
        self.unembed = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        B, T = x.shape
        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(T, device=x.device))
        attended = self.attention(tok_emb + pos_emb)
        processed = self.ffn(attended)
        return self.unembed(processed)
    
    def generate(self, prompt_tokens, max_new_tokens=50, temperature=1.0):
        self.eval()
        tokens = list(prompt_tokens)
        generated = []
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                context = tokens[-self.max_seq_len:]
                x = torch.tensor([context], device=next(self.parameters()).device)
                logits = self.forward(x)
                probs = F.softmax(logits[0, -1] / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).item()
                
                generated.append(next_token)
                tokens.append(next_token)
                
                if next_token == 2:
                    break
        
        return generated


# === 07: Transformer Block (Residuals + LayerNorm) ===

class TransformerBlock(nn.Module):
    """A single transformer block with pre-norm architecture."""
    
    def __init__(self, d_model, d_ff=None, max_seq_len=256):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.attention = CausalSelfAttention(d_model, max_seq_len)
        self.ffn = FeedForward(d_model, d_ff)
    
    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


class TransformerLM(nn.Module):
    """Notebook 07: Proper transformer block with residuals and LayerNorm."""
    
    def __init__(self, vocab_size, d_model, d_ff=None, max_seq_len=256):
        super().__init__()
        self.max_seq_len = max_seq_len
        if d_ff is None:
            d_ff = 4 * d_model
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(max_seq_len, d_model)
        self.block = TransformerBlock(d_model, d_ff, max_seq_len)
        self.ln_final = nn.LayerNorm(d_model)
        self.unembed = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        B, T = x.shape
        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(T, device=x.device))
        x = tok_emb + pos_emb
        x = self.block(x)
        x = self.ln_final(x)
        return self.unembed(x)
    
    def generate(self, prompt_tokens, max_new_tokens=50, temperature=1.0):
        self.eval()
        tokens = list(prompt_tokens)
        generated = []
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                context = tokens[-self.max_seq_len:]
                x = torch.tensor([context], device=next(self.parameters()).device)
                logits = self.forward(x)
                probs = F.softmax(logits[0, -1] / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).item()
                
                generated.append(next_token)
                tokens.append(next_token)
                
                if next_token == 2:
                    break
        
        return generated

In [29]:
# === Future models will be added here ===
#
# 08: StackedTransformerLM
# 09: MultiHeadTransformerLM
#
# Each will be added as we build the notebooks.

print("Model architectures defined.")

Model architectures defined.


---

## Load Available Models

Check which checkpoints exist and load them.

In [30]:
# Registry of models: (checkpoint_file, model_class, display_name, extra_kwargs)
MODEL_REGISTRY = [
    ('03_bigram.pt', BigramLM, '03: Bigram', {}),
    ('04_attention.pt', AttentionLM, '04: + Attention (3 ep)', {}),
    ('05_positional.pt', PositionalAttentionLM, '05: + Position (1 ep)', {'max_seq_len': 256}),
    ('06_ffn.pt', AttentionFFNLM, '06: + FFN (1 ep)', {'max_seq_len': 256}),
    ('07_transformer_block.pt', TransformerLM, '07: + Residual (1 ep)', {'max_seq_len': 256}),
    # Future:
    # ('08_stacked.pt', ..., '08: Stacked'),
    # ('09_multihead.pt', ..., '09: Multi-Head'),
]

models = {}
stats = {}

print("Loading models...")
print("=" * 60)

for checkpoint_file, model_class, name, extra_kwargs in MODEL_REGISTRY:
    path = Path(checkpoint_file)
    if path.exists():
        checkpoint = torch.load(path, map_location='cpu', weights_only=False)
        
        # Handle d_ff for FFN models
        kwargs = {'vocab_size': checkpoint['vocab_size'], 'd_model': checkpoint['d_model']}
        kwargs.update(extra_kwargs)
        if 'd_ff' in checkpoint:
            kwargs['d_ff'] = checkpoint['d_ff']
        
        # Create model
        model = model_class(**kwargs)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)
        model.eval()
        
        models[name] = model
        stats[name] = {
            'params': sum(p.numel() for p in model.parameters()),
            'final_ppl': checkpoint['history']['val_perplexity'][-1]
        }
        
        print(f"✓ {name}")
        print(f"    Parameters: {stats[name]['params']:,}")
        print(f"    Final perplexity: {stats[name]['final_ppl']:.1f}")
    else:
        print(f"✗ {name} (not found: {checkpoint_file})")

print()
print(f"Loaded {len(models)} models.")

Loading models...
✓ 03: Bigram
    Parameters: 1,052,672
    Final perplexity: 35.8
✓ 04: + Attention (3 ep)
    Parameters: 1,118,208
    Final perplexity: 25.0
✓ 05: + Position (1 ep)
    Parameters: 1,150,976
    Final perplexity: 17.7
✓ 06: + FFN (1 ep)
    Parameters: 1,282,688
    Final perplexity: 13.4
✓ 07: + Residual (1 ep)
    Parameters: 1,283,456
    Final perplexity: 10.9

Loaded 5 models.


---

## Side-by-Side Generation

Give all models the same prompt, see what they produce.

In [31]:
def compare_generations(prompt, max_tokens=50, temperature=1.0, seed=None):
    """
    Generate from all loaded models with the same prompt.

    If seed is provided, it's combined with a hash of the prompt so that:
    - Same prompt + same seed = reproducible across runs
    - Different prompts + same seed = different outputs (not stuck in attractors)
    - All models for a given prompt get the same seed (fair comparison)
    """
    prompt_tokens = tokenizer.encode(prompt, add_eos=False)

    print("=" * 70)
    print(f"PROMPT: {prompt}")
    print("=" * 70)

    # Combine seed with prompt hash so different prompts get different randomness
    if seed is not None:
        prompt_seed = seed + hash(prompt) % 10000
    else:
        prompt_seed = None

    for name, model in models.items():
        if prompt_seed is not None:
            torch.manual_seed(prompt_seed)

        generated = model.generate(prompt_tokens, max_new_tokens=max_tokens, temperature=temperature)
        text = tokenizer.decode(generated)

        ppl = stats[name]['final_ppl']
        print(f"\n{name} (ppl={ppl:.1f}):")
        print(f"  {text}")

    print()

In [32]:
# The classics
print("\n" + "#" * 70)
print("# CLASSIC PROMPTS")
print("#" * 70)

compare_generations("Once upon a time", seed=42)
compare_generations("The little girl", seed=42)
compare_generations("He was very", seed=42)


######################################################################
# CLASSIC PROMPTS
######################################################################
PROMPT: Once upon a time

03: Bigram (ppl=35.8):
  , "It's go away. The sun. They walked. The next to row of toys we can play with a cat's phone. 




"Here, let you to touch it.

But she noticed that he

04: + Attention (3 ep) (ppl=25.0):
  , there was a time-. Every morning, the comet noticed a real starfish in a little bug. There, the hole in the park and jump high. The tree?"

The grass and it felt very special village. She remembered

05: + Position (1 ep) (ppl=17.7):
  , there was a dog named Buddy. Buddy Buddy Buddy loved. Buddy Buddy Buddy Buddy loved eat food. Buddy loved horses and eat yummy.
 One day, he had found lots of wild to touch it.

As she noticed that he

06: + FFN (1 ep) (ppl=13.4):
  , there were two friends named Lily. One day, they decided to explore the forest to eat food.

"Let's go and play together!"

In [33]:
# Story starters
print("\n" + "#" * 70)
print("# STORY STARTERS")
print("#" * 70)

compare_generations("Once upon a time there was a little girl named Lily. She", seed=42)
compare_generations("The big dog and the small cat were", seed=42)
compare_generations("One sunny day, the children went to the park to", seed=42)


######################################################################
# STORY STARTERS
######################################################################
PROMPT: Once upon a time there was a little girl named Lily. She

03: Bigram (ppl=35.8):
   was very angry.
Sally agreed.

"Mom put the baby found a jungle with your cars again. They see a.
The bird is shiny and ran to find it flew away. He went when he got ready to the

04: + Attention (3 ep) (ppl=25.0):
   loved very creative girl. One day Sarah asked her dad, "Why you are sad and told her cat to visit her not afraid to go home now they all the dream?" she asked her, so so she asked her when it got to bake the

05: + Position (1 ep) (ppl=17.7):
   loved to write. She loved her. One day, the world was a baby when she went with a towel everywhere she went for a.
 away next morning?" she asked her mom.

"Mom said, "We're sorry

06: + FFN (1 ep) (ppl=13.4):
   loved to play outside. One day, she saw a huge mountain near her house.

In [34]:
# More challenging - requires context
print("\n" + "#" * 70)
print("# CONTEXT-DEPENDENT PROMPTS")
print("#" * 70)

# These require remembering earlier context
compare_generations("Lily had a red ball. She loved to play with her", seed=42)
compare_generations("The boy was sad because his toy was broken. His mom said", seed=42)


######################################################################
# CONTEXT-DEPENDENT PROMPTS
######################################################################
PROMPT: Lily had a red ball. She loved to play with her

03: Bigram (ppl=35.8):
   imagination and ran back to throw things. They liked to move really fast and started to a notebook was sad because it. She knew it."
Joe said. She was a big red feather right. She stirred.
And that?"



04: + Attention (3 ep) (ppl=25.0):
   imagination and ran back to throw things. They ran to their heads. Their cries too.
Lily and hugged her shoe and showed them her brother, but she was very much. She always are safe.



One day,

05: + Position (1 ep) (ppl=17.7):
   friends and ran back to the ball until he finally got tired really.

06: + FFN (1 ep) (ppl=13.4):
   toys and ran back to their toys. They liked to play with her toys and books.

One day, they saw a big dog playing with the things. They wanted some people into the. They na

---

## Summary Statistics

In [35]:
print("=" * 60)
print("MODEL COMPARISON")
print("=" * 60)
print()
print(f"{'Model':<25} {'Parameters':<15} {'Perplexity':<12}")
print("-" * 52)

for name in models.keys():
    params = stats[name]['params']
    ppl = stats[name]['final_ppl']
    print(f"{name:<25} {params:>12,}   {ppl:>8.1f}")

print()
print("Lower perplexity = less surprised by correct answer = better.")

MODEL COMPARISON

Model                     Parameters      Perplexity  
----------------------------------------------------
03: Bigram                   1,052,672       35.8
04: + Attention (3 ep)       1,118,208       25.0
05: + Position (1 ep)        1,150,976       17.7
06: + FFN (1 ep)             1,282,688       13.4
07: + Residual (1 ep)        1,283,456       10.9

Lower perplexity = less surprised by correct answer = better.


---

## Try Your Own Prompts

Modify the cell below to test whatever you want.

In [36]:
# Your prompt here!
compare_generations(
    "Once upon a time there was a little girl who lived in a magical forest. She",
    max_tokens=60,
    temperature=1.0,
    seed=42  # Set to None for random each time
)

PROMPT: Once upon a time there was a little girl who lived in a magical forest. She

03: Bigram (ppl=35.8):
   wanted to break their! She met a toy car and loud noise. The fireman. All of the joy! You can do not disturbed Sally was too strong house. They shouted. His mommy. She cried.


The bird waved her and not want my toys, Lily is

04: + Attention (3 ep) (ppl=25.0):
   was amazed.

05: + Position (1 ep) (ppl=17.7):
   loved to spin around in the forest with joy. Every morning, she even in the forest, but mom the woods today when she met a great wonderful day, the forest found a sweet thing. The bird was very tasty, a fancy bird birdcage and told her that her message's mom

06: + FFN (1 ep) (ppl=13.4):
   wanted to find her way. She met a toy they would take a new home. On the way to the biggest scale, and pebbles had wonderful day!

So Sarah's mom took her to a. When Grace's mom saw that her and told her it was the perfect cup

07: + Residual (1 ep) (ppl=10.9):
   wanted to find so