In [None]:
#| default_exp tokenizers

%load_ext autoreload
%autoreload 2


# Tokenizers

## BPE

In [None]:
# | export
from collections import defaultdict
import re


In [None]:
# we use transformer tokenezier to validate my code.
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")


In [None]:
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders
pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True, use_regex=True)
decoder = decoders.ByteLevel()

tokenizer = Tokenizer(model=models.BPE(unk_token="[UNK]"))

tokenizer.pre_tokenizer = pre_tokenizer
tokenizer.decoder = decoder

toks = pre_tokenizer.pre_tokenize_str("Hello, world! How  are you?good .")
toks = [t for t, p in toks]
print(toks)
decoder.decode(toks)
#type(toks[0])
#toks

# no prefix space ['Hello', ',', 'Ġworld', '!', 'ĠHow', 'Ġ', 'Ġare', 'Ġyou', '?', 'good', 'Ġ.']
# prefix space   ['ĠHello', ',', 'Ġworld', '!', 'ĠHow', 'Ġ', 'Ġare', 'Ġyou', '?', 'good', 'Ġ.']
# prefix + regex ['ĠHello', ',', 'Ġworld', '!', 'ĠHow', 'Ġ', 'Ġare', 'Ġyou', '?', 'good', 'Ġ.']
# prefix no regex['ĠHello,Ġworld!ĠHowĠĠareĠyou?goodĠ.']

In [None]:
# | export


class ByteLevelPreTokenizer:
    def __init__(self, add_prefix_space:bool=False):
        self.add_prefix_space = add_prefix_space

    def pre_tokenize_str(self, s):
        if self.add_prefix_space:
            s = " " + s
        raw_tokens = [t for t in re.split(r"(\s|,|\.|;|:|!|\?|\t)", s) if t]
        prev_t = None
        tokens = []
        for t in raw_tokens:
            if t == " ":
                t = "Ġ"
            if prev_t == "Ġ" and t != "Ġ":
                t = prev_t + t
                tokens.pop()
            tokens.append(t)
            prev_t = t
        return tokens
    
class ByteLevelDecoder:
    def decode(self, tokens):
        return "".join(tokens).replace("Ġ", " ")
    
    
ByteLevelPreTokenizer(add_prefix_space=False).pre_tokenize_str("Hello, world! How  are you?good .")    
sents = [
    "Hello, world! How  are you?good .", "Hello world,  How are you ? Go !",
]
trf_pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
my_pre_tokenizer = ByteLevelPreTokenizer(add_prefix_space=True)
trf_decoder = decoders.ByteLevel()
my_decoder = ByteLevelDecoder()

for s in sents:
    assert my_pre_tokenizer.pre_tokenize_str(s) == [
        t for t, p in trf_pre_tokenizer.pre_tokenize_str(s)
    ], f"{my_pre_tokenizer.pre_tokenize_str(s)} != {[t for t, p in trf_pre_tokenizer.pre_tokenize_str(s)]}"
    
    assert my_decoder.decode(my_pre_tokenizer.pre_tokenize_str(s))==" "+s



In [None]:
# | export


class BPETokenizer:
    def __init__(self) -> None:
        self.alphabet: list[str] = []
        self.vocab: list[str] = []
        self.merges: dict[tuple[str, str], str] = {}
        self.itov: dict[int, str] = {}
        self.vtoi: dict[str, int] = {}
        self.pre_tokenizer = ByteLevelPreTokenizer(add_prefix_space=False)

    def pre_tokenize(self, text: str) -> list[str]:
        return self.pre_tokenizer.pre_tokenize_str(text)

    def compute_words_freqs(self, text: str) -> dict[str, int]:
        word_freqs = defaultdict(int)
        for word in self.pre_tokenize(text):
            word_freqs[word] += 1
        return word_freqs

    def init_vocab(self, word_freqs: dict[str, int]):
        char_sets = set()
        for word in word_freqs.keys():
            char_sets.update(set(word))
        alphabet = list(char_sets)
        alphabet.sort()
        vocab = ["<|OOV|>", "<|EOT|>"] + alphabet.copy()
        return vocab

    def compute_pair_freqs(self, splits: dict[str, list[str]], word_freqs: dict[str, int]):
        pair_freqs = defaultdict(int)
        for w, freq in word_freqs.items():
            chars = splits[w]
            for i in range(len(chars) - 1):
                pair_freqs[(chars[i], chars[i + 1])] += freq
        return pair_freqs

    def merge_pair(self, a, b, word_splits: dict[str, list[str]], word_freqs: dict[str, int]):
        """through all words split (into char or group of char), find the one that match a and b, and merge them
        into a+b. Merge will replace a and b in the word_splits dict."""
        for w, freq in word_freqs.items():
            split = word_splits[w]
            if len(split) == 1:
                continue
            i = 0
            for i in range(len(split) - 1):
                if split[i] == a and split[i + 1] == b:
                    split = split[:i] + [a + b] + split[i + 2 :]
                else:
                    i += 1
            word_splits[w] = split

        return word_splits

    def find_best_pair(self, pair_freqs):
        best_pair = ""
        max_freq = None
        for pair, freq in pair_freqs.items():
            if max_freq is None or max_freq < freq:
                max_freq = freq
                best_pair = pair
        return best_pair

    def train(self, text: str, vocab_size: int = 1000) -> None:
        """Train a BPE tokenizer on a text for a given vocab size. A vocab size of 1000 means we will have
        1000 tokens in our vocab, including the alphabet.

        Arguments:
            text -- _description_

        Keyword Arguments:
            vocab_size -- the nb of non alphabet (char) tokens. It represent a maximum nb of tokens (default: {1000}).
        """
        word_freqs = self.compute_words_freqs(text)
        self.vocab = self.init_vocab(word_freqs)
        word_splits = {w: [c for c in w] for w in word_freqs.keys()}

        while len(self.vocab) < vocab_size:
            pair_freqs = self.compute_pair_freqs(word_splits, word_freqs)

            best_pair = self.find_best_pair(pair_freqs)
            if len(best_pair) == 2:
                a, b = best_pair
                word_splits = self.merge_pair(a, b, word_splits=word_splits, word_freqs=word_freqs)
                self.merges[best_pair] = best_pair[0] + best_pair[1]
                self.vocab.append(best_pair[0] + best_pair[1])
            else:
                # when we can't find a pair, we stop, even if we don't reach the vocab size.
                break
        self.vtoi = {v: i for i, v in enumerate(self.vocab)}
        self.itov = {i: v for i, v in enumerate(self.vocab)}
        # print(f"Vocab: {self.vocab}, len: {len(self.vocab)}")
        # print(f"Merges: {self.merges}, len: {len(self.merges)}")

    def encode(self, txt: str) -> list[int]:
        print("-- encode:")
        words = self.pre_tokenize(txt)
        word_splits = [[c for c in w] for w in words]
        print(word_splits)
        for pair, merge in self.merges.items():
            for i, word_split in enumerate(word_splits):
                j = 0
                while j < len(word_split) - 1:
                    if word_split[j] == pair[0] and word_split[j + 1] == pair[1]:
                        word_split = word_split[:j] + [merge] + word_split[j + 2 :]
                    else:
                        j += 1
                word_splits[i] = word_split
        print(f"word splits: {word_splits}")
        # if not in vocab, replace by <|OOV|>
        encoded = [self.vtoi.get(c, 0) for w in word_splits for c in w]
        return encoded

    def decode(self, encoded: list[int]) -> str:
        decoded = [self.itov.get(i, "<|OOV|>") for i in encoded]
        return "".join(decoded)


tk = BPETokenizer()
ptk = ByteLevelPreTokenizer(add_prefix_space=False)

sents = ["Hello, world! How  are you?good .", "Hello world,  How are you ? Go !"]
for s in sents:
    gpt_pre_tokenize = [t for t in ptk.pre_tokenize_str(s)]
    assert (
        tk.pre_tokenize(s) == gpt_pre_tokenize
    ), f"{pre_tokenize(s)} != {gpt_pre_tokenize} (gpt byte level pretokenizer)"


text = "A Hello world, this is a test. A new world is coming, Hell yes."
tk.train(text, vocab_size=50)
text_enc = tk.encode("Hello, I love to test this new thing")
tk.decode(text_enc)


In [None]:
# | hide
import nbdev

nbdev.nbdev_export("./tokenizers.ipynb")
