In [137]:
def get_stats(ids, counts=None):
    """
    Given a list of integers, return a dictionary of counts of consecutive pairs
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    Optionally allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]): # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts


def merge(ids, pair, idx):
    """
    In the list of integers (ids), replace all consecutive occurrences
    of pair with the new integer token idx
    Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
    """
    newids = []
    i = 0
    while i < len(ids):
        # if not at the very last position AND the pair matches, replace it
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids




    

In [138]:
import regex as re



# the main GPT text split patterns, see
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

pattern = GPT4_SPLIT_PATTERN 
compiled_pattern = re.compile(pattern)
special_tokens = {}
inverse_special_tokens = {}

In [139]:

merges = {}
def train(text, vocab_size, verbose=True):
    assert vocab_size >= 256
    num_merges = vocab_size - 256

    # split the text up into text chunks
    text_chunks = re.findall(compiled_pattern, text)

    # input text preprocessing
    ids = [list(ch.encode("utf-8")) for ch in text_chunks]

    # iteratively merge the most common pairs to create new tokens
    merges = {} # (int, int) -> int
    vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
    for i in range(num_merges):
        # count the number of times every consecutive pair appears
        stats = {}
        for chunk_ids in ids:
            # passing in stats will update it in place, adding up counts
            get_stats(chunk_ids, stats)
        # find the pair with the highest count
        pair = max(stats, key=stats.get)
        # mint a new token: assign it the next available id
        idx = 256 + i
        # replace all occurrences of pair in ids with idx
        ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
        # save the merge
        merges[pair] = idx
        vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
        # prints
        if verbose:
            print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")


In [140]:
def _build_vocab():
    # vocab is simply and deterministically derived from merges
    vocab = {idx: bytes([idx]) for idx in range(256)}
    for (p0, p1), idx in merges.items():
        vocab[idx] = vocab[p0] + vocab[p1]
    for special, idx in special_tokens.items():
        vocab[idx] = special.encode("utf-8")
    return vocab

In [141]:
vocab = _build_vocab()


def register_special_tokens(special_tokens):
    # special_tokens is a dictionary of str -> int
    # example: {"<|endoftext|>": 100257}
    special_tokens = special_tokens
    inverse_special_tokens = {v: k for k, v in special_tokens.items()}

def decode(ids):
    # given ids (list of integers), return Python string
    part_bytes = []
    for idx in ids:
        if idx in vocab:
            part_bytes.append(vocab[idx])
        elif idx in inverse_special_tokens:
            part_bytes.append(inverse_special_tokens[idx].encode("utf-8"))
        else:
            raise ValueError(f"invalid token id: {idx}")
    text_bytes = b"".join(part_bytes)
    text = text_bytes.decode("utf-8", errors="replace")
    return text

def _encode_chunk(text_bytes):
    # return the token ids
    # let's begin. first, convert all bytes to integers in range 0..255
    ids = list(text_bytes)
    while len(ids) >= 2:
        # find the pair with the lowest merge index
        stats = get_stats(ids)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        # subtle: if there are no more merges available, the key will
        # result in an inf for every single pair, and the min will be
        # just the first pair in the list, arbitrarily
        # we can detect this terminating case by a membership check
        if pair not in merges:
            break # nothing else can be merged anymore
        # otherwise let's merge the best pair (lowest merge index)
        idx = merges[pair]
        ids = merge(ids, pair, idx)
    return ids

def encode_ordinary(text):
    """Encoding that ignores any special tokens."""
    # split text into chunks of text by categories defined in regex pattern
    text_chunks = re.findall(compiled_pattern, text)
    # all chunks of text are encoded separately, then results are joined
    ids = []
    for chunk in text_chunks:
        chunk_bytes = chunk.encode("utf-8") # raw bytes
        chunk_ids = _encode_chunk(chunk_bytes)
        ids.extend(chunk_ids)
    return ids

def encode(text, allowed_special="none_raise"):
    """
    Unlike encode_ordinary, this function handles special tokens.
    allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
    if none_raise, then an error is raised if any special token is encountered in text
    this is the default tiktoken behavior right now as well
    any other behavior is either annoying, or a major footgun
    """
    # decode the user desire w.r.t. handling of special tokens
    special = None
    if allowed_special == "all":
        special = special_tokens
    elif allowed_special == "none":
        special = {}
    elif allowed_special == "none_raise":
        special = {}
        assert all(token not in text for token in special_tokens)
    elif isinstance(allowed_special, set):
        special = {k: v for k, v in special_tokens.items() if k in allowed_special}
    else:
        raise ValueError(f"allowed_special={allowed_special} not understood")
    if not special:
        # shortcut: if no special tokens, just use the ordinary encoding
        return encode_ordinary(text)
    # otherwise, we have to be careful with potential special tokens in text
    # we handle special tokens by splitting the text
    # based on the occurrence of any exact match with any of the special tokens
    # we can use re.split for this. note that surrounding the pattern with ()
    # makes it into a capturing group, so the special tokens will be included
    special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
    special_chunks = re.split(special_pattern, text)
    # now all the special characters are separated from the rest of the text
    # all chunks of text are encoded separately, then results are joined
    ids = []

    for part in special_chunks:
        if part in special:
            # this is a special token, encode it separately as a special case
            ids.append(special[part])
        else:
            # this is an ordinary sequence, encode it normally
            ids.extend(encode_ordinary(part))
    return ids

In [143]:
train("""OpenAI's text generation models (often referred to as generative pre-trained transformers or "GPT" models for short), like GPT-4 and GPT-3.5, have been trained to understand natural and formal language. Models like GPT-4 allows text outputs in response to their inputs. The inputs to these models are also referred to as "prompts". Designing a prompt is essentially how you "program" a model like GPT-4, usually by providing instructions or some examples of how to successfully complete a task. Models like GPT-4 can be used across a great variety of tasks including content or code generation, summarization, conversation, creative writing, and more. Read more in our introductory text generation guide and in our prompt engineering guide.""", vocab_size=300)

merge 1/44: (105, 110) -> 256 (b'in') had 16 occurrences
merge 2/44: (32, 116) -> 257 (b' t') had 15 occurrences
merge 3/44: (32, 97) -> 258 (b' a') had 14 occurrences
merge 4/44: (114, 101) -> 259 (b're') had 11 occurrences
merge 5/44: (101, 110) -> 260 (b'en') had 10 occurrences
merge 6/44: (101, 114) -> 261 (b'er') had 10 occurrences
merge 7/44: (97, 116) -> 262 (b'at') had 9 occurrences
merge 8/44: (111, 110) -> 263 (b'on') had 9 occurrences
merge 9/44: (100, 101) -> 264 (b'de') had 9 occurrences
merge 10/44: (111, 114) -> 265 (b'or') had 8 occurrences
merge 11/44: (32, 256) -> 266 (b' in') had 8 occurrences
merge 12/44: (32, 103) -> 267 (b' g') had 7 occurrences
merge 13/44: (262, 105) -> 268 (b'ati') had 7 occurrences
merge 14/44: (111, 264) -> 269 (b'ode') had 7 occurrences
merge 15/44: (114, 111) -> 270 (b'ro') had 7 occurrences
merge 16/44: (32, 109) -> 271 (b' m') had 6 occurrences
merge 17/44: (269, 108) -> 272 (b'odel') had 6 occurrences
merge 18/44: (257, 111) -> 273 (b' t

# Implementation of class Tokenizer

In [134]:
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.

Algorithmically follows along the GPT tokenizer:
https://github.com/openai/gpt-2/blob/master/src/encoder.py

Unlike BasicTokenizer:
- RegexTokenizer handles an optional regex splitting pattern.
- RegexTokenizer handles optional special tokens.
"""

import regex as re



# the main GPT text split patterns, see
# https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""


class RegexTokenizer():

    def __init__(self, pattern=None):
        """
        - pattern: optional string to override the default (GPT-4 split pattern)
        - special_tokens: str -> int dictionary of special tokens
          example: {'<|endoftext|>': 100257}
        """

        self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
        self.compiled_pattern = re.compile(self.pattern)
        self.special_tokens = {}
        self.inverse_special_tokens = {}
        self.merges = {}
        self.vocab = self._build_vocab()

    def train(self, text, vocab_size, verbose=True):
        assert vocab_size >= 256
        num_merges = vocab_size - 256

        # split the text up into text chunks
        text_chunks = re.findall(self.compiled_pattern, text)

        # input text preprocessing
        ids = [list(ch.encode("utf-8")) for ch in text_chunks]

        # iteratively merge the most common pairs to create new tokens
        merges = {} # (int, int) -> int
        vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
        for i in range(num_merges):
            # count the number of times every consecutive pair appears
            stats = {}
            for chunk_ids in ids:
                # passing in stats will update it in place, adding up counts
                get_stats(chunk_ids, stats)
            # find the pair with the highest count
            pair = max(stats, key=stats.get)
            # mint a new token: assign it the next available id
            idx = 256 + i
            # replace all occurrences of pair in ids with idx
            ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
            # save the merge
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            # prints
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

        # save class variables
        self.merges = merges # used in encode()
        self.vocab = vocab   # used in decode()

    def register_special_tokens(self, special_tokens):
        # special_tokens is a dictionary of str -> int
        # example: {"<|endoftext|>": 100257}
        self.special_tokens = special_tokens
        self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}

    def decode(self, ids):
        # given ids (list of integers), return Python string
        part_bytes = []
        for idx in ids:
            if idx in self.vocab:
                part_bytes.append(self.vocab[idx])
            elif idx in self.inverse_special_tokens:
                part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
            else:
                raise ValueError(f"invalid token id: {idx}")
        text_bytes = b"".join(part_bytes)
        text = text_bytes.decode("utf-8", errors="replace")
        return text
    
    def _build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")
        return vocab

    def _encode_chunk(self, text_bytes):
        # return the token ids
        # let's begin. first, convert all bytes to integers in range 0..255
        ids = list(text_bytes)
        while len(ids) >= 2:
            # find the pair with the lowest merge index
            stats = get_stats(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            # subtle: if there are no more merges available, the key will
            # result in an inf for every single pair, and the min will be
            # just the first pair in the list, arbitrarily
            # we can detect this terminating case by a membership check
            if pair not in self.merges:
                break # nothing else can be merged anymore
            # otherwise let's merge the best pair (lowest merge index)
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids

    def encode_ordinary(self, text):
        """Encoding that ignores any special tokens."""
        # split text into chunks of text by categories defined in regex pattern
        text_chunks = re.findall(self.compiled_pattern, text)
        # all chunks of text are encoded separately, then results are joined
        ids = []
        for chunk in text_chunks:
            chunk_bytes = chunk.encode("utf-8") # raw bytes
            chunk_ids = self._encode_chunk(chunk_bytes)
            ids.extend(chunk_ids)
        return ids

    def encode(self, text, allowed_special="none_raise"):
        """
        Unlike encode_ordinary, this function handles special tokens.
        allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
        if none_raise, then an error is raised if any special token is encountered in text
        this is the default tiktoken behavior right now as well
        any other behavior is either annoying, or a major footgun
        """
        # decode the user desire w.r.t. handling of special tokens
        special = None
        if allowed_special == "all":
            special = self.special_tokens
        elif allowed_special == "none":
            special = {}
        elif allowed_special == "none_raise":
            special = {}
            assert all(token not in text for token in self.special_tokens)
        elif isinstance(allowed_special, set):
            special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
        else:
            raise ValueError(f"allowed_special={allowed_special} not understood")
        if not special:
            # shortcut: if no special tokens, just use the ordinary encoding
            return self.encode_ordinary(text)
        # otherwise, we have to be careful with potential special tokens in text
        # we handle special tokens by splitting the text
        # based on the occurrence of any exact match with any of the special tokens
        # we can use re.split for this. note that surrounding the pattern with ()
        # makes it into a capturing group, so the special tokens will be included
        special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
        special_chunks = re.split(special_pattern, text)
        # now all the special characters are separated from the rest of the text
        # all chunks of text are encoded separately, then results are joined
        ids = []

        for part in special_chunks:
            if part in special:
                # this is a special token, encode it separately as a special case
                ids.append(special[part])
            else:
                # this is an ordinary sequence, encode it normally
                ids.extend(self.encode_ordinary(part))
        return ids

In [135]:
tk = RegexTokenizer()

In [None]:
tk.train(text=""" """, vocab_size=300)

merge 1/44: (121, 111) -> 256 (b'yo') had 8 occurrences
merge 2/44: (101, 114) -> 257 (b'er') had 8 occurrences
merge 3/44: (32, 97) -> 258 (b' a') had 7 occurrences
merge 4/44: (32, 256) -> 259 (b' yo') had 7 occurrences
merge 5/44: (259, 117) -> 260 (b' you') had 7 occurrences
merge 6/44: (101, 118) -> 261 (b'ev') had 7 occurrences
merge 7/44: (32, 116) -> 262 (b' t') had 7 occurrences
merge 8/44: (32, 115) -> 263 (b' s') had 6 occurrences
merge 9/44: (105, 116) -> 264 (b'it') had 6 occurrences
merge 10/44: (32, 112) -> 265 (b' p') had 5 occurrences
merge 11/44: (114, 101) -> 266 (b're') had 5 occurrences
merge 12/44: (110, 101) -> 267 (b'ne') had 5 occurrences
merge 13/44: (32, 119) -> 268 (b' w') had 5 occurrences
merge 14/44: (101, 110) -> 269 (b'en') had 5 occurrences
merge 15/44: (104, 101) -> 270 (b'he') had 4 occurrences
merge 16/44: (32, 105) -> 271 (b' i') had 4 occurrences
merge 17/44: (32, 102) -> 272 (b' f') had 4 occurrences
merge 18/44: (118, 105) -> 273 (b'vi') had 4 o

In [127]:
tk.encode('hello au??')

idx 270
ids [104, 101, 108, 108, 111]
new ids [270, 108, 108, 111]
idx 276
ids [270, 108, 108, 111]
new ids [270, 276, 111]
idx 258
ids [32, 97, 117]
new ids [258, 117]


[270, 276, 111, 258, 117, 63, 63]

In [45]:
tk.decode([32])

' '

In [31]:
tk.merges

{(97, 117): 256, (256, 116): 257}

In [46]:
tk.vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'