# minGPT

**Note:** The `autoreload` extension allows the interpreter to reload modules every time a cell is executed. This is useful when editing the code in a module. The following cell enables the extension and downloads the minGPT package from Github. You can now double-click on a file like model.py, edit its contents, and press Ctrl+S to save it. If you then re-run the notebook cells, including those that create an object of the corresponding class, you will see the changes reflected. Note that the next cell should *only be executed once*, as running `pip install` again will overwrite the modified contents of the module.

Recall that changes in the files (except the notebook itself) are not persistent unless you connect them to your Google Drive account.

In [2]:
%load_ext autoreload
%autoreload 2
%pip install -e 'git+https://github.com/karpathy/minGPT.git@37baab71b9abea1b76ab957409a1cc2fbfba8a26#egg=mingpt'

# Fix this issue: https://github.com/karpathy/minGPT/issues/120
!sed -i '200s/.*/        assert len(keys) == len([k for k in sd if not k.endswith(".attn.bias")])/' /content/src/mingpt/mingpt/model.py


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Obtaining mingpt from git+https://github.com/karpathy/minGPT.git@37baab71b9abea1b76ab957409a1cc2fbfba8a26#egg=mingpt
  Skipping because already up-to-date.
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
[?25hCollecting torch (from mingpt)
  Using cached torch-2.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting filelock (from torch->mingpt)
  Using cached filelock-3.20.2-py3-none-any.whl.metadata (2.1 kB)
Collecting typing-extensions>=4.10.0 (from torch->mingpt)
  Using cached typing_extensions-4.15.0-py3-none-any.whl.metadata (3.3 kB)
Collecting setuptools (from torch->mingpt)
  Using cached setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch

Add module's location to PYTHONPATH, which tells your Python interpreter where to search modules for. The previous `pip install -e` changes the variable in a subshell and the interpreter is therefore not aware of the updated value.

In [3]:
import sys
sys.path.append('/content/src/mingpt')

In [4]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m303.4 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting numpy>=1.17 (from transformers)
  Downloading numpy-2.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting pyyaml>=5.1 (from transformers)
  Using cached pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (2.4 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2025.11.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting request

In [6]:
import sys
print(sys.executable)


/home/bledyx/UA/master-ia/TPLN/code/lvl1/lvl2/tpln-practice2/.venv/bin/python


In [3]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from mingpt.model import GPT
from mingpt.utils import set_seed
from mingpt.bpe import BPETokenizer
set_seed(3407)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
use_mingpt = True # use minGPT or huggingface/transformers model?
model_type = 'gpt2'
device = 'cuda'

In [4]:
if use_mingpt:
    model = GPT.from_pretrained(model_type)
else:
    model = GPT2LMHeadModel.from_pretrained(model_type)
    model.config.pad_token_id = model.config.eos_token_id # suppress a warning

# ship model to device and set to eval mode
model.to(device)
model.eval();

number of parameters: 124.44M


In [5]:

def generate(prompt='', num_samples=10, steps=20, do_sample=True):

    # tokenize the input prompt into integer input sequence
    if use_mingpt:
        tokenizer = BPETokenizer()
        if prompt == '':
            # to create unconditional samples...
            # manually create a tensor with only the special <|endoftext|> token
            # similar to what openai's code does here https://github.com/openai/gpt-2/blob/master/src/generate_unconditional_samples.py
            x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)
        else:
            x = tokenizer(prompt).to(device)
    else:
        tokenizer = GPT2Tokenizer.from_pretrained(model_type)
        if prompt == '':
            # to create unconditional samples...
            # huggingface/transformers tokenizer special cases these strings
            prompt = '<|endoftext|>'
        encoded_input = tokenizer(prompt, return_tensors='pt').to(device)
        x = encoded_input['input_ids']

    # we'll process all desired num_samples in a batch, so expand out the batch dim
    x = x.expand(num_samples, -1)

    # forward the model `steps` times to get samples, in a batch
    y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)

    for i in range(num_samples):
        out = tokenizer.decode(y[i].cpu().squeeze())
        print('-'*80)
        print(out)


In [6]:
generate(prompt='Andrej Karpathy, the Earth representative on', num_samples=10, steps=20)

downloading https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json to /home/bledyx/.cache/mingpt/encoder.json
downloading https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe to /home/bledyx/.cache/mingpt/vocab.bpe
--------------------------------------------------------------------------------
Andrej Karpathy, the Earth representative on NASA's Juno mission, will also receive this year's award.

While his experience shows that
--------------------------------------------------------------------------------
Andrej Karpathy, the Earth representative on Russia's delegation to the G20 summit, spoke about his work to end climate change in his blog
--------------------------------------------------------------------------------
Andrej Karpathy, the Earth representative on the United Nations Security Council who was asked on Monday to take a position on climate change in order to
----------------------------------------------------------------------------

In [7]:
%%writefile repo_orientation.py
"""
Section 2 helper: repository/codebase orientation for minGPT in Colab.

What this script does:
- Prints where mingpt is installed.
- Locates mingpt/model.py.
- Extracts/prints the key lines of GPT.forward that matter for the assignment:
  embeddings -> transformer blocks -> ln_f -> lm_head -> logits
- Provides programmatic checks used by unit tests.

This does NOT implement activation caching/patching yet. It only verifies
we understand where it would go later.
"""

from __future__ import annotations

import inspect
import pathlib
import re
from dataclasses import dataclass
from typing import Dict, List, Tuple

import mingpt
import mingpt.model
from mingpt.model import GPT


@dataclass(frozen=True)
class ForwardLandmarks:
    has_tok_emb: bool
    has_pos_emb: bool
    has_blocks_loop: bool
    has_ln_f: bool
    has_lm_head: bool


def get_paths() -> Dict[str, str]:
    pkg_path = pathlib.Path(mingpt.__file__).resolve()
    model_path = pathlib.Path(mingpt.model.__file__).resolve()
    return {
        "mingpt.__file__": str(pkg_path),
        "mingpt.model.__file__": str(model_path),
    }


def read_model_source() -> str:
    model_path = pathlib.Path(mingpt.model.__file__).resolve()
    return model_path.read_text(encoding="utf-8")


def attn_bias_fix_present(model_source: str) -> bool:
    # Required fix: assert len(keys) == len([k for k in sd if not k.endswith(".attn.bias")])
    return 'len([k for k in sd if not k.endswith(".attn.bias")])' in model_source


def forward_source() -> str:
    return inspect.getsource(GPT.forward)


def find_forward_landmarks(src: str) -> ForwardLandmarks:
    # We intentionally check for robust substrings (not exact formatting).
    has_tok_emb = "tok_emb" in src and "wte" in src
    has_pos_emb = "pos_emb" in src and "wpe" in src
    has_blocks_loop = ("for block in self.transformer.h" in src) or ("for block in self.transformer['h']" in src)
    has_ln_f = "ln_f" in src
    has_lm_head = "lm_head" in src and "logits" in src
    return ForwardLandmarks(
        has_tok_emb=has_tok_emb,
        has_pos_emb=has_pos_emb,
        has_blocks_loop=has_blocks_loop,
        has_ln_f=has_ln_f,
        has_lm_head=has_lm_head,
    )


def print_forward_snippet(src: str, max_lines: int = 80) -> None:
    lines = src.splitlines()
    print("=== GPT.forward (snippet) ===")
    for i, line in enumerate(lines[:max_lines], start=1):
        print(f"{i:03d}: {line}")
    if len(lines) > max_lines:
        print(f"... ({len(lines)-max_lines} more lines)")


def main() -> None:
    paths = get_paths()
    print("=== Installed paths ===")
    for k, v in paths.items():
        print(f"{k}: {v}")

    model_src = read_model_source()
    print("\n=== .attn.bias fix present? ===")
    print(attn_bias_fix_present(model_src))

    fwd_src = forward_source()
    landmarks = find_forward_landmarks(fwd_src)
    print("\n=== Forward pipeline landmarks ===")
    print(landmarks)

    print()
    print_forward_snippet(fwd_src)


if __name__ == "__main__":
    main()


Writing repo_orientation.py


In [8]:
!python repo_orientation.py


=== Installed paths ===
mingpt.__file__: /home/bledyx/UA/master-ia/TPLN/code/lvl1/lvl2/tpln-practice2/.venv/src/mingpt/mingpt/__init__.py
mingpt.model.__file__: /home/bledyx/UA/master-ia/TPLN/code/lvl1/lvl2/tpln-practice2/.venv/src/mingpt/mingpt/model.py

=== .attn.bias fix present? ===
True

=== Forward pipeline landmarks ===
ForwardLandmarks(has_tok_emb=True, has_pos_emb=True, has_blocks_loop=True, has_ln_f=True, has_lm_head=True)

=== GPT.forward (snippet) ===
001:     def forward(self, idx, targets=None):
002:         device = idx.device
003:         b, t = idx.size()
004:         assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
005:         pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
006: 
007:         # forward the GPT model itself
008:         tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
009:         pos_emb = self.transformer.wpe(pos) # position

In [9]:
%%writefile generate_driver.py
"""
Section 2 driver skeleton (will be extended in Sections 3+ and especially 5–7).

Right now it only:
- loads GPT-2 small via GPT.from_pretrained('gpt2')
- tokenizes a prompt with BPETokenizer
- runs a single forward pass to confirm logits shape
- runs model.generate to confirm decoding loop works

Later, you'll add:
- control flags passed into GPT.forward (save_activations, patch params, etc.)
"""

from __future__ import annotations

import torch

from mingpt.model import GPT
from mingpt.bpe import BPETokenizer
from mingpt.utils import set_seed


def get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"


@torch.no_grad()
def main() -> None:
    set_seed(3407)

    device = get_device()
    print("Device:", device)

    model = GPT.from_pretrained("gpt2")
    model.to(device)
    model.eval()

    bpe = BPETokenizer()
    prompt = "Andrej Karpathy, the Earth representative on"
    idx = bpe(prompt).to(device)  # shape (1, T)

    # forward pass (logits for each position)
    logits, loss = model(idx)
    print("Input shape:", tuple(idx.shape))
    print("Logits shape:", tuple(logits.shape))
    assert logits.ndim == 3, "Expected (B, T, V) logits"
    assert logits.shape[0] == idx.shape[0] and logits.shape[1] == idx.shape[1], "B,T must match input"

    # generate a short continuation (just to prove decoding loop works)
    out_idx = model.generate(idx, max_new_tokens=20, do_sample=True, top_k=40)
    out_text = bpe.decode(out_idx[0].cpu())
    print("\n=== Generated ===")
    print(out_text)


if __name__ == "__main__":
    main()


Writing generate_driver.py


In [10]:
!python generate_driver.py


Device: cuda
number of parameters: 124.44M
Input shape: (1, 10)
Logits shape: (1, 10, 50257)

=== Generated ===
Andrej Karpathy, the Earth representative on NASA's Mars Exploration Rover Curiosity, talks about the success of the science rover Curiosity, which now has


In [12]:
%%writefile test_all.py
import os
import pathlib
import pytest
import torch

import mingpt
import mingpt.model
from mingpt.model import GPT

import repo_orientation as ro


def test_mingpt_importable_and_paths_exist():
    paths = ro.get_paths()
    assert "mingpt.__file__" in paths and "mingpt.model.__file__" in paths

    pkg_path = pathlib.Path(paths["mingpt.__file__"])
    model_path = pathlib.Path(paths["mingpt.model.__file__"])
    assert pkg_path.exists(), f"mingpt package file not found: {pkg_path}"
    assert model_path.exists(), f"mingpt.model file not found: {model_path}"


def test_attn_bias_fix_present_or_applied():
    src = ro.read_model_source()
    assert ro.attn_bias_fix_present(src), (
        "Required fix not found in mingpt/model.py. "
        "Expected assert to ignore keys ending with .attn.bias."
    )


def test_forward_pipeline_landmarks_present():
    fwd_src = ro.forward_source()
    lm = ro.find_forward_landmarks(fwd_src)
    assert lm.has_tok_emb, "Expected token embedding (wte/tok_emb) usage in forward."
    assert lm.has_pos_emb, "Expected positional embedding (wpe/pos_emb) usage in forward."
    assert lm.has_blocks_loop, "Expected loop over transformer blocks in forward."
    assert lm.has_ln_f, "Expected final layer norm ln_f in forward."
    assert lm.has_lm_head, "Expected lm_head/logits in forward."


def test_fast_forward_and_generate_from_scratch():
    # Fast test: avoid downloading HF weights.
    cfg = GPT.get_default_config()
    cfg.model_type = "gpt-nano"  # tiny
    cfg.vocab_size = 1000
    cfg.block_size = 64
    model = GPT(cfg)
    model.eval()

    idx = torch.randint(0, cfg.vocab_size, (1, 10), dtype=torch.long)
    with torch.no_grad():
        logits, loss = model(idx)
    assert logits.shape == (1, 10, cfg.vocab_size)
    assert loss is None

    with torch.no_grad():
        out = model.generate(idx, max_new_tokens=5, do_sample=False)
    assert out.shape[1] == 15


@pytest.mark.slow
def test_slow_from_pretrained_gpt2_loads_and_runs():
    # Slow test: tries to download and load GPT-2 weights.
    # If network/cache issues happen in Colab, we skip rather than fail hard.
    device = "cuda" if torch.cuda.is_available() else "cpu"

    try:
        model = GPT.from_pretrained("gpt2")
    except Exception as e:
        pytest.skip(f"Skipping from_pretrained test due to load/download error: {e}")

    model.to(device)
    model.eval()

    idx = torch.randint(0, 50257, (1, 8), dtype=torch.long, device=device)
    with torch.no_grad():
        logits, loss = model(idx)

    assert logits.shape == (1, 8, 50257)
    assert loss is None


Overwriting test_all.py


In [22]:
%%writefile pytest.ini
[pytest]
markers =
    slow: marks tests as slow (deselect with '-m "not slow"')


Writing pytest.ini


In [24]:
import sys
print("Kernel python:", sys.executable)

!{sys.executable} -m pip install -q pytest
!{sys.executable} -c "import torch; print('torch:', torch.__version__)"
!{sys.executable} -m pytest -q


Kernel python: /home/bledyx/UA/master-ia/TPLN/code/lvl1/lvl2/tpln-practice2/.venv/bin/python
torch: 2.9.1+cu128
[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m                                                                    [100%][0m
[32m[32m[1m5 passed[0m[32m in 11.96s[0m[0m


In [29]:
%%writefile tokenization_protocol.py
"""
Section 3: Tokenization Protocol and "Same Number of Tokens" Guarantee.

This module provides:
- Tokenization reports (token ids, per-token decoded strings, token count)
- Pair comparison (same-length check, diff positions, one-token-diff check)
- Report-friendly Markdown export for token-by-token decomposition
- Heuristic suggestions to fix token length mismatches

Designed for minGPT's BPETokenizer (mingpt/bpe.py).
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple, Dict

import torch
from mingpt.bpe import BPETokenizer


# -----------------------------
# Data structures
# -----------------------------

@dataclass(frozen=True)
class TokenizationReport:
    text: str
    token_ids: List[int]
    token_strs: List[str]  # decoded per-token strings (may include leading spaces)
    seq_len: int
    decoded_roundtrip: str

    def short_preview(self, max_chars: int = 120) -> str:
        s = self.text.replace("\n", "\\n")
        return s if len(s) <= max_chars else s[: max_chars - 3] + "..."


@dataclass(frozen=True)
class PairComparison:
    clean: TokenizationReport
    corrupt: TokenizationReport
    same_length: bool
    diff_positions: List[int]
    diff_count: int

    @property
    def one_token_diff(self) -> bool:
        return self.same_length and self.diff_count == 1


# -----------------------------
# Core tokenization helpers
# -----------------------------

def tokenize_2d(bpe: BPETokenizer, text: str, device: Optional[str] = None) -> torch.LongTensor:
    """
    Returns token ids as a 2D tensor of shape (1, T) as BPETokenizer does.
    """
    ids_2d = bpe(text)  # (1, T)
    if device is not None:
        ids_2d = ids_2d.to(device)
    return ids_2d


def tokenize_1d_ids(bpe: BPETokenizer, text: str) -> List[int]:
    """
    Returns token ids as a python list[int] (1D).
    """
    ids = bpe(text)[0].tolist()
    return [int(x) for x in ids]


def decode_token_id(bpe: BPETokenizer, token_id: int) -> str:
    """
    Decode a single token id into its string form.
    """
    t = torch.tensor([token_id], dtype=torch.long)
    return bpe.decode(t)


def decode_tokens_1d(bpe: BPETokenizer, token_ids: Sequence[int]) -> str:
    """
    Decode a sequence of token ids back into a string.
    """
    t = torch.tensor(list(token_ids), dtype=torch.long)
    return bpe.decode(t)


def per_token_strings(bpe: BPETokenizer, token_ids: Sequence[int]) -> List[str]:
    """
    Per-token decoded strings (important for inspecting leading spaces).
    """
    return [decode_token_id(bpe, int(tid)) for tid in token_ids]


def build_report(bpe: BPETokenizer, text: str) -> TokenizationReport:
    """
    Build a complete tokenization report for one text.
    """
    token_ids = tokenize_1d_ids(bpe, text)
    token_strs = per_token_strings(bpe, token_ids)
    decoded = decode_tokens_1d(bpe, token_ids)
    return TokenizationReport(
        text=text,
        token_ids=token_ids,
        token_strs=token_strs,
        seq_len=len(token_ids),
        decoded_roundtrip=decoded,
    )


# -----------------------------
# Comparison and validations
# -----------------------------

def diff_positions(a: Sequence[int], b: Sequence[int]) -> List[int]:
    """
    Returns a list of positions where sequences differ.
    If lengths differ, extra positions beyond min length are included as diffs.
    """
    la, lb = len(a), len(b)
    m = min(la, lb)
    diffs = [i for i in range(m) if int(a[i]) != int(b[i])]
    if la != lb:
        diffs.extend(list(range(m, max(la, lb))))
    return diffs


def compare_clean_corrupt(clean: TokenizationReport, corrupt: TokenizationReport) -> PairComparison:
    diffs = diff_positions(clean.token_ids, corrupt.token_ids)
    same_len = (clean.seq_len == corrupt.seq_len)
    return PairComparison(
        clean=clean,
        corrupt=corrupt,
        same_length=same_len,
        diff_positions=diffs,
        diff_count=len(diffs),
    )


def assert_same_length(clean: TokenizationReport, corrupt: TokenizationReport) -> None:
    if clean.seq_len != corrupt.seq_len:
        raise ValueError(
            f"Token length mismatch: clean={clean.seq_len}, corrupt={corrupt.seq_len}.\n"
            f"Clean preview: {clean.short_preview()}\n"
            f"Corrupt preview: {corrupt.short_preview()}"
        )


def assert_one_token_difference(comp: PairComparison) -> None:
    if not comp.same_length:
        raise ValueError(
            f"Cannot check one-token-diff: lengths differ (clean={comp.clean.seq_len}, corrupt={comp.corrupt.seq_len})."
        )
    if comp.diff_count != 1:
        raise ValueError(
            f"Expected exactly 1 differing token position, found {comp.diff_count}: {comp.diff_positions}\n"
            f"Tip: inspect the per-token strings and adjust the text until only one BPE token changes."
        )


def validate_pair(
    bpe: BPETokenizer,
    clean_text: str,
    corrupt_text: str,
    require_same_length: bool = True,
    require_one_token_diff: bool = True,
) -> PairComparison:
    """
    Tokenize both texts, compare, and (optionally) enforce constraints by raising errors.
    """
    clean = build_report(bpe, clean_text)
    corrupt = build_report(bpe, corrupt_text)
    comp = compare_clean_corrupt(clean, corrupt)

    if require_same_length:
        assert_same_length(clean, corrupt)
    if require_one_token_diff:
        assert_one_token_difference(comp)
    return comp


# -----------------------------
# Printing / report exports
# -----------------------------

def format_token_list_for_console(rep: TokenizationReport) -> str:
    """
    Console-friendly token list.
    Shows position, token_id, and repr(token_str) to make spaces visible.
    """
    lines = []
    for i, (tid, s) in enumerate(zip(rep.token_ids, rep.token_strs)):
        lines.append(f"{i:02d} | {tid:5d} | {repr(s)}")
    return "\n".join(lines)


def format_pair_diff_markdown(comp: PairComparison) -> str:
    """
    Markdown table: position-wise clean vs corrupt tokens.
    Great for pasting into the report.
    """
    clean = comp.clean
    corrupt = comp.corrupt
    max_len = max(clean.seq_len, corrupt.seq_len)

    header = "| pos | clean_id | clean_tok | corrupt_id | corrupt_tok | diff? |\n|---:|---:|---|---:|---|:---:|\n"
    rows = []
    for i in range(max_len):
        c_id = clean.token_ids[i] if i < clean.seq_len else None
        k_id = corrupt.token_ids[i] if i < corrupt.seq_len else None
        c_tok = clean.token_strs[i] if i < clean.seq_len else ""
        k_tok = corrupt.token_strs[i] if i < corrupt.seq_len else ""
        diff = "✅" if i in comp.diff_positions else ""
        rows.append(
            f"| {i} | {'' if c_id is None else c_id} | {repr(c_tok)} | {'' if k_id is None else k_id} | {repr(k_tok)} | {diff} |"
        )
    return header + "\n".join(rows) + "\n"


def describe_pair(comp: PairComparison) -> str:
    """
    Human-readable summary.
    """
    return (
        "=== Pair summary ===\n"
        f"Clean tokens:   {comp.clean.seq_len}\n"
        f"Corrupt tokens: {comp.corrupt.seq_len}\n"
        f"Same length?    {comp.same_length}\n"
        f"Diff count:     {comp.diff_count}\n"
        f"Diff positions: {comp.diff_positions}\n"
        f"One-token diff? {comp.one_token_diff}\n"
    )


# -----------------------------
# Heuristic suggestions (for mismatch debugging)
# -----------------------------

def suggest_fixes(clean: TokenizationReport, corrupt: TokenizationReport) -> List[str]:
    """
    Heuristics to help the user fix length mismatches / multi-token mismatches.
    Not an automatic fixer; it gives actionable suggestions.
    """
    suggestions: List[str] = []

    # Length mismatch guidance
    if clean.seq_len != corrupt.seq_len:
        suggestions.append(
            "Token length mismatch detected. Common causes: whitespace differences, punctuation attachment, "
            "or swapping a word that tokenizes into a different number of BPE tokens."
        )
        suggestions.append(
            "Try keeping punctuation identical (e.g., 'student.' vs 'student .') and keep spaces consistent around the changed word."
        )
        suggestions.append(
            "Proper nouns are often unstable: try swapping to a more common single-token alternative and re-check."
        )

    # Multi-token difference guidance
    diffs = diff_positions(clean.token_ids, corrupt.token_ids)
    if clean.seq_len == corrupt.seq_len and len(diffs) != 1:
        suggestions.append(
            f"More than one token differs ({len(diffs)}). You want exactly 1 differing BPE token position."
        )
        suggestions.append(
            "Inspect per-token strings around the diff positions; often a punctuation or whitespace token is also changing."
        )

    # Space-specific hint
    suggestions.append(
        "Remember GPT-2 BPE: tokens in the middle often include a leading space. "
        "If you care about the token 'Jones', the actual token is usually ' Jones'."
    )

    return suggestions


Writing tokenization_protocol.py


In [25]:
%%writefile tokenization_driver.py
"""
Section 3 driver: tokenize clean/corrupt prompts, enforce same-length and one-token-diff,
print per-token decomposition, and export a Markdown token table for the report.

Usage in Colab:
!python tokenization_driver.py

Or override defaults by editing the CLEAN_TEXT / CORRUPT_TEXT constants below.
"""

from __future__ import annotations

import argparse
from pathlib import Path

from mingpt.bpe import BPETokenizer

import tokenization_protocol as tp


# Edit these defaults for your own experiment.
CLEAN_TEXT = "Michelle Jones was a top-notch student. Michelle"
CORRUPT_TEXT = "Michelle Smith was a top-notch student. Michelle"


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--clean", type=str, default=CLEAN_TEXT, help="Clean prompt text")
    p.add_argument("--corrupt", type=str, default=CORRUPT_TEXT, help="Corrupted prompt text")
    p.add_argument("--no-require-one-diff", action="store_true", help="Do not require exactly 1 token difference")
    p.add_argument("--out_md", type=str, default="token_table.md", help="Output markdown file for token table")
    return p.parse_args()


def main() -> None:
    args = parse_args()
    bpe = BPETokenizer()

    clean_rep = tp.build_report(bpe, args.clean)
    corrupt_rep = tp.build_report(bpe, args.corrupt)
    comp = tp.compare_clean_corrupt(clean_rep, corrupt_rep)

    print(tp.describe_pair(comp))

    print("=== Clean prompt ===")
    print(clean_rep.text)
    print("\n=== Clean tokens (pos | id | repr(token)) ===")
    print(tp.format_token_list_for_console(clean_rep))

    print("\n=== Corrupt prompt ===")
    print(corrupt_rep.text)
    print("\n=== Corrupt tokens (pos | id | repr(token)) ===")
    print(tp.format_token_list_for_console(corrupt_rep))

    # Enforce constraints as requested by the assignment
    require_one = not args.no_require_one_diff
    try:
        _ = tp.validate_pair(
            bpe=bpe,
            clean_text=args.clean,
            corrupt_text=args.corrupt,
            require_same_length=True,
            require_one_token_diff=require_one,
        )
        print("\n✅ Validation passed.")
    except Exception as e:
        print("\n❌ Validation failed:")
        print(e)
        print("\nSuggestions:")
        for s in tp.suggest_fixes(clean_rep, corrupt_rep):
            print("-", s)

    # Export markdown table for report
    md = tp.format_pair_diff_markdown(comp)
    out_path = Path(args.out_md)
    out_path.write_text(md, encoding="utf-8")
    print(f"\nWrote Markdown token table to: {out_path.resolve()}")


if __name__ == "__main__":
    main()


Writing tokenization_driver.py


In [31]:
!python tokenization_driver.py


=== Pair summary ===
Clean tokens:   11
Corrupt tokens: 11
Same length?    True
Diff count:     1
Diff positions: [1]
One-token diff? True

=== Clean prompt ===
Michelle Jones was a top-notch student. Michelle

=== Clean tokens (pos | id | repr(token)) ===
00 | 48736 | 'Michelle'
01 |  5437 | ' Jones'
02 |   373 | ' was'
03 |   257 | ' a'
04 |  1353 | ' top'
05 |    12 | '-'
06 |  1662 | 'not'
07 |   354 | 'ch'
08 |  3710 | ' student'
09 |    13 | '.'
10 | 16738 | ' Michelle'

=== Corrupt prompt ===
Michelle Smith was a top-notch student. Michelle

=== Corrupt tokens (pos | id | repr(token)) ===
00 | 48736 | 'Michelle'
01 |  4176 | ' Smith'
02 |   373 | ' was'
03 |   257 | ' a'
04 |  1353 | ' top'
05 |    12 | '-'
06 |  1662 | 'not'
07 |   354 | 'ch'
08 |  3710 | ' student'
09 |    13 | '.'
10 | 16738 | ' Michelle'

✅ Validation passed.

Wrote Markdown token table to: /home/bledyx/UA/master-ia/TPLN/code/lvl1/lvl2/tpln-practice2/token_table.md


In [32]:
!sed -n '1,120p' token_table.md


| pos | clean_id | clean_tok | corrupt_id | corrupt_tok | diff? |
|---:|---:|---|---:|---|:---:|
| 0 | 48736 | 'Michelle' | 48736 | 'Michelle' |  |
| 1 | 5437 | ' Jones' | 4176 | ' Smith' | ✅ |
| 2 | 373 | ' was' | 373 | ' was' |  |
| 3 | 257 | ' a' | 257 | ' a' |  |
| 4 | 1353 | ' top' | 1353 | ' top' |  |
| 5 | 12 | '-' | 12 | '-' |  |
| 6 | 1662 | 'not' | 1662 | 'not' |  |
| 7 | 354 | 'ch' | 354 | 'ch' |  |
| 8 | 3710 | ' student' | 3710 | ' student' |  |
| 9 | 13 | '.' | 13 | '.' |  |
| 10 | 16738 | ' Michelle' | 16738 | ' Michelle' |  |


In [27]:
%%writefile test_all.py
import pathlib
import sys

import pytest
import torch

# Colab-friendly: ensure mingpt editable install path is visible during pytest subprocess
COLAB_MINGPT_PATH = pathlib.Path("/content/src/mingpt")
if COLAB_MINGPT_PATH.exists():
    sys.path.append(str(COLAB_MINGPT_PATH))

import mingpt
import mingpt.model
from mingpt.model import GPT

import repo_orientation as ro
import tokenization_protocol as tp


# --------------------------
# Section 2 tests (repo orientation)
# --------------------------

def test_mingpt_importable_and_paths_exist():
    paths = ro.get_paths()
    assert "mingpt.__file__" in paths and "mingpt.model.__file__" in paths

    pkg_path = pathlib.Path(paths["mingpt.__file__"])
    model_path = pathlib.Path(paths["mingpt.model.__file__"])
    assert pkg_path.exists(), f"mingpt package file not found: {pkg_path}"
    assert model_path.exists(), f"mingpt.model file not found: {model_path}"


def test_attn_bias_fix_present_or_applied():
    src = ro.read_model_source()
    assert ro.attn_bias_fix_present(src), (
        "Required fix not found in mingpt/model.py. "
        "Expected assert to ignore keys ending with .attn.bias."
    )


def test_forward_pipeline_landmarks_present():
    fwd_src = ro.forward_source()
    lm = ro.find_forward_landmarks(fwd_src)
    assert lm.has_tok_emb, "Expected token embedding (wte/tok_emb) usage in forward."
    assert lm.has_pos_emb, "Expected positional embedding (wpe/pos_emb) usage in forward."
    assert lm.has_blocks_loop, "Expected loop over transformer blocks in forward."
    assert lm.has_ln_f, "Expected final layer norm ln_f in forward."
    assert lm.has_lm_head, "Expected lm_head/logits in forward."


def test_fast_forward_and_generate_from_scratch():
    # Fast test: avoid downloading HF weights.
    cfg = GPT.get_default_config()
    cfg.model_type = "gpt-nano"  # tiny
    cfg.vocab_size = 1000
    cfg.block_size = 64
    model = GPT(cfg)
    model.eval()

    idx = torch.randint(0, cfg.vocab_size, (1, 10), dtype=torch.long)
    with torch.no_grad():
        logits, loss = model(idx)
    assert logits.shape == (1, 10, cfg.vocab_size)
    assert loss is None

    with torch.no_grad():
        out = model.generate(idx, max_new_tokens=5, do_sample=False)
    assert out.shape[1] == 15


# --------------------------
# Section 3 tests (tokenization protocol)
# --------------------------

def test_diff_positions_length_mismatch_includes_tail():
    a = [1, 2, 3]
    b = [1, 2, 3, 4, 5]
    diffs = tp.diff_positions(a, b)
    assert diffs == [3, 4]


def test_compare_reports_detects_one_token_diff_synthetic():
    clean = tp.TokenizationReport(
        text="clean",
        token_ids=[10, 20, 30],
        token_strs=["a", "b", "c"],
        seq_len=3,
        decoded_roundtrip="abc",
    )
    corrupt = tp.TokenizationReport(
        text="corrupt",
        token_ids=[10, 99, 30],
        token_strs=["a", "X", "c"],
        seq_len=3,
        decoded_roundtrip="aXc",
    )
    comp = tp.compare_clean_corrupt(clean, corrupt)
    assert comp.same_length is True
    assert comp.diff_positions == [1]
    assert comp.diff_count == 1
    assert comp.one_token_diff is True


def test_assert_one_token_difference_raises_when_multi_diff():
    clean = tp.TokenizationReport(
        text="clean",
        token_ids=[1, 2, 3],
        token_strs=["a", "b", "c"],
        seq_len=3,
        decoded_roundtrip="abc",
    )
    corrupt = tp.TokenizationReport(
        text="corrupt",
        token_ids=[9, 2, 8],
        token_strs=["X", "b", "Y"],
        seq_len=3,
        decoded_roundtrip="XbY",
    )
    comp = tp.compare_clean_corrupt(clean, corrupt)
    assert comp.diff_count == 2
    with pytest.raises(ValueError):
        tp.assert_one_token_difference(comp)


@pytest.mark.slow
def test_bpe_tokenization_roundtrip_and_lengths():
    """
    Slow-ish test because BPETokenizer may download merges/vocab on first use in a fresh runtime.
    """
    from mingpt.bpe import BPETokenizer

    try:
        bpe = BPETokenizer()
    except Exception as e:
        pytest.skip(f"Skipping BPETokenizer test due to tokenizer init/download error: {e}")

    text = "Michelle Jones was a top-notch student. Michelle"
    rep = tp.build_report(bpe, text)

    # Basic sanity
    assert rep.seq_len > 0
    assert len(rep.token_ids) == rep.seq_len
    assert len(rep.token_strs) == rep.seq_len

    # Roundtrip should contain the key content (exact equality may vary by whitespace normalization)
    assert "Michelle" in rep.decoded_roundtrip


@pytest.mark.slow
def test_bpe_pair_validation_example_michelle_jones_smith():
    """
    Uses the assignment's canonical-style example to ensure:
    - same token length
    - ideally a one-token difference (it usually is, but tokenizer quirks can vary)
    """
    from mingpt.bpe import BPETokenizer

    try:
        bpe = BPETokenizer()
    except Exception as e:
        pytest.skip(f"Skipping BPETokenizer test due to tokenizer init/download error: {e}")

    clean = "Michelle Jones was a top-notch student. Michelle"
    corrupt = "Michelle Smith was a top-notch student. Michelle"

    clean_rep = tp.build_report(bpe, clean)
    corrupt_rep = tp.build_report(bpe, corrupt)
    comp = tp.compare_clean_corrupt(clean_rep, corrupt_rep)

    assert comp.same_length is True, f"Expected same token length; got {clean_rep.seq_len} vs {corrupt_rep.seq_len}"

    # We prefer one-token diff; if it isn't, we still show it's a valid pair for same-length constraint.
    # But for the assignment report you should aim for diff_count == 1.
    assert comp.diff_count >= 1


@pytest.mark.slow
def test_slow_from_pretrained_gpt2_loads_and_runs():
    """
    Slow test: downloads and loads GPT-2 weights.
    If network/cache issues happen in Colab, we skip rather than fail hard.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"

    try:
        model = GPT.from_pretrained("gpt2")
    except Exception as e:
        pytest.skip(f"Skipping from_pretrained test due to load/download error: {e}")

    model.to(device)
    model.eval()

    idx = torch.randint(0, 50257, (1, 8), dtype=torch.long, device=device)
    with torch.no_grad():
        logits, loss = model(idx)

    assert logits.shape == (1, 8, 50257)
    assert loss is None


Overwriting test_all.py


In [30]:
!{sys.executable} -m pytest -q

[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m                                                               [100%][0m
[32m[32m[1m10 passed[0m[32m in 16.71s[0m[0m
