In [230]:
import random
import requests
import tqdm.notebook as tqdm
from collections import Counter
from IPython.display import HTML, display

## LLM Tokenization
https://en.wikipedia.org/wiki/Byte_pair_encoding#Modified_algorithm

In [107]:
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
text = requests.get(url).text

In [317]:
class BytePairEncodingTokenizer:
    def __init__(self, vocab_size: int = 512):
        self.vocab_size = vocab_size
        self.tokens_to_byte_tuples = {}

    def _vocab_size_reached(self) -> bool:
        return len(self.tokens_to_byte_tuples) == self.vocab_size

    @staticmethod
    def _find_most_common_byte_tuple_pair(byte_tuples: list[tuple[int]]) -> tuple[tuple[int]]:
        pairs = []
        for i in range(len(byte_tuples) - 1):
            first, second = byte_tuples[i], byte_tuples[i + 1]
            pairs.append((first, second))

        most_common = Counter(pairs).most_common(n=1)
        return most_common[0][0]

    @staticmethod
    def _merge_most_common_byte_tuple_pair(
        byte_tuples: list[tuple[int]],
        most_common_byte_tuple_pair: tuple[tuple[int]],
    ) -> list[tuple[int]]:
        new_byte_tuples = []
        i = 0
        while i < len(byte_tuples):

            # can't merge at the last position, just append the last tuple and exit
            if i == len(byte_tuples) - 1:
                new_byte_tuples.append(byte_tuples[i])
                return new_byte_tuples

            first, second  = byte_tuples[i], byte_tuples[i + 1]
            if (first, second) == most_common_byte_tuple_pair:
                new_byte_tuples.append(first + second)
                i += 2                
            else:
                new_byte_tuples.append(first)
                i += 1
        return new_byte_tuples

    def fill_missing_bytes(self) -> None:
        # allows to encode characters that weren't in the training data
        max_token = max(self.tokens_to_byte_tuples.keys())
        for byte_ in range(256): 
            if (byte_,) not in self.tokens_to_byte_tuples.values():
                max_token += 1
                self.tokens_to_byte_tuples[max_token] = (byte_,)
    
    def fit(self, text: str) -> None:
        # text as a list of utf8 bytes = [29, 19, 255, ...]
        btext = list(text.encode("utf8"))
        # each token consists of tuple of bytes
        # starting with 1 tuple = 1 byte
        byte_tuples = [(byte_,) for byte_ in btext]
        while not self._vocab_size_reached():
            most_common_byte_tuple_pair = self._find_most_common_byte_tuple_pair(byte_tuples)
            byte_tuples = self._merge_most_common_byte_tuple_pair(byte_tuples, most_common_byte_tuple_pair)
            unique_byte_tuples = list(set(byte_tuples))
            self.tokens_to_byte_tuples = {i: byte_tuple for i, byte_tuple in enumerate(unique_byte_tuples)}

        self.fill_missing_bytes()

    def get_sorted_tokens_to_byte_tuples(self) -> list[int, tuple[int]]:
        return (sorted(self.tokens_to_byte_tuples.items(), key=lambda i: len(i[1])))[::-1]

    def encode(self, text: str) -> list[int]:

        bytes_ = list(text.encode("utf8"))
        tokens = []
        i = 0
        while i < len(bytes_):
            for token, byte_tuple in self.get_sorted_tokens_to_byte_tuples():
                tuple_len = len(byte_tuple)
                if byte_tuple == tuple(bytes_[i:i + tuple_len]):
                    tokens.append(token)
                    i += tuple_len
                    break
        return tokens

    def decode(self, tokens: list[int]) -> str:
        byte_tuples = [self.tokens_to_byte_tuples[token] for token in tokens]
        bytes_ = [byte_ for byte_tuple in byte_tuples for byte_ in byte_tuple]
        return bytes(bytes_).decode("utf8")

In [330]:
def test():
    bpe = BytePairEncodingTokenizer(vocab_size=256)
    small_text = text[:10000] 
    bpe.fit(small_text)

    segment_size = 512
    for _ in tqdm.tqdm(range(200)):
        start = random.randint(0, len(small_text) - segment_size)
        text_to_encode = small_text[start:start + segment_size]
        assert bpe.decode(bpe.encode(text_to_encode)) == text_to_encode

test()

  0%|          | 0/200 [00:00<?, ?it/s]

In [323]:
def visualize(text: str, tokenizer: BytePairEncodingTokenizer) -> None:
    colors = [
        "#c9ff82",
        "#82e4ff",
        "#e7c8fa",
        "#fae3c8",
        "#f2c8fa",
        "#b3fff0",
    ]

    def get_color(i: int) -> str:
        j = i % len(colors)
        color = colors[j]
        return f"background-color:{color}"

    tokens = tokenizer.encode(text)
    html_elements = [
        f"""<span style="{get_color(i)}">{tokenizer.decode([token])}</span>"""
        for i, token in enumerate(tokens)
    ]

    return HTML("".join(html_elements))

In [324]:
bpe = BytePairEncodingTokenizer(vocab_size=256)
small_text = text[:10000] 
bpe.fit(small_text)

In [325]:
bpe.encode("Before we proceed any further")

[204, 105, 41, 145, 89, 201, 30, 114, 189, 106, 15, 10, 143, 230]

In [326]:
bpe.decode(bpe.encode("Before we proceed any further"))

'Before we proceed any further'

In [327]:
bpe.decode(bpe.encode("čšáuaweada😊"))

'čšáuaweada😊'

In [328]:
visualize(small_text[1000:2000], bpe)

In [329]:
visualize(small_text[:1000], bpe)