following along:
* https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing
* https://www.youtube.com/watch?v=zduSFxRajkE

tom lehrer's songs: https://tomlehrersongs.com/

the elements song: https://tomlehrersongs.com/wp-content/uploads/2018/12/the-elements.pdf

In [None]:
lyrics = """THE ELEMENTS

There's antimony, arsenic, aluminum, selenium,
And hydrogen and oxygen and nitrogen and rhenium,
And nickel, neodymium, neptunium, germanium,
And iron, americium, ruthenium, uranium,

Europium, zirconium, lutetium, vanadium,
And lanthanum and osmium and astatine and radium,
And gold and protactinium and indium and gallium,
And iodine and thorium and thulium and thallium.

There's yttrium, ytterbium, actinium, rubidium,
And boron, gadolinium, niobium, iridium,
And strontium and silicon and silver and samarium,
And bismuth, bromine, lithium, beryllium, and barium.

There's holmium and helium and hafnium and erbium,
And phosphorus and francium and fluorine and terbium,
And manganese and mercury, molybdenum, magnesium,
Dysprosium and scandium and cerium and cesium.

And lead, praseodymium and platinum, plutonium,
Palladium, promethium, potassium, polonium,
And tantalum, technetium, titanium, tellurium,
And cadmium and calcium and chromium and curium.

There's sulfur, californium and fermium, berkelium,
And also mendelevium, einsteinium, nobelium,
And argon, krypton, neon, radon, xenon, zinc and rhodium,
And chlorine, carbon, cobalt, copper, tungsten, tin and sodium.

These are the only ones o_f which the news has come to Ha'vard,
And there may be many others but they haven't been discavard."""

## basic byte pair encoding

In [None]:
import abc
import string
import typing as T
from collections import Counter

import regex
import tqdm

import random_neural_net_models.utils as utils

tokens = [int(v) for v in lyrics.encode("utf-8")]
tokens[:5], tokens[-5:], len(tokens), len(set(tokens))

In [None]:
def get_stats(tokens: T.List[int]) -> Counter:
    return Counter(zip(tokens[:-1], tokens[1:]))


stats = get_stats(tokens)
stats.most_common(5)

In [None]:
def merge(
    tokens: T.List[int],
    pair_to_replace: T.Tuple[int, int],
    replacement_token: int,
) -> T.List[int]:
    new_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and tuple(tokens[i : i + 2]) == pair_to_replace:
            new_tokens.append(replacement_token)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens


merge([5, 6, 6, 7, 9, 1], (6, 7), 99)

In [None]:
base_symbols = string.ascii_letters + string.digits
base_symbols

In [None]:
base_tokens = [int(v) for v in base_symbols.encode("utf-8")]
base_tokens[:5], base_tokens[-5:]

In [None]:
replacement_token = max(tokens + base_tokens) + 1
replacement_token

In [None]:
pair_to_replace = stats.most_common()[0][0]
pair_to_replace

In [None]:
tokens2 = merge(tokens, pair_to_replace, replacement_token)

In [None]:
max(tokens2), max(tokens)

In [None]:
len(tokens), len(tokens2), len(set(tokens)), len(set(tokens2))

In [None]:
log = utils.logger


def repeated_merge(
    tokens: T.List[int],
    vocab_size: int,
    show_progress: bool,
    base_tokens: T.List[int] = None,
) -> T.Tuple[T.List[int], T.Dict[T.Tuple[int, int], int]]:
    n0 = len(tokens)
    n_used_tokens = len(set(tokens))
    n_merges = vocab_size - n_used_tokens
    log.info(
        f"repeatedly merging tokens: {n_merges=} to achieve {vocab_size=} with {n_used_tokens=}"
    )

    replacement_token = (
        max(tokens + base_tokens) if base_tokens else max(tokens)
    )
    pair_map = {}
    for _ in tqdm.tqdm(
        range(n_merges), total=n_merges, desc="merge", disable=not show_progress
    ):
        stats = get_stats(tokens)
        pair_to_replace = stats.most_common()[0][0]
        replacement_token += 1
        tokens = merge(tokens, pair_to_replace, replacement_token)
        pair_map[pair_to_replace] = replacement_token
    n1 = len(tokens)
    log.info(
        f"result: {n0:_d} -> {n1:_d} tokens = compression to {n1/n0:.2%} of tokens"
    )
    return tokens, pair_map


vocab_size = len(set(tokens)) + 20
tokens3, pair_map = repeated_merge(tokens, vocab_size, show_progress=True)

In [None]:
len(pair_map)

In [None]:
len(tokens), len(tokens3), len(set(tokens)), len(set(tokens3))

In [None]:
pair_map

In [None]:
vocab = {idx: bytes([idx]) for idx in set(tokens + base_tokens)}
for (token0, token1), idx in pair_map.items():
    vocab[idx] = vocab[token0] + vocab[token1]
vocab

In [None]:
def decode(bpe_tokens: T.List[int], vocab: T.Dict[int, bytes]) -> str:
    tokens = [vocab[token] for token in bpe_tokens]
    tokens = b"".join(tokens)
    text = tokens.decode("utf-8", errors="replace")
    return text


decode(tokens3, vocab)

In [None]:
def encode(
    text: str, pair_map: T.Dict[int, T.Tuple[int, int]], show_progress: bool
) -> T.List[int]:
    tokens = [int(v) for v in text.encode("utf-8")]
    log.info(f"{len(tokens)=:_d}")
    if len(tokens) == 1:
        return tokens

    for _ in range(len(pair_map)):
        stats = get_stats(tokens)
        is_done = not any(p in pair_map for p in stats)
        if is_done:
            return tokens

        pair = min(stats, key=lambda pair: pair_map.get(pair, float("inf")))
        idx = pair_map[pair]
        tokens = merge(tokens, pair, idx)
        log.info(f"{len(tokens)=:_d}")

    return tokens


test_bpe_tokens = encode("bla bla and bla", pair_map, show_progress=True)
test_bpe_tokens

In [None]:
decode(test_bpe_tokens, vocab)

https://github.com/openai/gpt-2

https://github.com/openai/tiktoken

https://github.com/google/sentencepiece

## Tokenizer classes

In [None]:
BASE_SYMBOLS = string.ascii_letters + string.digits


class TokenizerBase(abc.ABC):

    base_symbols: str
    base_tokens: T.List[int]
    vocab: T.Dict[int, bytes]
    pair_map: T.Dict[T.Tuple[int, int], int]

    def __init__(self, base_symbols: str = None):
        self.base_symbols = base_symbols if base_symbols else BASE_SYMBOLS
        self.base_tokens = [int(v) for v in self.base_symbols.encode("utf-8")]

    @abc.abstractmethod
    def fit(self, text: str, vocab_size: int, verbose: int = False): ...

    @abc.abstractmethod
    def encode(self, text: str) -> T.List[int]: ...

    @abc.abstractmethod
    def decode(self, tokens: T.List[int]) -> str: ...


class TokenizerSimple(TokenizerBase):

    def fit(self, text: str, vocab_size: int, verbose: int = False):
        tokens = [int(v) for v in text.encode("utf-8", errors="replace")]
        _, self.pair_map = repeated_merge(
            tokens,
            vocab_size,
            show_progress=verbose,
            base_tokens=self.base_tokens,
        )
        self.vocab = {
            idx: bytes([idx]) for idx in set(tokens + self.base_tokens)
        }
        for (token0, token1), idx in self.pair_map.items():
            self.vocab[idx] = self.vocab[token0] + self.vocab[token1]

    def encode(self, text: str) -> T.List[int]:
        return encode(text, self.pair_map, show_progress=False)

    def decode(self, tokens: T.List[int]) -> str:
        return decode(tokens, self.vocab)

In [None]:
tokenizer = TokenizerSimple()

In [None]:
vocab_size = 60
tokenizer.fit(lyrics, vocab_size, verbose=True)

In [None]:
phrase = "From adolescence to senility, bypassing maturity."
bpe_tokens = tokenizer.encode(phrase)
bpe_tokens[:3]

In [None]:
tokenizer.decode(bpe_tokens)

In [None]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

pattern = regex.compile(GPT4_SPLIT_PATTERN)

pattern.findall(phrase)

In [None]:
c0 = Counter([1, 1, 1])
c1 = Counter([1, 1, 2])
c2 = Counter()
c2.update(c0)
c2.update(c1)
c0, c1, c2

https://www.lesswrong.com/posts/aPeJE8bSo6rAFoLqg/solidgoldmagikarp-plus-prompt-generation

In [None]:
class TokenizerRegex(TokenizerSimple):

    def fit(
        self,
        text: str,
        vocab_size: int,
        pattern: regex.Pattern,
        verbose: int = False,
    ):
        self.pattern = pattern
        tokens = [
            [int(v) for v in chunk.encode("utf-8", errors="replace")]
            for chunk in self.pattern.findall(text)
        ]
        unique_tokens = set(tok for chunk in tokens for tok in chunk)

        n_merges = vocab_size - len(unique_tokens)

        if n_merges <= 0:
            raise ValueError(f"{n_merges=} needs to be > 0")

        unique_tokens.update(self.base_tokens)
        print(unique_tokens)

        self.pair_map = {}
        # replacement_token = max([tok for chunk in tokens for tok in chunk])
        replacement_token = max(unique_tokens)

        for i in range(n_merges):

            stats = Counter()
            for chunk in tokens:
                stats.update(get_stats(chunk))

            pair_to_replace = stats.most_common()[0][0]
            replacement_token += 1

            tokens = [
                merge(chunk, pair_to_replace, replacement_token)
                for chunk in tokens
            ]

            self.pair_map[pair_to_replace] = replacement_token

        self.vocab = {idx: bytes([idx]) for idx in unique_tokens}
        for (token0, token1), idx in self.pair_map.items():
            self.vocab[idx] = self.vocab[token0] + self.vocab[token1]

    def encode(self, text: str) -> T.List[T.List[int]]:
        return [
            encode(chunk, self.pair_map, show_progress=False)
            for chunk in self.pattern.findall(text)
        ]

    def decode(self, tokens: T.List[T.List[int]]) -> str:
        return "".join([decode(chunk, self.vocab) for chunk in tokens])


tokenizer = TokenizerRegex()
tokenizer.fit(lyrics, vocab_size, pattern)

In [None]:
pattern.findall(phrase)

In [None]:
bpe_tokens = tokenizer.encode(phrase)
bpe_tokens[:5]

In [None]:
# tokenizer.vocab
# tokenizer.pair_map

In [None]:
tokenizer.decode(bpe_tokens)