In [None]:
#| default_exp tokenizers

%load_ext autoreload
%autoreload 2


# Tokenizers

## BPE

In [None]:
# | export
from collections import defaultdict
import re
from typing import Optional

In [None]:
# we use transformer tokenezier to validate my code.
from transformers import AutoTokenizer
from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)

In [None]:
# | export
class ByteLevelPreTokenizer:
    def __init__(self, add_prefix_space: bool = False):
        self.add_prefix_space = add_prefix_space

    def pre_tokenize_str(self, s):
        if self.add_prefix_space:
            s = " " + s
        raw_tokens = [t for t in re.split(r"(\s|,|\.|;|:|!|\?|\t)", s) if t]
        prev_t = None
        tokens = []
        start = 0
        end = 0
        for t in raw_tokens:
            start = s[:].find(t, end)
            end = start + len(t)
            if t == " ":
                t = "Ġ"
            if prev_t == "Ġ" and t != "Ġ":
                t = prev_t + t
                tokens.pop()
                start -= 1
            tokens.append((t, (start, end)))
            prev_t = t
        return tokens


class ByteLevelDecoder:
    def decode(self, tokens):
        return "".join(t for t, p in tokens).replace("Ġ", " ")

In [None]:
def test_encode_decode():
    txt = "Hello, world! How  are you?good ."
    toks = ByteLevelPreTokenizer(add_prefix_space=False).pre_tokenize_str(txt)
    assert ByteLevelDecoder().decode(toks) == txt


test_encode_decode()

In [None]:
def test_pre_tokenizer_against_transformers():
    """Test ByteLevelPreTokenizer against transformers' ByteLevel pre-tokenizer"""
    ByteLevelPreTokenizer(add_prefix_space=False).pre_tokenize_str("Hello, world! How  are you?good .")
    sents = [
        "Hello, world! How  are you?good .",
        "Hello world,  How are you ? Go !",
    ]
    add_prefix_space: bool = False
    trf_pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
    my_pre_tokenizer = ByteLevelPreTokenizer(add_prefix_space=add_prefix_space)
    my_decoder = ByteLevelDecoder()

    for s in sents:
        toks = my_pre_tokenizer.pre_tokenize_str(s)
        trf_toks = trf_pre_tokenizer.pre_tokenize_str(s)
        for t, trf_t in zip(toks, trf_toks):
            assert t[0] == trf_t[0], f"{t} != {trf_t}"
            assert t[1][0] == trf_t[1][0], f"{t} != {trf_t}"
            assert t[1][1] == trf_t[1][1], f"{t} != {trf_t}"

        assert my_decoder.decode(my_pre_tokenizer.pre_tokenize_str(s)) == " " + s if add_prefix_space else s


test_pre_tokenizer_against_transformers()

In [None]:
# | export


class Encoding:
    def __init__(self, ids, tokens, offsets: Optional[list] = None, attention_mask: Optional[list] = None):
        self.ids = ids
        self.tokens = tokens
        self.offsets = offsets
        self.attention_mask = attention_mask

    def __repr__(self):
        return f"ids: {self.ids}\ntokens: {self.tokens}\noffsets: {self.offsets}\nattention_mask: {self.attention_mask}"


class BPETokenizer:
    def __init__(self, vocab_size: int = 1000) -> None:
        self.vocab_size: int = vocab_size
        self.alphabet: list[str] = []
        self.vocab: list[str] = []
        self.merges: dict[tuple[str, str], str] = {}
        self.itov: dict[int, str] = {}
        self.vtoi: dict[str, int] = {}
        self.pre_tokenizer = ByteLevelPreTokenizer(add_prefix_space=False)
        self.padding: bool = False
        self.length: int = -1
        self.train_length: int = -1
        self.padding_token: str = "[PAD]"

        self.padding_direction: str = "right"
        self.unknown_token: str = "[UNK]"
        self.additional_special_tokens: list[str] = []

    def enable_padding(self, padding_token: str = "[PAD]", length: int = 128, direction: str = "right"):
        self.padding = True
        self.padding_token = padding_token
        self.length = length
        self.padding_direction = direction

    def pre_tokenize(self, text: str) -> list[str]:
        return self.pre_tokenizer.pre_tokenize_str(text)

    def compute_words_freqs(self, text: str) -> dict[str, int]:
        word_freqs = defaultdict(int)
        for word, p in self.pre_tokenizer.pre_tokenize_str(text):
            word_freqs[word] += 1
        return word_freqs

    def init_vocab(self, word_freqs: dict[str, int]):
        char_sets = set()
        for word in word_freqs.keys():
            char_sets.update(set(word))
        alphabet = list(char_sets)
        alphabet.sort()
        vocab = [self.padding_token, self.unknown_token] + self.additional_special_tokens + alphabet.copy()
        return vocab

    def compute_pair_freqs(self, splits: dict[str, list[str]], word_freqs: dict[str, int]):
        pair_freqs = defaultdict(int)
        for w, freq in word_freqs.items():
            chars = splits[w]
            for i in range(len(chars) - 1):
                pair_freqs[(chars[i], chars[i + 1])] += freq
        return pair_freqs

    def merge_pair(self, a, b, word_splits: dict[str, list[str]], word_freqs: dict[str, int]):
        """through all words split (into char or group of char), find the one that match a and b, and merge them
        into a+b. Merge will replace a and b in the word_splits dict."""
        for w, freq in word_freqs.items():
            split = word_splits[w]
            if len(split) == 1:
                continue
            i = 0
            for i in range(len(split) - 1):
                if split[i] == a and split[i + 1] == b:
                    split = split[:i] + [a + b] + split[i + 2 :]
                else:
                    i += 1
            word_splits[w] = split

        return word_splits

    def find_best_pair(self, pair_freqs):
        best_pair = ""
        max_freq = None
        for pair, freq in pair_freqs.items():
            if max_freq is None or max_freq < freq:
                max_freq = freq
                best_pair = pair
        return best_pair

    def train_from_iterator(self, txt_iterator) -> None:
        # for txt_item in txt_iterator:
        text = "\n".join(txt_iterator)
        self.train(text)

    def train(self, text: str) -> None:
        """Train a BPE tokenizer on a text for a given vocab size. A vocab size of 1000 means we will have
        1000 tokens in our vocab, including the alphabet.

        Arguments:
            text -- _description_

        Keyword Arguments:
            vocab_size -- the nb of non alphabet (char) tokens. It represent a maximum nb of tokens (default: {1000}).
        """
        word_freqs = self.compute_words_freqs(text)
        self.vocab = self.init_vocab(word_freqs)
        word_splits = {w: [c for c in w] for w in word_freqs.keys()}

        while len(self.vocab) < self.vocab_size:
            pair_freqs = self.compute_pair_freqs(word_splits, word_freqs)

            best_pair = self.find_best_pair(pair_freqs)
            if len(best_pair) == 2:
                a, b = best_pair
                word_splits = self.merge_pair(a, b, word_splits=word_splits, word_freqs=word_freqs)
                self.merges[best_pair] = best_pair[0] + best_pair[1]
                self.vocab.append(best_pair[0] + best_pair[1])
            else:
                # when we can't find a pair, we stop, even if we don't reach the vocab size.
                break

        self.vtoi = {v: i for i, v in enumerate(self.vocab)}
        self.itov = {i: v for i, v in enumerate(self.vocab)}

    def encode(self, txt: str) -> Encoding:
        pre_tokens = self.pre_tokenizer.pre_tokenize_str(txt)
        word_splits = [([c for c in w], p) for w, p in pre_tokens]
        for pair, merge in self.merges.items():
            for i, (word_split, p) in enumerate(word_splits):
                j = 0
                while j < len(word_split) - 1:
                    if word_split[j] == pair[0] and word_split[j + 1] == pair[1]:
                        word_split = word_split[:j] + [merge] + word_split[j + 2 :]
                    else:
                        j += 1
                word_splits[i] = (word_split, p)
        encoded = []
        offsets = []
        tokens = []
        attention_mask = []

        # padding length is based on the number of tokens, not the number of chars.

        for w, (s, e) in word_splits:
            prev_tok_end = s
            for i, c in enumerate(w):
                tok_id = self.vtoi.get(c, self.vtoi.get(self.unknown_token))
                encoded.append(tok_id)
                tokens.append(self.itov.get(tok_id))
                # attention mask is 1 for any tokens except it its a padding token
                attention_mask.append(1)
                offsets.append((prev_tok_end, prev_tok_end + len(c)))
                prev_tok_end += len(c)

        if self.padding:
            encoded, offsets, tokens, attention_mask = self._pad(encoded, offsets, tokens, attention_mask)

        enc = Encoding(ids=encoded, tokens=tokens, offsets=offsets, attention_mask=attention_mask)
        return enc

    def _pad(self, encoded, offsets, tokens, attention_mask):
        nb_tokens = len(encoded)
        if self.length > nb_tokens:
            padding_length = self.length - nb_tokens
            if self.padding_direction == "right":
                encoded.extend([self.vtoi[self.padding_token]] * padding_length)
                tokens.extend([self.padding_token] * padding_length)
                attention_mask.extend([0] * padding_length)
                offsets.extend([(0, 0)] * padding_length)
            elif self.padding_direction == "left":
                encoded = [self.vtoi[self.padding_token]] * padding_length + encoded
                tokens = [self.padding_token] * padding_length + tokens
                attention_mask = [0] * padding_length + attention_mask
                offsets = [(0, 0)] * padding_length + offsets
        return encoded, offsets, tokens, attention_mask

    def decode(self, encoded: list[int]) -> str:
        decoded = [self.itov.get(i, self.unknown_token).replace("Ġ", " ") for i in encoded]
        return "".join(decoded)

In [None]:
tk = BPETokenizer(vocab_size=50)
text = "A Hello world, this is a test. A new world is coming, Hell yes . I love this"
tk.train(text)
txt_to_enc = "Hello world, I love to test this new thing"

encodings = tk.encode(txt_to_enc)

assert "".join([t for t in encodings.tokens]).replace("Ġ", " ") == txt_to_enc
assert "".join([tk.decode([i]) for i in encodings.ids]) == txt_to_enc
assert "".join([txt_to_enc[s:e] for s, e in encodings.offsets]) == txt_to_enc

encodings = tk.encode("Hello TOTO")
assert encodings.tokens[-5:] == ["Ġ", "[UNK]", "[UNK]", "[UNK]", "[UNK]"]

tk.enable_padding(direction="right", length=15)
padded_encodings = tk.encode("Hello world TOTO")
assert padded_encodings.tokens[-5:] == ["[UNK]", "[PAD]", "[PAD]", "[PAD]", "[PAD]"]

In [None]:
# example using transformer tokenizer on a new training dataset
tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# tokenizer.enable_padding(length=20, pad_token="[PAD]")
trainer = trainers.BpeTrainer(vocab_size=50, special_tokens=["[PAD]", "[UNK]"])  # pad is 0, unk is 1
tokenizer.train_from_iterator([text], trainer=trainer)
tokenizer.get_vocab_size()
tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
tokenizer.decoder = decoders.ByteLevel()
tokenizer.enable_padding(length=20, pad_token="[PAD]")
sent_to_encode = "Hello World TOTO"
encodings = tokenizer.encode(sent_to_encode)
encodings.offsets, encodings.tokens, encodings.attention_mask, encodings.ids
# tokenizer.id_to_token(0)
print(encodings.ids), print(encodings.tokens)
# tokenizer.get_vocab()

In [None]:
# tokenizer.enable_padding(pad_id=2, pad_token="<|PAD|>", length=55)
tokenizer.enable_padding(length=30, pad_token="[PAD]")
tokenizer.encode(txt_to_enc[:50]).attention_mask
tokenizer.encode(txt_to_enc[:50]).tokens
# tokenizer.padding
# tokenizer.get_vocab()
tokenizer.encode(txt_to_enc[:50]).tokens
tokenizer.token_to_id("TOTO")

In [None]:
tk = BPETokenizer(vocab_size=50)

text = "A Hello world, this is a test. A new world is coming, Hell yes . I love this"
tk.train(text)
text_enc = tk.encode("Hello, I love to test this new thing")
print(tk.decode(text_enc.ids))
print(text_enc)

In [None]:
# | hide
import nbdev

nbdev.nbdev_export("./tokenizers.ipynb")