Problem Statement
Title: Implement Byte Pair Encoding (BPE) from Scratch

Description: You are tasked with implementing Byte Pair Encoding (BPE), a subword tokenization algorithm used in Large Language Models (LLMs) like GPT and BERT to create a vocabulary of subword units. BPE iteratively merges the most frequent pair of adjacent characters or subwords in a text corpus to build a vocabulary, balancing between character-level and word-level representations. Your implementation should process a small text corpus, generate a vocabulary, and encode/decode text using the learned merges. The solution should be in pure Python (no PyTorch or NumPy required, as BPE is a text-processing algorithm), and include detailed comments explaining each step.

Mathematical Definition:

Vocabulary Creation:
Start with a character-level vocabulary (e.g., individual characters in the corpus).
Compute the frequency of all adjacent pairs of tokens (characters or subwords).
Merge the most frequent pair into a new token, updating the corpus.
Repeat for a specified number of merges (e.g., 100).
Encoding:
For a given word, split into characters.
Apply learned merge rules in order to form subword tokens.
Output a sequence of token IDs from the vocabulary.
Decoding:
Convert token IDs back to subword tokens.
Concatenate subwords to reconstruct the original text.
Requirements:

Implement a BPE class with methods for:
train: Learn merge rules from a text corpus.
encode: Convert text to a sequence of token IDs.
decode: Convert token IDs back to text.
Use a small synthetic corpus (e.g., a few sentences) to train the BPE model.
Generate a vocabulary with 100 merges.
Handle edge cases (e.g., unknown characters, empty inputs).
Provide detailed Purpose and Theory comments for each line of code.
Test encoding and decoding on sample texts to verify correctness.
Constraints:

Use only Python standard libraries (e.g., collections, re).
No external libraries like tokenizers or sentencepiece.
Process text at the character level initially.
Ensure encoding/decoding is reversible (lossless).
Vocabulary size = initial characters + 100 merges.
Synthetic Dataset:

Corpus: A small text corpus (e.g., ["hello world", "hello there", "world peace", "hi there"]).
Test Inputs: Words like "hello", "world", "hi", and an unseen word "hellothere".
Vocabulary: Initial character set (e.g., {h, e, l, o, , w, r, d, t, p, a, c, i}) plus 100 merged tokens.
Expected Output:

Vocabulary: A dictionary mapping tokens (characters or subwords) to IDs, with ~113 tokens (13 initial + 100 merges).
Merge Rules: A list of tuples (e.g., [(‘h’, ‘e’), (‘he’, ‘l’)]) defining merge operations.
Encoded Output: Token IDs for test words (e.g., encode("hello") → [7, 8, 9, 9, 10]).
Decoded Output: Reconstructed text (e.g., decode([7, 8, 9, 9, 10]) → "hello").

In [1]:
from collections import Counter, defaultdict
# Purpose: Import Counter for frequency counting and defaultdict for pair tracking.
# Theory: Counter efficiently counts token pairs; defaultdict simplifies pair frequency aggregation.

import re
# Purpose: Import re for splitting text into words.
# Theory: Regular expressions split text on whitespace, preserving words for BPE processing.

class BPE:
    # Purpose: Define BPE class for training, encoding, and decoding text.
    # Theory: Encapsulates vocabulary, merge rules, and tokenization logic for LLMs.
    
    def __init__(self):
        # Purpose: Initialize BPE with empty vocabulary and merge rules.
        # Theory: Sets up data structures for training and tokenization.
        
        self.vocab = {}
        # Purpose: Store token-to-ID mapping.
        # Theory: Maps characters and merged subwords to unique integer IDs.
        
        self.merge_rules = []
        # Purpose: Store list of merge rules as (token1, token2) tuples.
        # Theory: Records the order of merges for consistent encoding.
        
        self.reverse_vocab = {}
        # Purpose: Store ID-to-token mapping for decoding.
        # Theory: Enables reverse lookup to reconstruct text from token IDs.
    
    def get_stats(self, word_counts):
        # Purpose: Compute frequency of adjacent token pairs in the corpus.
        # Theory: Identifies the most frequent pair for merging, a core step in BPE training.
        
        pairs = defaultdict(int)
        # Purpose: Initialize dictionary to count pair frequencies.
        # Theory: defaultdict(int) assigns 0 to new pairs, simplifying counting.
        
        for word, count in word_counts.items():
            # Purpose: Iterate over words and their frequencies.
            # Theory: Processes each word in the corpus, weighted by its frequency.
            
            tokens = list(word)
            # Purpose: Split word into list of tokens (initially characters).
            # Theory: Represents word as a sequence of tokens for pair analysis.
            
            for i in range(len(tokens) - 1):
                # Purpose: Iterate over adjacent token pairs.
                # Theory: Counts (t_i, t_{i+1}) pairs to find frequent combinations.
                
                pairs[(tokens[i], tokens[i + 1])] += count
                # Purpose: Increment frequency of the pair (t_i, t_{i+1}).
                # Theory: Aggregates pair counts across all occurrences in the corpus.
        
        return pairs
        # Purpose: Return dictionary of pair frequencies.
        # Theory: Provides data for selecting the most frequent pair to merge.
    
    def merge_pair(self, pair, word_counts):
        # Purpose: Merge a given pair in all words, updating word_counts.
        # Theory: Replaces adjacent occurrences of pair with a new token, advancing BPE training.
        
        new_token = pair[0] + pair[1]
        # Purpose: Create new token by concatenating the pair.
        # Theory: Combines two tokens (e.g., 'h', 'e' → 'he') to form a subword unit.
        
        new_word_counts = defaultdict(int)
        # Purpose: Initialize new dictionary for updated word representations.
        # Theory: Stores words after merging the pair, preserving frequencies.
        
        for word, count in word_counts.items():
            # Purpose: Iterate over current word representations.
            # Theory: Processes each word to apply the merge rule.
            
            tokens = list(word)
            # Purpose: Convert word to list of tokens.
            # Theory: Allows manipulation of token sequences for merging.
            
            i = 0
            new_tokens = []
            # Purpose: Initialize list for new token sequence after merging.
            # Theory: Builds the updated word representation.
            
            while i < len(tokens):
                # Purpose: Iterate through tokens to find and merge the pair.
                # Theory: Scans for adjacent tokens matching the pair to merge.
                
                if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair:
                    # Purpose: Check if current position matches the pair to merge.
                    # Theory: Identifies occurrences of the pair (e.g., 'h', 'e').
                    
                    new_tokens.append(new_token)
                    # Purpose: Add merged token to the new sequence.
                    # Theory: Replaces pair with new token (e.g., 'he').
                    
                    i += 2
                    # Purpose: Skip the merged pair.
                    # Theory: Advances past both tokens to avoid re-processing.
                else:
                    # Purpose: Keep non-matching token as is.
                    # Theory: Preserves tokens not involved in the merge.
                    
                    new_tokens.append(tokens[i])
                    i += 1
            new_word = ''.join(new_tokens)
            # Purpose: Join tokens into a new word representation.
            # Theory: Converts token list back to a string for further processing.
            
            new_word_counts[new_word] += count
            # Purpose: Update frequency of the new word representation.
            # Theory: Maintains corpus frequency after merging.
        
        return new_word_counts
        # Purpose: Return updated word counts after merging.
        # Theory: Provides new corpus representation for the next merge iteration.
    
    def train(self, corpus, num_merges=100):
        # Purpose: Train BPE model by learning merge rules from the corpus.
        # Theory: Iteratively merges the most frequent token pairs to build vocabulary.
        
        # Split corpus into words and count frequencies
        word_counts = Counter(' '.join(corpus).split())
        # Purpose: Count frequency of each word in the corpus.
        # Theory: Treats corpus as a single string, splits on whitespace to get words.
        
        # Initialize vocabulary with characters
        vocab = set()
        for word in word_counts:
            # Purpose: Collect unique characters from all words.
            # Theory: Initial vocabulary includes all characters in the corpus.
            
            for char in word:
                vocab.add(char)
        
        self.vocab = {char: i for i, char in enumerate(sorted(vocab))}
        # Purpose: Assign IDs to initial characters.
        # Theory: Creates a mapping from characters to unique IDs, starting the vocabulary.
        
        # Split words into character sequences
        word_counts = {''.join(list(word)): count for word, count in word_counts.items()}
        # Purpose: Represent words as character sequences for merging.
        # Theory: Prepares words for token pair analysis, preserving frequencies.
        
        # Perform merges
        for merge_idx in range(num_merges):
            # Purpose: Iterate for the specified number of merges.
            # Theory: Each iteration merges the most frequent pair, growing the vocabulary.
            
            pairs = self.get_stats(word_counts)
            # Purpose: Compute pair frequencies in the current corpus.
            # Theory: Identifies candidates for merging based on frequency.
            
            if not pairs:
                # Purpose: Break if no pairs are available.
                # Theory: Handles case where no further merges are possible (e.g., short words).
                
                break
            
            best_pair = max(pairs, key=pairs.get)
            # Purpose: Select the most frequent pair to merge.
            # Theory: Chooses the pair with the highest frequency for vocabulary expansion.
            
            word_counts = self.merge_pair(best_pair, word_counts)
            # Purpose: Update corpus by merging the selected pair.
            # Theory: Replaces all occurrences of the pair with a new token.
            
            new_token = best_pair[0] + best_pair[1]
            # Purpose: Create new token from the merged pair.
            # Theory: Adds the merged token to the vocabulary.
            
            self.vocab[new_token] = len(self.vocab)
            # Purpose: Assign a new ID to the merged token.
            # Theory: Expands vocabulary with the new subword unit.
            
            self.merge_rules.append(best_pair)
            # Purpose: Record the merge rule.
            # Theory: Stores the pair for use in encoding new text.
        
        self.reverse_vocab = {v: k for k, v in self.vocab.items()}
        # Purpose: Create reverse mapping from IDs to tokens.
        # Theory: Enables decoding by mapping token IDs back to strings.
    
    def encode(self, text):
        # Purpose: Encode text into a sequence of token IDs.
        # Theory: Applies learned merge rules to tokenize text into subword units.
        
        if not text:
            # Purpose: Handle empty input.
            # Theory: Returns empty list for empty text to avoid errors.
            
            return []
        
        # Split text into words
        words = text.split()
        # Purpose: Split input text into words.
        # Theory: Processes each word independently, as in BPE training.
        
        token_ids = []
        # Purpose: Initialize list to store token IDs.
        # Theory: Collects IDs for the final encoded sequence.
        
        for word in words:
            # Purpose: Process each word in the input text.
            # Theory: Encodes each word separately, handling spaces implicitly.
            
            tokens = list(word)
            # Purpose: Split word into characters.
            # Theory: Initializes tokenization at the character level.
            
            # Apply merge rules
            for pair in self.merge_rules:
                # Purpose: Iterate through merge rules in order.
                # Theory: Applies merges in the order they were learned to ensure consistency.
                
                i = 0
                while i < len(tokens) - 1:
                    # Purpose: Scan tokens for the current merge rule.
                    # Theory: Checks for adjacent tokens matching the pair to merge.
                    
                    if (tokens[i], tokens[i + 1]) == pair:
                        # Purpose: Check if current pair matches the merge rule.
                        # Theory: Identifies mergeable tokens for replacement.
                        
                        tokens[i] = pair[0] + pair[1]
                        # Purpose: Merge the pair into a new token.
                        # Theory: Replaces two tokens with a single subword unit.
                        
                        tokens.pop(i + 1)
                        # Purpose: Remove the second token of the pair.
                        # Theory: Updates the token list after merging.
                    else:
                        i += 1
                        # Purpose: Move to the next position if no merge.
                        # Theory: Continues scanning for mergeable pairs.
            
            # Convert tokens to IDs
            for token in tokens:
                # Purpose: Map each token to its ID.
                # Theory: Uses vocabulary to convert subwords to IDs, falling back to character IDs for unknown tokens.
                
                if token in self.vocab:
                    token_ids.append(self.vocab[token])
                else:
                    # Handle unknown tokens
                    for char in token:
                        # Purpose: Split unknown token into characters.
                        # Theory: Falls back to character-level encoding for robustness.
                        
                        token_ids.append(self.vocab.get(char, self.vocab.get('<unk>', -1)))
                        # Purpose: Append character ID or -1 for unknown characters.
                        # Theory: Ensures encoding continues even for unseen characters.
        
        return token_ids
        # Purpose: Return the sequence of token IDs.
        # Theory: Represents the tokenized text for LLM input.
    
    def decode(self, token_ids):
        # Purpose: Decode a sequence of token IDs back to text.
        # Theory: Reconstructs text by mapping IDs to tokens and concatenating.
        
        tokens = [self.reverse_vocab.get(id, '<unk>') for id in token_ids]
        # Purpose: Map token IDs to their string representations.
        # Theory: Uses reverse vocabulary to convert IDs to subwords or characters.
        
        return ''.join(tokens)
        # Purpose: Concatenate tokens to form the output text.
        # Theory: Joins subwords without spaces, as BPE tokens are concatenated directly.

# Test BPE implementation
if __name__ == "__main__":
    # Purpose: Run a test of the BPE implementation.
    # Theory: Demonstrates training, encoding, and decoding on a synthetic corpus.
    
    # Synthetic corpus
    corpus = ["hello world", "hello there", "world peace", "hi there"]
    # Purpose: Define a small text corpus for training.
    # Theory: Mimics a dataset for LLMs, with repeated words to learn meaningful merges.
    
    bpe = BPE()
    # Purpose: Initialize BPE model.
    # Theory: Sets up vocabulary and merge rules for training.
    
    bpe.train(corpus, num_merges=100)
    # Purpose: Train BPE on the corpus with 100 merges.
    # Theory: Builds vocabulary and merge rules based on frequent pairs.
    
    print(f"Vocabulary Size: {len(bpe.vocab)}")
    # Purpose: Print the size of the learned vocabulary.
    # Theory: Expected to be ~113 (initial characters + 100 merges).
    
    print(f"Sample Merge Rules: {bpe.merge_rules[:4]}")
    # Purpose: Print first few merge rules for inspection.
    # Theory: Shows the most frequent pairs merged during training.
    
    # Test encoding
    test_text = "hello"
    encoded = bpe.encode(test_text)
    # Purpose: Encode a test word.
    # Theory: Converts text to token IDs using learned merge rules.
    
    print(f"Encoded '{test_text}': {encoded}")
    # Purpose: Print encoded token IDs.
    # Theory: Shows the tokenized representation of the input.
    
    decoded = bpe.decode(encoded)
    # Purpose: Decode the token IDs back to text.
    # Theory: Verifies that encoding is reversible.
    
    print(f"Decoded {encoded}: {decoded}")
    # Purpose: Print decoded text.
    # Theory: Should match the original input if encoding/decoding is lossless.
    
    # Test unseen word
    test_text = "hellothere"
    encoded = bpe.encode(test_text)
    # Purpose: Encode an unseen word.
    # Theory: Tests handling of out-of-vocabulary words using character fallback.
    
    print(f"Encoded '{test_text}': {encoded}")
    # Purpose: Print encoded token IDs for unseen word.
    # Theory: Shows robustness to new words.
    
    decoded = bpe.decode(encoded)
    # Purpose: Decode the token IDs.
    # Theory: Verifies correct reconstruction of the unseen word.
    
    print(f"Decoded {encoded}: {decoded}")
    # Purpose: Print decoded text.
    # Theory: Confirms lossless encoding/decoding.

Vocabulary Size: 13
Sample Merge Rules: [('h', 'e'), ('h', 'e'), ('h', 'e'), ('h', 'e')]
Encoded 'hello': [13, 6, 6, 7]
Decoded [13, 6, 6, 7]: hello
Encoded 'hellothere': [13, 6, 6, 7, 10, 13, 9, 3]
Decoded [13, 6, 6, 7, 10, 13, 9, 3]: hellothere
