# Understanding RustBPETokenizer.train_from_iterator()

This notebook breaks down the `train_from_iterator()` function step by step to understand how BPE tokenization training works.

## What is BPE (Byte Pair Encoding)?

BPE is a compression algorithm adapted for tokenization:
1. Start with individual bytes (256 base tokens)
2. Find the most frequent pair of tokens
3. Merge them into a new token
4. Repeat until desired vocabulary size is reached

**Example:**
- Text: `"aaabdaaabac"`
- Most frequent pair: `"aa"` (appears 4 times)
- After merge: `"ZabdZabac"` (where Z = "aa")
- Continue merging...

This allows the tokenizer to learn common subwords like "ing", "tion", "un", etc.

## Step 0: Imports and Setup

In [None]:
import rustbpe
import tiktoken
import pickle

# Special tokens used in the tokenizer
SPECIAL_TOKENS = [
    "<|bos|>",  # Beginning of Sequence
    "<|user_start|>", "<|user_end|>",
    "<|assistant_start|>", "<|assistant_end|>",
    "<|python_start|>", "<|python_end|>",
    "<|output_start|>", "<|output_end|>",
]

# GPT-4 style regex pattern for pre-tokenization
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

print(f"Number of special tokens: {len(SPECIAL_TOKENS)}")
print(f"Split pattern: {SPLIT_PATTERN[:50]}...")

## Step 1: Prepare Sample Training Data

We'll create a simple iterator of text that simulates what would be passed to the training function.

In [None]:
# Sample training data - in real use, this would be millions of documents
sample_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Hello, world! This is a test of tokenization.",
    "Machine learning models need lots of training data.",
    "The weather is beautiful today. The sun is shining.",
    "Python programming is fun and powerful.",
    "The cat sat on the mat. The dog ran in the park.",
] * 100  # Repeat to give more training data

# Create an iterator (this is what train_from_iterator expects)
text_iterator = iter(sample_texts)

# Let's also set a small vocabulary size for this demo
VOCAB_SIZE = 512  # Much smaller than real models (GPT uses 50k-100k)

print(f"Training data: {len(sample_texts)} documents")
print(f"Target vocabulary size: {VOCAB_SIZE}")
print(f"First document: {sample_texts[0]}")

## Step 2: The Complete train_from_iterator Function

Here's the full function we're going to dissect:

In [None]:
def train_from_iterator_complete(text_iterator, vocab_size):
    """
    Complete train_from_iterator function from RustBPETokenizer
    """
    # 1) train using rustbpe, Rust based tokenizer faster than Python
    tokenizer = rustbpe.Tokenizer()

    # the special tokens are inserted later, we don't train them here
    vocab_size_no_special = vocab_size - len(SPECIAL_TOKENS)

    assert vocab_size_no_special >= 256, f"vocab_size_no_special must be at least 256, got {vocab_size_no_special}"
    
    # Train the tokenizer
    tokenizer.train_from_iterator(text_iterator, vocab_size_no_special, pattern=SPLIT_PATTERN)
    
    # 2) construct the associated tiktoken encoding for inference
    pattern = tokenizer.get_pattern()
    mergeable_ranks_list = tokenizer.get_mergeable_ranks()
    mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
    tokens_offset = len(mergeable_ranks)
    special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
    
    enc = tiktoken.Encoding(
        name="rustbpe",
        pat_str=pattern,
        mergeable_ranks=mergeable_ranks, # dict[bytes, int] (token bytes -> merge priority rank)
        special_tokens=special_tokens, # dict[str, int] (special token name -> token id)
    )
    
    return enc

print("Function defined. Let's break it down step by step next!")

---

## STEP-BY-STEP BREAKDOWN

Now let's execute each part of the function separately to understand what's happening.

### STEP 2.1: Initialize the RustBPE Tokenizer

In [None]:
# STEP 2.1: Create a rustbpe.Tokenizer instance
tokenizer = rustbpe.Tokenizer()

print("âœ“ Created rustbpe.Tokenizer instance")
print(f"  Type: {type(tokenizer)}")
print(f"  This is an empty tokenizer that will learn BPE merges from data")

### STEP 2.2: Calculate Vocabulary Size (Reserve Space for Special Tokens)

In [None]:
# STEP 2.2: Reserve space for special tokens
vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)

print(f"âœ“ Calculated vocabulary sizes:")
print(f"  Total vocab size: {VOCAB_SIZE}")
print(f"  Special tokens: {len(SPECIAL_TOKENS)}")
print(f"  Vocab for BPE training: {vocab_size_no_special}")
print()
print(f"  Why? Special tokens are added AFTER training, so we train with fewer slots")
print(f"  The BPE algorithm will learn {vocab_size_no_special} tokens from the data")

# Sanity check
assert vocab_size_no_special >= 256, f"Need at least 256 for base bytes!"
print(f"\n  âœ“ Passed sanity check (>= 256 for all byte values)")

### STEP 2.3: Train the Tokenizer (THE MAIN EVENT!)

This is where the magic happens! The tokenizer will:
1. Split text according to the SPLIT_PATTERN regex
2. Start with 256 base byte tokens
3. Iteratively find the most frequent pair of tokens and merge them
4. Continue until we have `vocab_size_no_special` tokens

In [None]:
# STEP 2.3: Train the tokenizer!
# Note: We need to recreate the iterator since we consumed it earlier
text_iterator = iter(sample_texts)

print("ðŸš€ Starting BPE training...")
print(f"   Training on {len(sample_texts)} documents")
print(f"   Target: {vocab_size_no_special} tokens (256 base + {vocab_size_no_special - 256} merges)")
print()

# This is the heavy lifting - learning which byte pairs to merge
tokenizer.train_from_iterator(
    text_iterator, 
    vocab_size_no_special, 
    pattern=SPLIT_PATTERN
)

print("\nâœ“ Training complete!")
print("  The tokenizer has learned which byte pairs appear frequently")
print("  and should be merged into single tokens")

### STEP 2.4: Extract the Pattern

Retrieve the regex pattern used for pre-tokenization (should be the same one we passed in).

In [None]:
# STEP 2.4: Get the pattern back from the trained tokenizer
pattern = tokenizer.get_pattern()

print("âœ“ Retrieved pattern from trained tokenizer")
print(f"  Pattern: {pattern[:80]}...")
print()
print("  This is the same regex we passed in - it's needed for tiktoken later")
print("  so that inference uses the same pre-tokenization as training")

### STEP 2.5: Extract Mergeable Ranks (The Learned BPE Merges!)

This is the core output of training - the learned merge rules.

In [None]:
# STEP 2.5: Get the mergeable ranks (the learned BPE vocabulary)
mergeable_ranks_list = tokenizer.get_mergeable_ranks()

print("âœ“ Retrieved mergeable ranks from trained tokenizer")
print(f"  Type: {type(mergeable_ranks_list)}")
print(f"  Length: {len(mergeable_ranks_list)} tokens")
print()
print("  What is this? A list of (token_bytes, rank) pairs")
print("  - token_bytes: The bytes that make up this token")
print("  - rank: The priority/order in which this merge was learned (0 = first)")
print()
print("  First 10 tokens (these are the base bytes 0-9):")
for i in range(10):
    token_bytes, rank = mergeable_ranks_list[i]
    print(f"    {i}: bytes={list(token_bytes)}, rank={rank}, char='{chr(token_bytes[0]) if len(token_bytes)==1 and 32<=token_bytes[0]<127 else '?'}'")
print()
print("  Last 10 tokens (these are the most recently learned merges):")
for i in range(-10, 0):
    token_bytes, rank = mergeable_ranks_list[i]
    try:
        text = token_bytes.decode('utf-8')
    except:
        text = repr(token_bytes)
    print(f"    Token {len(mergeable_ranks_list)+i}: rank={rank}, bytes={text}")

### STEP 2.6: Convert to Dictionary Format for tiktoken

tiktoken expects `mergeable_ranks` as a dict mapping `bytes -> int` (rank).

In [None]:
# STEP 2.6: Convert list to dict for tiktoken
# tiktoken wants: dict[bytes, int] where int is the merge rank
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}

print("âœ“ Converted to dictionary format")
print(f"  Type: {type(mergeable_ranks)}")
print(f"  Length: {len(mergeable_ranks)}")
print()
print("  Example entries:")
for i, (token_bytes, rank) in enumerate(list(mergeable_ranks.items())[:5]):
    try:
        text = token_bytes.decode('utf-8')
    except:
        text = repr(token_bytes)
    print(f"    {text!r} -> rank {rank}")

### STEP 2.7: Add Special Tokens

Special tokens get IDs starting AFTER all the learned BPE tokens.

In [None]:
# STEP 2.7: Create special tokens mapping
# Special tokens get IDs starting from tokens_offset (after all BPE tokens)
tokens_offset = len(mergeable_ranks)
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}

print("âœ“ Created special tokens mapping")
print(f"  Offset (first special token ID): {tokens_offset}")
print(f"  Number of special tokens: {len(special_tokens)}")
print()
print("  Special token mappings:")
for name, token_id in special_tokens.items():
    print(f"    '{name}' -> ID {token_id}")
print()
print(f"  Total vocabulary size: {tokens_offset + len(special_tokens)}")
print(f"  (Should equal our target: {VOCAB_SIZE})")

### STEP 2.8: Create the tiktoken Encoding Object (Final Output!)

This is what gets returned and used for fast inference.

In [None]:
# STEP 2.8: Create the tiktoken.Encoding object
enc = tiktoken.Encoding(
    name="rustbpe",
    pat_str=pattern,              # The regex pattern for pre-tokenization
    mergeable_ranks=mergeable_ranks,  # The learned BPE merges
    special_tokens=special_tokens,    # Our special tokens
)

print("âœ“ Created tiktoken.Encoding object!")
print(f"  Type: {type(enc)}")
print(f"  Name: {enc.name}")
print(f"  Vocabulary size: {enc.n_vocab}")
print()
print("  This encoding object can now be used for FAST tokenization:")
print("  - encode(): text -> token IDs")
print("  - decode(): token IDs -> text")
print()
print("ðŸŽ‰ Training complete! The tokenizer is ready to use.")

---

## Step 3: Test the Trained Tokenizer!

Let's see the tokenizer in action.

In [None]:
# Test encoding
test_text = "Hello, world! The quick brown fox jumps."

# Encode
token_ids = enc.encode_ordinary(test_text)
print(f"Original text: {test_text}")
print(f"Token IDs: {token_ids}")
print(f"Number of tokens: {len(token_ids)}")
print()

# Decode
decoded_text = enc.decode(token_ids)
print(f"Decoded text: {decoded_text}")
print(f"Match original? {decoded_text == test_text}")
print()

# Show individual tokens
print("Individual tokens:")
for i, token_id in enumerate(token_ids):
    token_text = enc.decode([token_id])
    print(f"  {i}: ID={token_id:4d} -> {token_text!r}")

### Test Special Tokens

In [None]:
# Test special token encoding
bos_id = enc.encode_single_token("<|bos|>")
user_start_id = enc.encode_single_token("<|user_start|>")

print("Special token IDs:")
print(f"  <|bos|> -> {bos_id}")
print(f"  <|user_start|> -> {user_start_id}")
print()

# Test with a simple conversation-like sequence
conversation_text = "<|bos|><|user_start|>Hello!<|user_end|><|assistant_start|>Hi there!<|assistant_end|>"
# Note: encode_ordinary won't process special tokens in the text
# In real usage, RustBPETokenizer builds these sequences programmatically

print("In real usage, special tokens are added programmatically,")
print("not as text to be encoded.")

---

## Summary: What train_from_iterator Does

### Input:
- `text_iterator`: Iterator over text documents
- `vocab_size`: Desired vocabulary size (e.g., 512, 50000, 100000)

### Process:
1. **Initialize** rustbpe tokenizer
2. **Calculate** vocab size without special tokens (need room for them)
3. **Train** BPE algorithm:
   - Pre-tokenize text using SPLIT_PATTERN regex
   - Start with 256 base byte tokens
   - Iteratively merge most frequent byte pairs
   - Continue until reaching target vocabulary size
4. **Extract** learned merge rules (mergeable_ranks)
5. **Convert** to tiktoken-compatible format
6. **Add** special tokens at the end
7. **Create** tiktoken.Encoding object for fast inference

### Output:
- A `tiktoken.Encoding` object that can:
  - `encode()`: Convert text to token IDs
  - `decode()`: Convert token IDs to text
  - Handle special tokens
  - Run efficiently in production

### Key Insight:
**Train with Rust (fast), Infer with tiktoken (also fast)** - Best of both worlds!

---

## Bonus: What Did BPE Learn?

Let's examine some of the merged tokens to see what patterns BPE discovered.

In [None]:
# Look at tokens beyond the base 256 bytes (these are the learned merges)
print("Learned multi-byte tokens (showing a sample):")
print()

learned_tokens = []
for token_bytes, rank in mergeable_ranks.items():
    if len(token_bytes) > 1:  # Multi-byte tokens (learned merges)
        try:
            text = token_bytes.decode('utf-8')
            learned_tokens.append((rank, text, token_bytes))
        except:
            pass  # Skip non-UTF8 tokens

# Sort by rank (earlier ranks = more frequent merges)
learned_tokens.sort()

print(f"Total learned multi-byte tokens: {len(learned_tokens)}")
print()
print("First 30 learned merges (most frequent patterns):")
for i, (rank, text, token_bytes) in enumerate(learned_tokens[:30]):
    print(f"  Rank {rank:3d}: {text!r:20s} (len={len(token_bytes)})")

print()
print("These are common patterns BPE discovered in our training data!")
print("For example, you might see tokens like:")
print("  - ' the' (space + the)")
print("  - 'ing' (common suffix)")
print("  - 'er', 'ed' (common endings)")
print("  - Common words that appear frequently")

### Why is BPE Better Than Character-Level?

Let's compare BPE tokenization to simple character-level tokenization.

In [None]:
comparison_text = "The quick brown fox jumps over the lazy dog. Machine learning is amazing!"

# BPE tokenization
bpe_tokens = enc.encode_ordinary(comparison_text)
bpe_count = len(bpe_tokens)

# Character-level (naive approach)
char_count = len(comparison_text)

print(f"Text: {comparison_text}")
print()
print(f"Character-level: {char_count} tokens")
print(f"BPE tokenization: {bpe_count} tokens")
print(f"Compression ratio: {char_count / bpe_count:.2f}x")
print()
print("Benefits of BPE:")
print("  1. Fewer tokens = shorter sequences for the model to process")
print("  2. Common words/subwords are single tokens (more efficient)")
print("  3. Rare words split into known subwords (better generalization)")
print("  4. Variable-length encoding (optimal for language structure)")
print()
print("BPE tokens:")
for i, tid in enumerate(bpe_tokens[:20]):  # Show first 20
    print(f"  {enc.decode([tid])!r}", end=" ")
    if i > 0 and (i+1) % 10 == 0:
        print()
print("\n...")

---

## Visual Flow Diagram

```
train_from_iterator(text_iterator, vocab_size)
â”‚
â”œâ”€ Step 1: Initialize rustbpe.Tokenizer()
â”‚          â””â”€> Empty tokenizer ready to learn
â”‚
â”œâ”€ Step 2: Calculate vocab_size_no_special
â”‚          â””â”€> vocab_size - 9 special tokens
â”‚
â”œâ”€ Step 3: Train BPE Algorithm
â”‚          â”œâ”€> Input: text documents
â”‚          â”œâ”€> Pre-tokenize with SPLIT_PATTERN regex
â”‚          â”œâ”€> Start with 256 base bytes
â”‚          â”œâ”€> Merge frequent pairs iteratively
â”‚          â””â”€> Output: learned merge rules
â”‚
â”œâ”€ Step 4: Extract Training Results
â”‚          â”œâ”€> pattern = tokenizer.get_pattern()
â”‚          â””â”€> mergeable_ranks = tokenizer.get_mergeable_ranks()
â”‚                 â”‚
â”‚                 â””â”€> List[(bytes, rank)] of all tokens
â”‚
â”œâ”€ Step 5: Convert to tiktoken Format
â”‚          â””â”€> Dict[bytes -> int]
â”‚
â”œâ”€ Step 6: Add Special Tokens
â”‚          â””â”€> Map special token names to IDs
â”‚             (IDs start after all BPE tokens)
â”‚
â””â”€ Step 7: Create tiktoken.Encoding
           â””â”€> Fast inference-ready tokenizer
               â”œâ”€> encode(text) -> [token_ids]
               â””â”€> decode([token_ids]) -> text
```

---

## Next Steps

To use this in RustBPETokenizer class:
1. Wrap the tiktoken.Encoding in the class
2. Add helper methods (get_bos_token_id, render_conversation, etc.)
3. Implement save/load functionality

See `nanochat/tokenizer.py` for the full implementation!