In [None]:
!git clone https://github.com/bushuyeu/ece405-assignment1-basics.git

In [None]:
%cd ece405-assignment1-basics/

In [None]:
!git pull

In [None]:
!git checkout main

No need for Conda environment on Colab

In [None]:
!pip install -e .'[test]'

In [None]:
!pytest tests/test_train_bpe.py

In [None]:
%cd /content

# Remove old data folder and start fresh
!rm -rf data
!mkdir -p data
%cd data

# Download TinyStories
!wget -q --show-progress https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
!wget -q --show-progress https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt

# Download OpenWebText sample
!wget -q --show-progress https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz
!gunzip owt_train.txt.gz
!wget -q --show-progress https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz
!gunzip owt_valid.txt.gz

# Verify all downloads
import os
print("\n=== Download verification ===")
for f in ["TinyStoriesV2-GPT4-train.txt", "TinyStoriesV2-GPT4-valid.txt", "owt_train.txt", "owt_valid.txt"]:
    if os.path.exists(f):
        size = os.path.getsize(f)
        print(f"{f}: {size / 1e9:.2f} GB")
    else:
        print(f"{f}: MISSING!")

%cd /content/ece405-assignment1-basics

In [None]:
# Common imports and helper functions
import os
import time
import pickle
import tracemalloc
from ece496b_basics import train_bpe

os.makedirs("outputs", exist_ok=True)

def safe_decode(b):
    """Safely decode bytes to string, falling back to repr if UTF-8 fails."""
    try:
        return b.decode('utf-8')
    except:
        return repr(b)

def analyze_vocab(vocab, name):
    """Print analysis of a vocabulary."""
    longest_token = max(vocab.values(), key=len)
    avg_len = sum(len(v) for v in vocab.values()) / len(vocab)
    
    print(f"\n{name} Vocabulary Analysis:")
    print(f"  Total tokens: {len(vocab)}")
    print(f"  Merged tokens (non-byte): {len([k for k in vocab if k >= 256])}")
    print(f"  Average token length: {avg_len:.2f} bytes")
    print(f"  Longest token: '{safe_decode(longest_token)}' ({len(longest_token)} bytes)")
    
    # Top 5 longest
    print(f"  Top 5 longest tokens:")
    for tid, tbytes in sorted(vocab.items(), key=lambda x: len(x[1]), reverse=True)[:5]:
        print(f"    ID {tid}: '{safe_decode(tbytes)}' ({len(tbytes)} bytes)")

print("Imports and helpers loaded.")

## TinyStories BPE Training (vocab_size=10000)

In [None]:
# Train BPE on TinyStories
tracemalloc.start()
start_time = time.time()

ts_vocab, ts_merges = train_bpe(
    input_path="/content/data/TinyStoriesV2-GPT4-train.txt",
    vocab_size=10000,
    special_tokens=["<|endoftext|>"]
)

elapsed_time = time.time() - start_time
_, peak_mem = tracemalloc.get_traced_memory()
tracemalloc.stop()

print(f"Training time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")
print(f"Peak memory: {peak_mem / 1e9:.2f} GB")
print(f"Number of merges: {len(ts_merges)}")

analyze_vocab(ts_vocab, "TinyStories")

# Save immediately
with open("outputs/ts_vocab_10k.pkl", "wb") as f:
    pickle.dump(ts_vocab, f)
with open("outputs/ts_merges_10k.pkl", "wb") as f:
    pickle.dump(ts_merges, f)
print("\nSaved to outputs/ts_vocab_10k.pkl and outputs/ts_merges_10k.pkl")

## OpenWebText BPE Training (vocab_size=32000)

**Problem (train_bpe_expts_owt)**: Train a byte-level BPE tokenizer on OpenWebText with vocab_size=32,000.

Resource requirements: ≤12 hours (no GPUs), ≤100GB RAM

**Note**: This will take several hours. Consider running overnight.

In [None]:
# Train BPE on OpenWebText with vocab_size=32000
tracemalloc.start()
start_time = time.time()

owt_vocab, owt_merges = train_bpe(
    input_path="/content/data/owt_train.txt",
    vocab_size=32000,
    special_tokens=["<|endoftext|>"]
)

elapsed_time = time.time() - start_time
_, peak_mem = tracemalloc.get_traced_memory()
tracemalloc.stop()

print("="*50)
print("OpenWebText BPE Training Complete!")
print("="*50)
print(f"Training time: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes, {elapsed_time/3600:.2f} hours)")
print(f"Peak memory: {peak_mem / 1e9:.2f} GB")
print(f"Number of merges: {len(owt_merges)}")

analyze_vocab(owt_vocab, "OpenWebText")

# Save immediately
with open("outputs/owt_vocab_32k.pkl", "wb") as f:
    pickle.dump(owt_vocab, f)
with open("outputs/owt_merges_32k.pkl", "wb") as f:
    pickle.dump(owt_merges, f)
print("\nSaved to outputs/owt_vocab_32k.pkl and outputs/owt_merges_32k.pkl")

## Compare TinyStories vs OpenWebText Tokenizers

**Problem (train_bpe_expts_owt) Part (b)**: Compare and contrast the tokenizers trained on TinyStories vs OpenWebText.

In [None]:
# Load vocabularies (in case running from fresh kernel after long OWT training)
import pickle

try:
    ts_vocab
    print("TinyStories vocab already in memory")
except NameError:
    with open("outputs/ts_vocab_10k.pkl", "rb") as f:
        ts_vocab = pickle.load(f)
    print("Loaded TinyStories vocab from disk")

try:
    owt_vocab
    print("OpenWebText vocab already in memory")
except NameError:
    with open("outputs/owt_vocab_32k.pkl", "rb") as f:
        owt_vocab = pickle.load(f)
    print("Loaded OpenWebText vocab from disk")

In [None]:
# Compare token sets (excluding base 256 bytes)
ts_tokens = set(v for k, v in ts_vocab.items() if k >= 256)
owt_tokens = set(v for k, v in owt_vocab.items() if k >= 256)

shared = ts_tokens & owt_tokens
ts_only = ts_tokens - owt_tokens
owt_only = owt_tokens - ts_tokens

print(f"TinyStories merged tokens: {len(ts_tokens)}")
print(f"OpenWebText merged tokens: {len(owt_tokens)}")
print(f"\nShared tokens: {len(shared)}")
print(f"  ({len(shared)/len(ts_tokens)*100:.1f}% of TinyStories tokens are in OWT)")
print(f"  ({len(shared)/len(owt_tokens)*100:.1f}% of OWT tokens are in TinyStories)")
print(f"\nTokens unique to TinyStories: {len(ts_only)}")
print(f"Tokens unique to OpenWebText: {len(owt_only)}")

In [None]:
# Helper function (in case cell-8 wasn't run)
def safe_decode(b):
    try:
        return b.decode('utf-8')
    except:
        return repr(b)

# Show example tokens unique to each dataset (sorted by length for interesting examples)
print("Top 15 longest tokens UNIQUE to TinyStories (children's stories):")
for t in sorted(ts_only, key=len, reverse=True)[:15]:
    print(f"  '{safe_decode(t)}'")

print(f"\nTop 15 longest tokens UNIQUE to OpenWebText (web text):")
for t in sorted(owt_only, key=len, reverse=True)[:15]:
    print(f"  '{safe_decode(t)}'")

In [None]:
# Compare statistics
ts_longest = max(ts_vocab.values(), key=len)
owt_longest = max(owt_vocab.values(), key=len)
ts_avg_len = sum(len(v) for v in ts_vocab.values()) / len(ts_vocab)
owt_avg_len = sum(len(v) for v in owt_vocab.values()) / len(owt_vocab)

print("=" * 50)
print("Summary Statistics")
print("=" * 50)
print(f"{'Metric':<30} {'TinyStories':>12} {'OpenWebText':>12}")
print("-" * 50)
print(f"{'Vocab size':<30} {len(ts_vocab):>12} {len(owt_vocab):>12}")
print(f"{'Avg token length (bytes)':<30} {ts_avg_len:>12.2f} {owt_avg_len:>12.2f}")
print(f"{'Longest token (bytes)':<30} {len(ts_longest):>12} {len(owt_longest):>12}")
print("-" * 50)
print(f"TinyStories longest: '{safe_decode(ts_longest)}'")
print(f"OpenWebText longest: '{safe_decode(owt_longest)}'")

## Answers to Assignment Questions

### Part (a): What is the longest token in the OWT vocabulary? Does it make sense?

*Fill in your answer here after running the cells above.*

### Part (b): Compare and contrast TinyStories vs OpenWebText tokenizers.

**Important**: Note that this comparison uses different vocab sizes (TinyStories: 10k, OWT: 32k) as specified by the assignment. When writing your answer, acknowledge this limitation and focus on qualitative differences (token types, domain-specific vocabulary) rather than raw token counts.

*Fill in your answer here. Consider:*
- *The vocab size difference (10k vs 32k) means OWT naturally has more tokens*
- *What % of TinyStories tokens also appear in OWT? What does this suggest?*
- *Types of tokens unique to each (children's vocabulary vs web/technical terms)*
- *Average token length differences and what they indicate about text complexity*

## Optional: Profiling (TinyStories)

In [None]:
# Profile the training to see what takes the most time
import cProfile
import pstats
from io import StringIO

profiler = cProfile.Profile()
profiler.enable()

vocab_prof, merges_prof = train_bpe(
    input_path="/content/data/TinyStoriesV2-GPT4-train.txt",
    vocab_size=10000,
    special_tokens=["<|endoftext|>"]
)

profiler.disable()

s = StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
ps.print_stats(20)
print(s.getvalue())