In [8]:
import json
import os
from typing import Any, List, Dict, Tuple, Iterable, Iterator

In [9]:
# uv run pytest tests/test_tokenizer.py
# adapter code
def get_tokenizer(
    vocab: dict[int, bytes],
    merges: list[tuple[bytes, bytes]],
    special_tokens: list[str] | None = None,
) -> Any:
    """Given a vocabulary, a list of merges, and a list of special tokens,
    return a BPE tokenizer that uses the provided vocab, merges, and special tokens.

    Args:
        vocab (dict[int, bytes]): The tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
            to bytes (token bytes)
        merges (list[tuple[bytes, bytes]]): BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
            representing that <token1> was merged with <token2>.
            Merges are ordered by order of creation.
        special_tokens (list[str] | None): A list of string special tokens for the tokenizer. These strings will never
            be split into multiple tokens, and will always be kept as a single token.

    Returns:
        A BPE tokenizer that uses the provided vocab, merges, and special tokens.
    """
    raise NotImplementedError

In [22]:
import regex as re

class BPETokenizer:
    def __init__(self, vocab, merges, special_tokens=None):
        self.vocab=vocab
        self.merges=[
            (pair[0].encode('utf-8'), pair[1].encode('utf-8'))
            for pair in merges
        ]
        self.special_tokens=special_tokens
        
        if special_tokens:
            self.split_pat = re.compile(
                "(" + "|".join(re.escape(tok) for tok in special_tokens) + ")"
            )
        else:
            self.split_pat = None
    
    @classmethod
    def from_files(cls, vocab_filepath, merges_filepath, special_tokens=None):
        with open(vocab_filepath) as f:
            vocab = json.load(f)
        
        with open(merges_filepath) as f:
            merges = [tuple(line.rstrip().split(" ")) for line in f]
        
        tokenizer = cls(
            vocab=vocab,
            merges=merges,
            special_tokens=special_tokens
        )
        return tokenizer
    
    def merge(self, indices: List[bytes], merge_pair: Tuple[bytes, bytes]) -> List[bytes]:
        merged_index = b''.join(merge_pair)
        new_indices = []
        
        i=0
        while i < len(indices):
            if i+1 < len(indices):
                pair = (indices[i], indices[i+1])
                # pair = (bytes([indices[i]]), bytes([indices[i+1]]))
                if pair==merge_pair:
                    new_indices.append(merged_index)
                    i+=2
                else:
                    new_indices.append(indices[i])
                    i+=1
            else:
                new_indices.append(indices[i])
                i+=1
        return new_indices
    
    def encode(self, text: str) -> list[int]:
        if self.special_tokens:
            parts = self.split_pat.split(text)
        else:
            parts = [text]
        
        indices = []
        for part in parts:
            # Handle special tokens
            if self.special_tokens and part in self.special_tokens:
                index = self.vocab[part.encode('utf-8')]
                indices.append(index)
                continue
                
            # Tokenize
            part_bytes = part.encode('utf-8')
            part_indices = list(map(lambda x: bytes([x]), part_bytes))
            
            # Merge
            for merge_pair in self.merges:
                part_indices = self.merge(part_indices, merge_pair)
            
            indices.extend(part_indices)
        return indices
             
    
    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        pass
    
    def decode(self, ids: list[int]) -> str:
        pass

In [23]:
pair = (b'a', b'c')
b''.join(pair)

b'ac'

In [24]:
for x in "abc".encode('utf-8'):
    print(x, bytes([x]))
    
part_bytes = 'abc'.encode('utf-8')
for x in map(lambda x: bytes([x]), part_bytes):
    print(x)

97 b'a'
98 b'b'
99 b'c'
b'a'
b'b'
b'c'


In [25]:
FIXTURES_PATH="../../tests/fixtures"
vocab_path = os.path.join(FIXTURES_PATH, "gpt2_vocab.json")
merges_path = os.path.join(FIXTURES_PATH, "gpt2_merges.txt")


In [26]:
tokenizer = BPETokenizer.from_files(
    vocab_filepath=vocab_path,
    merges_filepath=merges_path,
    # special_tokens=["<|endoftext|>"]
)

In [27]:
tokenizer.merges

[(b'\xc4\xa0', b't'),
 (b'\xc4\xa0', b'a'),
 (b'h', b'e'),
 (b'i', b'n'),
 (b'r', b'e'),
 (b'o', b'n'),
 (b'\xc4\xa0t', b'he'),
 (b'e', b'r'),
 (b'\xc4\xa0', b's'),
 (b'a', b't'),
 (b'\xc4\xa0', b'w'),
 (b'\xc4\xa0', b'o'),
 (b'e', b'n'),
 (b'\xc4\xa0', b'c'),
 (b'i', b't'),
 (b'i', b's'),
 (b'a', b'n'),
 (b'o', b'r'),
 (b'e', b's'),
 (b'\xc4\xa0', b'b'),
 (b'e', b'd'),
 (b'\xc4\xa0', b'f'),
 (b'in', b'g'),
 (b'\xc4\xa0', b'p'),
 (b'o', b'u'),
 (b'\xc4\xa0a', b'n'),
 (b'a', b'l'),
 (b'a', b'r'),
 (b'\xc4\xa0t', b'o'),
 (b'\xc4\xa0', b'm'),
 (b'\xc4\xa0o', b'f'),
 (b'\xc4\xa0', b'in'),
 (b'\xc4\xa0', b'd'),
 (b'\xc4\xa0', b'h'),
 (b'\xc4\xa0an', b'd'),
 (b'i', b'c'),
 (b'a', b's'),
 (b'l', b'e'),
 (b'\xc4\xa0t', b'h'),
 (b'i', b'on'),
 (b'o', b'm'),
 (b'l', b'l'),
 (b'en', b't'),
 (b'\xc4\xa0', b'n'),
 (b'\xc4\xa0', b'l'),
 (b's', b't'),
 (b'\xc4\xa0', b're'),
 (b'v', b'e'),
 (b'\xc4\xa0', b'e'),
 (b'r', b'o'),
 (b'l', b'y'),
 (b'\xc4\xa0b', b'e'),
 (b'\xc4\xa0', b'g'),
 (b'\xc4\xa0', b

In [28]:
tokenizer.encode('hello world<|endoftext|>hehe')

[b'hello',
 b' ',
 b'world',
 b'<',
 b'|',
 b'end',
 b'of',
 b'text',
 b'|',
 b'>',
 b'he',
 b'he']