# Naive BPE implementation

In [1]:
import regex as re

# A humble corpus
texts = [
    "low low low low low",
    "lower lower widest widest widest",
    "newest newest newest newest newest newest"
]


In [2]:
vocabulary = ["<EOS>"] + [chr(k) for k in range(ord('a'), ord('z')+1)]

print(" - ".join(vocabulary))

<EOS> - a - b - c - d - e - f - g - h - i - j - k - l - m - n - o - p - q - r - s - t - u - v - w - x - y - z


## Pretokenize texts


In [3]:
# Gpt-2 pre-tokenization regex
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

def pre_tokenize(text: str) -> list[str]:
    # A more complex pre-tokenizer
    return re.findall(PAT, text)

def whitespace_pretokenize(text: str) -> list[str]:
    # Simple whitespace tokenizer
    return text.split()

In [4]:
pre_tokenize("hola mundo...")

['hola', ' mundo', '...']

In [5]:
pre_tokenize("hello! こんにちは!")

['hello', '!', ' こんにちは', '!']

In [6]:
from collections import defaultdict
frequency_table: defaultdict[tuple[bytes], int] = defaultdict(int)

for text in texts:
    words = whitespace_pretokenize(text)

    for word in words:
        bytes_tuple = tuple(word.encode("utf-8"))
        frequency_table[bytes_tuple] += 1

frequency_table


defaultdict(int,
            {(108, 111, 119): 5,
             (108, 111, 119, 101, 114): 2,
             (119, 105, 100, 101, 115, 116): 3,
             (110, 101, 119, 101, 115, 116): 6})

In [7]:
b'low'[0:2].decode("utf-16")

'潬'

In [92]:
from dataclasses import dataclass, Field


# Frozen makes them hashable
@dataclass(frozen=True)
class Token:

    def __post_init__(self):
        # wtf is this shit.
        # using to cast this
        object.__setattr__(self, "byte_list", tuple(self.byte_list))

    byte_list: tuple[bytes]

ao_token = Token(chr(259).encode() + b"o") # strange way to write this

ao_token

Token(byte_list=(196, 131, 111))

In [80]:
ao_token == Token(ao_token.byte_list)

True

In [81]:
bytes([255]), bytes(b'\xff')

(b'\xff', b'\xff')

In [82]:
tokens = [Token((b,)) for b in range(256)]

tokens

[Token(byte_list=(0,)),
 Token(byte_list=(1,)),
 Token(byte_list=(2,)),
 Token(byte_list=(3,)),
 Token(byte_list=(4,)),
 Token(byte_list=(5,)),
 Token(byte_list=(6,)),
 Token(byte_list=(7,)),
 Token(byte_list=(8,)),
 Token(byte_list=(9,)),
 Token(byte_list=(10,)),
 Token(byte_list=(11,)),
 Token(byte_list=(12,)),
 Token(byte_list=(13,)),
 Token(byte_list=(14,)),
 Token(byte_list=(15,)),
 Token(byte_list=(16,)),
 Token(byte_list=(17,)),
 Token(byte_list=(18,)),
 Token(byte_list=(19,)),
 Token(byte_list=(20,)),
 Token(byte_list=(21,)),
 Token(byte_list=(22,)),
 Token(byte_list=(23,)),
 Token(byte_list=(24,)),
 Token(byte_list=(25,)),
 Token(byte_list=(26,)),
 Token(byte_list=(27,)),
 Token(byte_list=(28,)),
 Token(byte_list=(29,)),
 Token(byte_list=(30,)),
 Token(byte_list=(31,)),
 Token(byte_list=(32,)),
 Token(byte_list=(33,)),
 Token(byte_list=(34,)),
 Token(byte_list=(35,)),
 Token(byte_list=(36,)),
 Token(byte_list=(37,)),
 Token(byte_list=(38,)),
 Token(byte_list=(39,)),
 Token(byt

In [83]:
# Create a Token with multiple stuff

Token(b'\xff\x01\x03')

Token(byte_list=b'\xff\x01\x03')

In [None]:
# Get all merges

def find_next_merge(frequency_table: dict[tuple[bytes], int]):
    pair_count = d

    current_max_value = None
    current_max_count = None
    for word in frequency_table:
        for i in range(len(word)-1):
            pair = word[i:i+2]
            count = pair_count[pair] + 1
            pair_count[pair] = count

            if current_max_count is None:
                current_max_value = pair
                current_max_count = pair_count[pair]
            elif count > current_max_count or (current_max_count == count and pair < current_max_value):
                current_max_value = pair
                current_max_count = count

current_max_value, current_max_count

(b'es', 2)

In [35]:
b'\x1A\x30'[1]

48