# BPE Language Model Comparison


In [1]:
import math
import os
import time
import copy
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from bpeasy.tokenizer import BPEasyTokenizer


In [2]:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

torch.manual_seed(42)
if device == "cuda":
    torch.cuda.manual_seed_all(42)

print(f"Using device: {device}")


Using device: mps


In [3]:
data_path = Path("input.txt")
raw_text = data_path.read_text(encoding="utf-8")
print(f"Total characters: {len(raw_text):,}")
print(raw_text[:500])


Total characters: 1,115,393
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [None]:
tokenizer = BPEasyTokenizer.train(
    iterator=iter([raw_text]),
    vocab_size=10000,
    max_token_length=256,
    name="tiny-bpe",
)

tokenizer_path = Path("tokenizer_bpeasy.json")
tokenizer.save(tokenizer_path.as_posix())
print(f"Tokenizer saved to {tokenizer_path}")

vocab_size = len(tokenizer)
print(f"Vocabulary size: {vocab_size}")

encoded = tokenizer.encode(raw_text)
data = torch.tensor(encoded, dtype=torch.long)
train_ratio = 0.9
n = int(train_ratio * len(data))
train_data = data[:n]
val_data = data[n:]
print(f"Total tokens: {len(data):,}")
print(f"Train tokens: {len(train_data):,}, Val tokens: {len(val_data):,}")


Tokenizer saved to tokenizer_bpeasy.json
Vocabulary size: 10000
Total tokens: 277,628
Train tokens: 249,865, Val tokens: 27,763


In [None]:
# Hyperparameters
batch_size = 48
block_size = 192
max_iters = 300
eval_interval = 100
eval_iters = 20
learning_rate = 3e-4
grad_clip = 1.0
patience = 6
min_delta = 1e-3
n_embd = 192
n_head = 6 
n_layer = 3
dropout = 0.2

# RNN
hidden_size = 192

env_max_iters = os.environ.get("LM_MAX_ITERS")
if env_max_iters:
    max_iters = min(int(env_max_iters), max_iters)
    eval_interval = min(eval_interval, max_iters)

env_eval_iters = os.environ.get("LM_EVAL_ITERS")
if env_eval_iters:
    eval_iters = min(int(env_eval_iters), eval_iters)

def get_batch(split: str) -> tuple[torch.Tensor, torch.Tensor]:
    data_source = train_data if split == "train" else val_data
    idx = torch.randint(0, len(data_source) - block_size - 1, (batch_size,))
    x = torch.stack([data_source[i:i + block_size] for i in idx])
    y = torch.stack([data_source[i + 1:i + block_size + 1] for i in idx])
    return x.to(device), y.to(device)


@torch.no_grad()
def estimate_loss(model: nn.Module, num_batches: int = eval_iters) -> dict[str, float]:
    model.eval()
    losses: dict[str, float] = {}
    for split in ("train", "val"):
        split_losses = torch.zeros(num_batches)
        for k in range(num_batches):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)
            split_losses[k] = loss.item()
        losses[split] = split_losses.mean().item()
    model.train()
    return losses


def loss_to_perplexity(loss: float) -> float:
    return math.exp(loss)


In [6]:
class RNNLanguageModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        n_embd: int = 256,
        hidden_size: int = 512,
        num_layers: int = 2,
        dropout: float = 0.2,
    ) -> None:
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, n_embd)
        self.dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(
            n_embd,
            hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
        )
        self.proj = nn.Linear(hidden_size, vocab_size)

    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
        x = self.embedding(idx)
        x = self.dropout(x)
        output, _ = self.lstm(x)
        output = self.dropout(output)
        logits = self.proj(output)
        loss = None
        if targets is not None:
            logits_flat = logits.view(-1, self.vocab_size)
            targets_flat = targets.view(-1)
            loss = F.cross_entropy(logits_flat, targets_flat)
        return logits, loss

    def generate(self, idx: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        idx = idx.to(next(self.parameters()).device)
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits_last = logits[:, -1, :]
            probs = F.softmax(logits_last, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx


In [7]:
class MaskedSelfAttentionHead(nn.Module):
    def __init__(self, n_embd: int, head_size: int, dropout: float) -> None:
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * (k.size(-1) ** -0.5)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd: int, num_heads: int, dropout: float) -> None:
        super().__init__()
        head_size = n_embd // num_heads
        self.heads = nn.ModuleList(
            [MaskedSelfAttentionHead(n_embd, head_size, dropout) for _ in range(num_heads)]
        )
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out


class FeedForward(nn.Module):
    def __init__(self, n_embd: int, dropout: float) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class TransformerBlock(nn.Module):
    def __init__(self, n_embd: int, n_head: int, dropout: float) -> None:
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head, dropout)
        self.ff = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.sa(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


class TransformerLanguageModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        n_embd: int = 256,
        n_head: int = 8,
        n_layer: int = 4,
        dropout: float = 0.2,
    ) -> None:
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[TransformerBlock(n_embd, n_head, dropout) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb.unsqueeze(0)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            logits_flat = logits.view(-1, logits.size(-1))
            targets_flat = targets.view(-1)
            loss = F.cross_entropy(logits_flat, targets_flat)
        return logits, loss

    def generate(self, idx: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
        idx = idx.to(next(self.parameters()).device)
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits_last = logits[:, -1, :]
            probs = F.softmax(logits_last, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx


In [8]:
def train_model(
    model_name: str,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    max_iters: int = max_iters,
    eval_interval: int = eval_interval,
) -> dict:
    best_val = float("inf")
    best_step = 0
    best_state = copy.deepcopy(model.state_dict())
    history: list[dict[str, float]] = []
    start_time = time.time()
    steps_without_improvement = 0

    for step in range(1, max_iters + 1):
        model.train()
        xb, yb = get_batch("train")
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        if step % eval_interval == 0 or step == 1:
            losses = estimate_loss(model)
            history.append({"step": step, **losses})
            print(f"{model_name} step {step}: train {losses['train']:.3f}, val {losses['val']:.3f}")
            if losses["val"] + min_delta < best_val:
                best_val = losses["val"]
                best_step = step
                best_state = copy.deepcopy(model.state_dict())
                steps_without_improvement = 0
            else:
                steps_without_improvement += 1
            if steps_without_improvement >= patience:
                print(f"{model_name} stopping early at step {step}")
                break

    model.load_state_dict(best_state)
    duration = time.time() - start_time
    final_eval = estimate_loss(model)
    return {
        "model": model,
        "history": history,
        "best_val_loss": best_val,
        "best_val_perplexity": loss_to_perplexity(best_val),
        "final_eval": final_eval,
        "duration_sec": duration,
        "best_step": best_step,
    }


In [9]:
rnn_model = RNNLanguageModel(
    vocab_size=vocab_size,
    n_embd=n_embd,
    hidden_size=hidden_size,
    num_layers=n_layer,
    dropout=dropout,
).to(device)
rnn_optimizer = torch.optim.AdamW(rnn_model.parameters(), lr=learning_rate)
rnn_results = train_model("LSTM", rnn_model, rnn_optimizer)


LSTM step 1: train 9.199, val 9.198
LSTM step 100: train 6.749, val 6.868
LSTM step 200: train 6.738, val 6.883
LSTM step 300: train 6.734, val 6.883


In [10]:
transformer_model = TransformerLanguageModel(
    vocab_size=vocab_size,
    n_embd=n_embd,
    n_head=n_head,
    n_layer=n_layer,
    dropout=dropout,
).to(device)
transformer_optimizer = torch.optim.AdamW(transformer_model.parameters(), lr=learning_rate)
transformer_results = train_model("Transformer", transformer_model, transformer_optimizer)


Transformer step 1: train 9.069, val 9.062
Transformer step 100: train 6.108, val 6.255
Transformer step 200: train 5.385, val 5.793
Transformer step 300: train 5.101, val 5.681


In [11]:
results = {"LSTM": rnn_results, "Transformer": transformer_results}
for name, res in results.items():
    print(name)
    print(f"  best val loss: {res['best_val_loss']:.3f}")
    print(f"  best val perplexity: {res['best_val_perplexity']:.2f}")
    print(f"  final train loss: {res['final_eval']['train']:.3f} (perplexity {loss_to_perplexity(res['final_eval']['train']):.2f})")
    print(f"  final val loss: {res['final_eval']['val']:.3f} (perplexity {loss_to_perplexity(res['final_eval']['val']):.2f})")
    print(f"  training time: {res['duration_sec']:.1f}s")


LSTM
  best val loss: 6.868
  best val perplexity: 961.11
  final train loss: 6.754 (perplexity 857.17)
  final val loss: 6.864 (perplexity 957.16)
  training time: 48.7s
Transformer
  best val loss: 5.681
  best val perplexity: 293.16
  final train loss: 5.119 (perplexity 167.24)
  final val loss: 5.690 (perplexity 295.76)
  training time: 62.0s


In [16]:
def generate_text(model: nn.Module, prompt: str, max_new_tokens: int = 200) -> str:
    tokens = tokenizer.encode(prompt)
    if not tokens:
        tokens = tokenizer.encode(" ")
    context = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
    generated = model.generate(context, max_new_tokens=max_new_tokens)
    return tokenizer.decode(generated[0].tolist())


print("LSTM sample:")
print(generate_text(rnn_results["model"], "The"))
print()
print("Transformer sample:")
print(generate_text(transformer_results["model"], "The"))


LSTM sample:
The anyHePETRUCHIO dutyical again of chamberAw show serpent present heigh, brother a ratsworth news a particularThat of as ' a weAUTOLYCUS mayin, mother,
 soul all my there feet in his, and thou ins:
ItAw,
 III justice allbrs I in love,
-nightAn may,
's why have have since, to;'less liege ask, thereTheseIf spoke old Juliet Duke
, you enemy I.

ant is youBelike wounds!
 my curse beCarry I profess burningit for ne committed not, disp dishonour ' night wife itself to what man,:
 and andled so him and:
:
 se,:
 to show:
 fear gate lost some a die it those hear giantFirst or means that'll your days unto cure by:
 and heavily sweeter sir isSpeak the the.
.

This world.
 rec stri mothercued.

 ch embassFlFLORIZEL the,Think
 two manKING.


 hurlWith it usSinceuck we

Transformer sample:
The dog take the proclamation my till every thing,
With lady Musician:
Grathough who Montague, whose service anyers
 ginger donesthy of me learn their approvedven- than proud;
Hesemble former friar

In [17]:
summary = []
for name, res in results.items():
    summary.append(
        {
            "model": name,
            "best_val_loss": res["best_val_loss"],
            "best_val_perplexity": res["best_val_perplexity"],
            "train_perplexity": loss_to_perplexity(res["final_eval"]["train"]),
            "val_perplexity": loss_to_perplexity(res["final_eval"]["val"]),
            "time_sec": res["duration_sec"],
        }
    )

summary.sort(key=lambda x: x["val_perplexity"])
for entry in summary:
    print(
        f"{entry['model']}: val perplexity {entry['val_perplexity']:.2f}, train perplexity {entry['train_perplexity']:.2f}, time {entry['time_sec']:.1f}s"
    )

best = summary[0]
worst = summary[-1]
print()
print(
    f"{best['model']} achieved the lowest validation perplexity at {best['val_perplexity']:.2f}."
)
if best['model'] != worst['model']:
    gap = worst["val_perplexity"] - best["val_perplexity"]
    rel = gap / worst["val_perplexity"] * 100
    print(
        f"That is {gap:.2f} lower than {worst['model']} ({rel:.1f}% relative improvement)."
    )
print("Compare the generated samples to judge qualitative fluency.")


Transformer: val perplexity 295.76, train perplexity 167.24, time 62.0s
LSTM: val perplexity 957.16, train perplexity 857.17, time 48.7s

Transformer achieved the lowest validation perplexity at 295.76.
That is 661.40 lower than LSTM (69.1% relative improvement).
Compare the generated samples to judge qualitative fluency.
