In [2]:
# https://www.reedbeta.com/blog/programmers-intro-to-unicode/
text = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception."

text.encode("utf-8")[:10]

b'\xef\xbc\xb5\xef\xbd\x8e\xef\xbd\x89\xef'

In [3]:
tokens = list(map(int, text.encode("utf-8")))

print("---")
print(text)
print(f"length: {len(text)}")
print("---")
print(tokens)
print(f"length: {len(tokens)}")

---
Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception.
length: 533
---
[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140

Should we identify byte pairs? Or grapheme cluster pairs?

BPE implies byte-level tokenization, so that is what we will implement here.

However, grapheme cluster pair encoding (GCPE) might be more meaningful since it operates at the "character" level, i.e., the "perceptual atom" of text.

It might be the case that, as BPE begins to iteratively "mint tokens", new tokens begin to align with grapheme clusters.

Either way, Unicode normalization is important to ensure that the binary representation of text is in a consistent form before tokenization. This is beneficial for BPE, but significantly more important for grapheme cluster tokenization.

In [4]:
from typing import Tuple, Collection
from collections import Counter

def most_common_pair(tokens: Collection[int]) -> Tuple[int, int]:
    pairs = zip(tokens[:-1], tokens[1:])
    counter = Counter(pairs)
    return counter.most_common(1)[0][0]

assert most_common_pair([2, 3, 2, 3, 4]) == (2, 3)
assert most_common_pair([1, 2, 1, 2, 3, 4, 3]) == (1, 2)

In [5]:
mcp = most_common_pair(tokens)
(chr(mcp[0]), chr(mcp[1]))

('e', ' ')

In [6]:
def replace_pair(tokens: Collection[int], pair: Tuple[int, int], replacement: int) -> Collection[int]:
    result = []
    i = 0
    while i < len(tokens):
        if i + 1 < len(tokens) and (tokens[i], tokens[i + 1]) == pair:
            result.append(replacement)
            i += 2
        else:
            result.append(tokens[i])
            i += 1
    return result

assert replace_pair([2, 3, 2, 3, 4], (2, 3), 5) == [5, 5, 4]
assert replace_pair([1, 2, 1, 1, 3, 1, 2], (1, 2), 4) == [4, 1, 1, 3, 4]

Though we are replacing instances of tokens with "minted" tokens, it is important that we retain the original set of tokens. We are replacing instances, not modifying the vocabulary.

Replacing pairs should not restrict the use of the pair members in the future. For example, if BPE eliminate all usages of a specific character, e.g., "t",  we would not be able to use "t" in a novel way in the future if we lose the vocab entry for "t" in the process.

The process of building a "forest" of tokens and merging towards larger trees is reminiscent of Huffman coding. Huffman coding aims to turn a "forest" into a single tree. BPE aims to produce a forest with a target number of nodes across all trees.

Huffman coding also operates on individual elements, while BPE operates on sequences.

In [7]:
import unicodedata
from typing import Collection

def utf8_encode(text: str) -> Collection[int]:
    normalized = unicodedata.normalize("NFD", text)
    return list(map(int, normalized.encode("utf-8")))

def utf8_decode(tokens: Collection[int]) -> str:
    encoded = bytes(tokens)
    return encoded.decode("utf-8")

assert utf8_decode(utf8_encode("hello")) == "hello"
assert utf8_decode(utf8_encode(text)) == text

In [8]:
tokens = utf8_encode(text)
print(f"length: {len(tokens)}")
print(f"min: {min(tokens)}")
print(f"max: {max(tokens)}")
print(f"distinct: {len(set(tokens))}")

length: 616
min: 32
max: 240
distinct: 73


In [14]:
from typing import Tuple, Collection

Vocab = Collection[int]
Merges = Collection[Tuple[int, int]]

def bpe_train(text: str, vocab_size: int) -> Tuple[Vocab, Merges]:
    """
    output:
        - vocab: array of base tokens, where index is token id
        - merges: array of merged token pairs, where len(vocab) + index is token id
    """
    tokens = utf8_encode(text)

    # dict-encoded vocab
    vocab = sorted(set(tokens))

    # encode tokens with vocab
    vocab_index = {token: i for i, token in enumerate(vocab)}
    tokens = [vocab_index[token] for token in tokens]

    num_merges = vocab_size - len(vocab)
    merges = []
    i = len(vocab)
    while len(merges) < num_merges:
        pair = most_common_pair(tokens)
        merges.append(pair)

        tokens = replace_pair(tokens, pair, i)
        i += 1

    return vocab, merges

def bpe_encode(text: str, vocab: Vocab, merges: Merges) -> Collection[int]:
    tokens = utf8_encode(text)

    # encode tokens with vocab
    vocab_index = {token: i for i, token in enumerate(vocab)}
    tokens = [vocab_index[token] for token in tokens]

    for i, merge in enumerate(merges):
        pairs = set(zip(tokens[:-1], tokens[1:]))
        if merge in pairs:
            tokens = replace_pair(tokens, merge, len(vocab) + i)
    return tokens

def bpe_decode(tokens: Collection[int], vocab: Vocab, merges: Merges) -> str:
    # each merge can only depend on vocab or previous merges
    vm_index = [[t] for t in vocab]
    for pair in merges:
        result = vm_index[pair[0]] + vm_index[pair[1]]
        vm_index.append(result)

    result = []
    for token in tokens:
        result.extend(vm_index[token])
    return utf8_decode(result)

vocab, merges = bpe_train(text, 256)
encoded = bpe_encode(text, vocab, merges)
assert bpe_decode(encoded, vocab, merges) == text

In [19]:
print(f"vocab size: {len(vocab)}")
print(f"merges size: {len(merges)}")
print(f"raw length: {len(utf8_encode(text))}")
print(f"encoded length: {len(encoded)}")
print(f"compression: {len(tokens) / len(encoded)}")

vocab size: 73
merges size: 183
raw length: 616
encoded length: 197
compression: 3.1269035532994924


TODO:
1. implement bpe in neuralink compression challenge notebook
2. implement optional rules parameter for bpe training (review gpt3/4 rules)
3. implement minbpe