# Byte-Pair Encoding Tokenisation

First, the imports. I've just used basic Python to illustrate the algorithm.

In [1]:
from __future__ import annotations
from collections import Counter
from typing import Dict, List, Tuple
from dataclasses import dataclass
from IPython.display import HTML as html_print
from helpers import html

We need some text to train the tokeniser. We will use the infamous fast inverse-square root algorithm often falsely attributed to John Carmack of Quake fame. See more about it here: https://en.wikipedia.org/wiki/Fast_inverse_square_root

In [2]:
corpus: str = """
float Q_rsqrt( float number )
{
    long i;
    float x2, y;
    const float threehalfs = 1.5F;

    x2 = number * 0.5F;
    y  = number;
    i  = * ( long * ) &y;                       // evil floating point bit level hacking
    i  = 0x5f3759df - ( i >> 1 );               // what the fuck?
    y  = * ( float * ) &i;
    y  = y * ( threehalfs - ( x2 * y * y ) );   // 1st iteration
    // y  = y * ( threehalfs - ( x2 * y * y ) );// 2nd iteration, this can be removed

    return y;
}
"""

The original algorithm was invented in 1994, and the document can be found here ["A New Algorithm for Data Compression"](http://www.pennelynn.com/Documents/CUJ/HTML/94HTML/19940045.HTM)

In [3]:
@dataclass(frozen=True) 
class TokenPair:
    token_a: str
    token_b: str

    @property
    def merged_token(self) -> str:
        return self.token_a + self.token_b


@dataclass
class TextChunk:
    frequency: int
    tokens: List[str]

    def merge_token_pair(self, pair: TokenPair) -> None:
        i: int = 0
        while i < len(self.tokens) - 1:
            if self.tokens[i] == pair.token_a and self.tokens[i + 1] == pair.token_b:
                self.tokens = self.tokens[:i] + [pair.token_a + pair.token_b] + self.tokens[i + 2 :]
            else:
                i += 1

    @property
    def is_single_token(self) -> bool:
        return len(self.tokens) == 1

    @classmethod
    def from_text_chunk(cls, text_chunk: str, frequency: int = -1) -> TextChunk:
        return cls(frequency=frequency, tokens=list(text_chunk))


@dataclass
class Vocab:
    token_to_id: Dict[str, int]

    @classmethod
    def from_text_chunks_tokens(cls, text_chunks: List[TextChunk]):
        """Initialise from all text_chunk's tokens."""
        all_tokens = []
        for text_chunk in text_chunks:
            all_tokens += text_chunk.tokens
        token_to_id = {
            token: id for id, token in enumerate(sorted(list(set(all_tokens))))
        }
        return cls(token_to_id)

    def __getitem__(self, token_string: str) -> int:
        """Get token ID (int)."""
        return self.token_to_id[token_string]

    def __len__(self) -> int:
        """Current size of vocab."""
        return len(self.token_to_id)

    def add(self, token_string: str):
        """Add new token to vocab."""
        assert token_string not in self.token_to_id
        self.token_to_id[token_string] = len(self.token_to_id)
        


def pre_tokenize_str(text: str, delimiter: str = " ") -> List[str]:
    """Split the text with `delimiter` but ensure the `delimiter` is present in the split text."""
    return [
        text_chunk if text_chunk_i == 0 else delimiter + text_chunk for text_chunk_i, text_chunk in enumerate(text.split(delimiter))
    ]


def most_common_token_pair(text_chunks: List[TextChunk]) -> TokenPair:
    """Iterate over text_chunks, and find most common pair of tokens."""
    pair_freqs: Dict[str, int] = Counter()
    for text_chunk in text_chunks:
        if len(text_chunk.tokens) == 1:
            continue
        for i in range(len(text_chunk.tokens) - 1):
            pair: TokenPair = TokenPair(token_a=text_chunk.tokens[i], token_b=text_chunk.tokens[i + 1])
            pair_freqs[pair] += text_chunk.frequency
    return pair_freqs.most_common()[0][0]


def tokenise(text: str, token_merge_history: List[TokenPair], vocab: Vocab) -> Tuple[List[str], List[int]]:
    """Compute the token IDs and token substrings for `text` given a `vocab`."""
    text_chunks = [TextChunk.from_text_chunk(text_chunk) for text_chunk in pre_tokenize_str(text)]
    for token_pair in token_merge_history:
        for text_chunk in text_chunks:
            text_chunk.merge_token_pair(token_pair)
    token_strings: List[str] = []
    token_ids: List[int] = []
    for text_chunk in text_chunks:
        token_strings += text_chunk.tokens
        token_ids += [vocab[token] for token in text_chunk.tokens]
    return token_strings, token_ids



# Split corpus into text_chunks by splitting by the space character " ".
text_chunks: List[TextChunk] = [
    TextChunk.from_text_chunk(text_chunk=text_chunk, frequency=frequency)
    for text_chunk, frequency in Counter(pre_tokenize_str(corpus)).items()
]
# Target size of vocab.
vocab_size: int = 100
# Initialise the vocab from the tokens in the text_chunks. This will be all of the chars found in the corpus.
vocab: Vocab = Vocab.from_text_chunks_tokens(text_chunks)
# Let's track the history of the token merges.
token_merge_history: List[TokenPair] = []
# This is a giant HTML string that we'll print at the end of the cell that renders the tokens with different colours.
html_text_history: str = html.render_corpus_as_token_html(
    corpus=corpus, token_merge_history=token_merge_history, vocab=vocab, tokenise=tokenise,
)
# While we can increase vocab size, and we've not merged all the text_chunks to a single token...
while len(vocab) < vocab_size and not all(text_chunk.is_single_token for text_chunk in text_chunks):
    # Find the most common token pairing.
    token_pair: TokenPair = most_common_token_pair(text_chunks)
    # For all text_chunks, merge any cases of the most common token pairing.
    for text_chunk in text_chunks:
        text_chunk.merge_token_pair(token_pair)
    # Append this token merge to the history.
    token_merge_history.append(token_pair)
    # Add this token to the vocab.
    vocab.add(token_pair.merged_token)
    # Append to the giant HTML string to render the tokens with different colours.
    html_text_history += html.render_corpus_as_token_html(
        corpus=corpus, token_merge_history=token_merge_history, vocab=vocab, tokenise=tokenise,
    )
# Render the history of the token merges.
html_print(html_text_history)