# Byte-Pair Encoding tokenization

This notebook illustrates a implementation of Byte-Pair Encoding tokenization using python 

In [2]:
# black formatting with jupyter-black
import jupyter_black

jupyter_black.load(
    lab=True,
    line_length=140,
)

In [3]:
# import libaries
import re

from typing import List
from tqdm.notebook import tqdm
from datasets import load_dataset
from utils import text_preprocessing
from collections import defaultdict

In [4]:
# import IMDB dataset from huggingface
imdb = load_dataset("imdb")["unsupervised"].to_pandas()

In [5]:
# preprocess data
imdb.text = imdb.text.apply(text_preprocessing)

In [44]:
class BPETokenizer:

    def __init__(self, text: List[str], special_tokens=["<cls>", "<sep>", "<unk>", "<pad>", "<mask>"]) -> None:
        """
        Implementation of BPE algorithm

        Arguments:
        ----------
        text: List[str]
            Text to tokenize. It should be a list of setences.
        """

        self.special_tokens = special_tokens
        text = " ".join(text)
        self.text = self.pre_tokenization(text)
        self.vocabulary = list(set(text))
        self.vocabulary.remove(" ")
        self.token_to_id = {token: id for id, token in enumerate(special_tokens)}

    def pre_tokenization(self, text: str):
        """
        Pre tokenization of the text. It will remove the special tokens and replace the spaces with a special character.

        Arguments:
        ----------
        text: str
            Text to tokenize. It should be a list of setences.

        Returns:
        --------
        text: List[str]
        """

        pattern = "(" + "|".join(self.special_tokens) + "|\W|\s)"

        text = text.strip()
        text = re.sub("\s+", " Ñ", text)
        text = re.split(pattern, text)
        text = [char for char in text if char not in ["", " "]]

        return text

    def get_words_count(self):
        """
        Get the words count of the text.

        Returns:
        --------
            words_count: defaultdict(int)
        """

        words_count = defaultdict(int)

        for word in self.text:
            words_count[word] += 1

        return words_count

    def get_initial_corpus(self):
        """
        Get the initial corpus of the text.

        Returns:
        --------
            corpus: List[Tuple[List[str], int]]
        """

        corpus = []
        words_count = self.get_words_count()

        for word, freq in words_count.items():
            corpus.append((list(word), freq))

        return corpus

    def get_bigram_freq(self, corpus):
        """
        Get the bigram frequency of the corpus.

        Arguments:
        ----------
            corpus: List[Tuple[List[str], int]]

        Returns:
        --------
            bi_grams: defaultdict(int)
        """

        bi_grams = defaultdict(int)

        for char, freq in corpus:
            for i in range(len(char) - 1):
                bi_grams[char[i], char[i + 1]] += freq

        return bi_grams

    def update_corpus_and_vocab(self, bi_grams, corpus):
        """
        Update the corpus and vocabulary.

        Arguments:
        ----------
            bi_grams: defaultdict(int)
            corpus: List[Tuple[List[str], int]]

        Returns:
        --------
            new_corpus: List[Tuple[List[str], int]]
        """

        max_bi_gram = "".join(max(bi_grams, key=bi_grams.get))
        self.vocabulary.append(max_bi_gram)

        new_corpus = []

        for char, freq in corpus:

            if max_bi_gram in "".join(char):
                char = "".join(char)
                char = char.replace(max_bi_gram, " " + max_bi_gram + " ").split()

            new_corpus.append((char, freq))

        return new_corpus

    def train(self, vocab_size: int = 10):
        """
        Train the tokenizer.

        Arguments:
        ----------
            vocab_size: int
                Size of the vocabulary.
        """

        corpus = self.get_initial_corpus()

        bar = tqdm(range(vocab_size))
        for _ in bar:
            bi_grams = self.get_bigram_freq(corpus)
            corpus = self.update_corpus_and_vocab(bi_grams, corpus)

        self.corpus = corpus

        for id, token in enumerate(self.vocabulary, len(self.special_tokens)):
            self.token_to_id[token] = id

    def tokenize(self, text: str):
        """
        Tokenize the text.

        Arguments:
        ----------
            text: str
                Text to tokenize.

        Returns:
        --------
            tokenized_text: List[str]
        """

        text = self.pre_tokenization(text)
        tokenized_text = []

        for word in text:
            if word not in self.special_tokens:
                word = list(word)
                for voc in self.vocabulary:
                    i = 0
                    while i < len(word) - 1:
                        if word[i] + word[i + 1] == voc:
                            word = word[:i] + [voc] + word[i + 2 :]
                        else:
                            i += 1
            else:
                word = [word]

            tokenized_text += word

        for i in range(len(tokenized_text)):
            if (tokenized_text[i] not in self.vocabulary) and (tokenized_text[i] not in self.special_tokens):
                tokenized_text[i] = "<unk>"

        ids = [self.token_to_id[token] for token in tokenized_text]

        return tokenized_text, ids

In [45]:
# select a subset of the data for testing
text = imdb.text.tolist()[:1000]

In [46]:
# instance tokenizer
tokenizer = BPETokenizer(text)

# train tokenizer
tokenizer.train(vocab_size=100)

  0%|          | 0/100 [00:00<?, ?it/s]

In [51]:
tokens, ids = tokenizer.tokenize("<cls>i cant compare this movie with anything else <sep> maybe except the movie leon <mask> played")
print(tokens)
print(ids)

['<cls>', 'i', 'Ñc', 'a', 'n', 't', 'Ñc', 'o', 'm', 'p', 'a', 'r', 'e', 'Ñth', 'i', 's', 'Ñm', 'o', 'v', 'i', 'e', 'Ñw', 'i', 'th', 'Ñan', 'y', 'th', 'ing', 'Ñe', 'l', 's', 'e', '<unk>', '<sep>', 'Ñm', 'a', 'y', 'b', 'e', 'Ñe', 'x', 'c', 'e', 'p', 't', 'Ñthe', 'Ñm', 'o', 'v', 'i', 'e', 'Ñl', 'e', 'o', 'n', '<unk>', '<mask>', 'Ñp', 'l', 'a', 'y', 'e', 'd']
[0, 85, 98, 26, 55, 16, 98, 73, 38, 32, 26, 47, 10, 146, 85, 40, 96, 73, 30, 85, 10, 90, 85, 145, 138, 49, 145, 171, 108, 66, 40, 10, 2, 1, 96, 26, 49, 31, 10, 108, 75, 79, 10, 32, 16, 185, 96, 73, 30, 85, 10, 103, 10, 73, 55, 2, 4, 102, 66, 26, 49, 10, 67]
