In [85]:
import json
import os
import regex as re
from typing import Any, List, Dict, Tuple, Iterable, Iterator

import sys
# sys.path.append('../../tests')
from common import gpt2_bytes_to_unicode

In [86]:
import tiktoken
reference_tokenizer = tiktoken.get_encoding("gpt2")

In [87]:
import json
from typing import Any, List, Dict, Tuple, Iterable, Iterator

import regex as re

PAT=r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

class BPETokenizer:
    def __init__(
        self,
        vocab: dict[int, bytes],
        merges: list[tuple[bytes, bytes]],
        special_tokens: list[str] | None = None,
    ):
        self.vocab=vocab
        self.inv_vocab = {v:k for k,v in vocab.items()}

        self.merges=merges
        if isinstance(special_tokens, list) and special_tokens:
            # Sort to ensure case ['<|eot|><|eot|>', '<|eot|>']
            self.special_tokens=sorted(special_tokens, key=lambda x: (-len(x), x))
            self.split_pat = re.compile(
                "(" + "|".join(re.escape(tok) for tok in self.special_tokens) + ")"
            )
        else:
            self.special_tokens=special_tokens
            self.split_pat = None
        
    @classmethod
    def from_files(cls, vocab_filepath, merges_filepath, special_tokens=None):
        with open(vocab_filepath) as f:
            vocab = json.load(f)
        
        with open(merges_filepath) as f:
            merges = [tuple(line.rstrip().split(" ")) for line in f]
        
        tokenizer = cls(
            vocab=vocab,
            merges=merges,
            special_tokens=special_tokens
        )
        return tokenizer
    
    def id_to_token(self, token_id: int) -> bytes:
        return self.vocab[token_id]
    
    def token_to_id(self, token: bytes) -> int:
        if token not in self.inv_vocab:
            raise ValueError(
                "token {} not in vocab".format(token.decode('utf-8', errors='ignore'))
            )
        return self.inv_vocab[token]
        
    
    def merge(self, indices: List[bytes], merge_pair: Tuple[bytes, bytes]) -> List[bytes]:
        merged_index = b''.join(merge_pair)
        new_indices = []
        
        i=0
        while i < len(indices):
            if i+1 < len(indices):
                pair = (indices[i], indices[i+1])
                # pair = (bytes([indices[i]]), bytes([indices[i+1]]))
                if pair==merge_pair:
                    new_indices.append(merged_index)
                    i+=2
                else:
                    new_indices.append(indices[i])
                    i+=1
            else:
                new_indices.append(indices[i])
                i+=1
        return new_indices
    
    def tokenize(self, text: str) -> list[bytes]:
        if self.special_tokens:
            parts = self.split_pat.split(text)
        else:
            parts = [text]
        
        if self.special_tokens:
            parts = self.split_pat.split(text)
        else:
            parts = [text]
        
        indices = []
        for part in parts:
            # Handle special tokens
            if self.special_tokens and part in self.special_tokens:
                indices.append(part.encode('utf-8'))
                continue
            
            for pretok_match in re.finditer(PAT, part):
                pretok = pretok_match.group()
                # Tokenize
                part_bytes = pretok.encode('utf-8')
                part_indices = list(map(lambda x: bytes([x]), part_bytes))
                
                # Merge
                for merge_pair in self.merges:
                    part_indices = self.merge(part_indices, merge_pair)
            
                indices.extend(part_indices)
        return indices
    
    def encode(self, text: str) -> list[int]:
        tokens = self.tokenize(text)
        indices = [self.token_to_id(x) for x in tokens]
        return indices
             
    
    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        for x in iterable:
            for token in self.encode(x):
                yield token
    
    def decode(self, ids: list[int]) -> str:
        tokens = [self.id_to_token(x) for x in ids]
        return b''.join(tokens).decode('utf-8', errors='ignore')

In [88]:
# uv run pytest tests/test_tokenizer.py
# adapter code
def get_tokenizer(
    vocab: dict[int, bytes],
    merges: list[tuple[bytes, bytes]],
    special_tokens: list[str] | None = None,
) -> Any:
    """Given a vocabulary, a list of merges, and a list of special tokens,
    return a BPE tokenizer that uses the provided vocab, merges, and special tokens.

    Args:
        vocab (dict[int, bytes]): The tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
            to bytes (token bytes)
        merges (list[tuple[bytes, bytes]]): BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
            representing that <token1> was merged with <token2>.
            Merges are ordered by order of creation.
        special_tokens (list[str] | None): A list of string special tokens for the tokenizer. These strings will never
            be split into multiple tokens, and will always be kept as a single token.

    Returns:
        A BPE tokenizer that uses the provided vocab, merges, and special tokens.
    """
    raise NotImplementedError

In [89]:
pair = (b'a', b'c')
b''.join(pair)

b'ac'

In [90]:
for x in "abc".encode('utf-8'):
    print(x, bytes([x]))
    
part_bytes = 'abc'.encode('utf-8')
for x in map(lambda x: bytes([x]), part_bytes):
    print(x)

97 b'a'
98 b'b'
99 b'c'
b'a'
b'b'
b'c'


In [91]:
FIXTURES_PATH="../../tests/fixtures"
vocab_path = os.path.join(FIXTURES_PATH, "gpt2_vocab.json")
merges_path = os.path.join(FIXTURES_PATH, "gpt2_merges.txt")

special_tokens = ['<|endoftext|>']
special_tokens = ["<|endoftext|>", "<|endoftext|><|endoftext|>"]

In [92]:
# GPT2
gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()}
with open(vocab_path) as vocab_f:
    gpt2_vocab = json.load(vocab_f)
gpt2_bpe_merges = []
with open(merges_path) as f:
    for line in f:
        cleaned_line = line.rstrip()
        if cleaned_line and len(cleaned_line.split(" ")) == 2:
            gpt2_bpe_merges.append(tuple(cleaned_line.split(" ")))
# The GPT-2 tokenizer uses a remapped unicode encoding for bytes. Let's
# just return the original bytes, so we don't force students to use
# any particular encoding scheme.
vocab = {
    gpt2_vocab_index: bytes([gpt2_byte_decoder[token] for token in gpt2_vocab_item])
    for gpt2_vocab_item, gpt2_vocab_index in gpt2_vocab.items()
}
# If any of the special tokens don't exist in the vocab, append them to the vocab.
if special_tokens:
    for special_token in special_tokens:
        byte_encoded_special_token = special_token.encode("utf-8")
        if byte_encoded_special_token not in set(vocab.values()):
            vocab[len(vocab)] = byte_encoded_special_token

merges = [
    (
        bytes([gpt2_byte_decoder[token] for token in merge_token_1]),
        bytes([gpt2_byte_decoder[token] for token in merge_token_2]),
    )
    for merge_token_1, merge_token_2 in gpt2_bpe_merges
]

In [93]:
tokenizer = BPETokenizer(
    vocab=vocab,
    merges=merges,
    special_tokens=special_tokens
)
# tokenizer = BPETokenizer.from_files(
#     vocab_filepath=vocab_path,
#     merges_filepath=merges_path,
#     # special_tokens=["<|endoftext|>"]
# )

In [94]:
tokenizer.special_tokens
tokenizer.split_pat

regex.Regex('(<\\|endoftext\\|><\\|endoftext\\|>|<\\|endoftext\\|>)', flags=regex.V0)

In [95]:
# tokenizer.vocab#.keys()

In [96]:
tokenizer.tokenize('hello world<|endoftext|><|endoftext|>hehe')

[b'hello', b' world', b'<|endoftext|><|endoftext|>', b'he', b'he']

In [97]:
tokenizer.tokenize('hello world<|endoftext|>hehe')

[b'hello', b' world', b'<|endoftext|>', b'he', b'he']

In [98]:
tokenizer.encode('hello world<|endoftext|>hehe')

[31373, 995, 50256, 258, 258]

In [99]:
# debugging test_encode_special_token_double_newline_non_whitespace
print(tokenizer.id_to_token(198))
print(tokenizer.id_to_token(628))

b'\n'
b'\n\n'


In [100]:
x = '''<|endoftext|>

testing!'''
print(repr(x))

'<|endoftext|>\n\ntesting!'


In [101]:
reference_tokenizer.encode(x, allowed_special={"<|endoftext|>"})

[50256, 198, 198, 33407, 0]

In [102]:
tokenizer.encode(x)

[50256, 198, 198, 33407, 0]

In [103]:
PAT=r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

for match in re.finditer(PAT, x):
    print(match)

<regex.Match object; span=(0, 2), match='<|'>
<regex.Match object; span=(2, 11), match='endoftext'>
<regex.Match object; span=(11, 13), match='|>'>
<regex.Match object; span=(13, 14), match='\n'>
<regex.Match object; span=(14, 15), match='\n'>
<regex.Match object; span=(15, 22), match='testing'>
<regex.Match object; span=(22, 23), match='!'>
