In [25]:
# install and import libraries
from collections import Counter, defaultdict
from nltk.tokenize import wordpunct_tokenize

class BPE:
    """Byte-Pair Encoding: Subword-based tokenization algorithm."""

    def __init__(self, corpus, vocab_size: int = None, num_iter: int = None):
        """Initialize BPE tokenizer."""
        self.corpus = corpus
        self.vocab_size = vocab_size
        self.num_iter = num_iter

        # pre-tokenize the corpus into words, BERT pre-tokenizer is used here
        self.tokenizer = wordpunct_tokenize
        self.word_freqs = defaultdict(int)
        self.splits = {}
        self.merges = {}
        
        self._iter_vocab = []
        self._num_iter = 0
        assert self._iter_vocab or self._num_iter
        
    def _iteration_criteria(self) -> bool:
        if self.vocab_size:
            return len(self._iter_vocab) < self.vocab_size
        else:
            return self._num_iter < self.num_iter


    def train(self):
        """Train BPE tokenizer."""

        # compute the frequencies of each word in the corpus
        for text in self.corpus:
            new_words = [f'{word}_' for word in self.tokenizer(text)]
            for word in new_words:
                self.word_freqs[word] += 1

        # compute the base vocabulary of all characters in the corpus
        alphabet = []
        for word in self.word_freqs.keys():
            for letter in word:
                if letter not in alphabet:
                    alphabet.append(letter)
        alphabet.sort()

        # add the special token </w> at the beginning of the vocabulary
        vocab = ["</w>"] + alphabet.copy()
        self._iter_vocab = vocab

        # split each word into individual characters before training
        self.splits = {word: [c for c in word] for word in self.word_freqs.keys()}

        self._num_iter = 0
        # merge the most frequent pair iteratively until the vocabulary size is reached
        while self._iteration_criteria():

            # compute the frequency of each pair
            pair_freqs = self.compute_pair_freqs()

            # find the most frequent pair
            best_pair = ""
            max_freq = None
            for pair, freq in pair_freqs.items():
                if max_freq is None or max_freq < freq:
                    best_pair = pair
                    max_freq = freq

            # merge the most frequent pair
            self.splits = self.merge_pair(*best_pair)
            self.merges[best_pair] = best_pair[0] + best_pair[1]
            vocab.append(best_pair[0] + best_pair[1])
            self._num_iter += 1
        return self.merges


    def compute_pair_freqs(self):
        """Compute the frequency of each pair."""

        pair_freqs = defaultdict(int)
        for word, freq in self.word_freqs.items():
            split = self.splits[word]
            if len(split) == 1:
                continue
            for i in range(len(split) - 1):
                pair = (split[i], split[i + 1])
                pair_freqs[pair] += freq
        return pair_freqs


    def merge_pair(self, a, b):
        """Merge the given pair."""

        for word in self.word_freqs:
            split = self.splits[word]
            if len(split) == 1:
                continue
            i = 0
            while i < len(split) - 1:
                if split[i] == a and split[i + 1] == b:
                    split = split[:i] + [a + b] + split[i + 2 :]
                else:
                    i += 1
            self.splits[word] = split
        return self.splits


    def tokenize(self, text):
        """Tokenize a given text with trained BPE tokenizer (including pre-tokenization, split, and merge)."""

        pre_tokenized_text = [f'{word}_' for word in self.tokenizer(text)]
        splits_text = [[l for l in word] for word in pre_tokenized_text]

        for pair, merge in self.merges.items():
            for idx, split in enumerate(splits_text):
                i = 0
                while i < len(split) - 1:
                    if split[i] == pair[0] and split[i + 1] == pair[1]:
                        split = split[:i] + [merge] + split[i + 2 :]
                    else:
                        i += 1
                splits_text[idx] = split
        result = sum(splits_text, [])
        return result

In [35]:
text = 'target pasta star star pasta star'

bpe = BPE([text], vocab_size=14)
print(bpe.train())
print(bpe.tokenize('tapas'))
print(bpe.tokenize('stata'))

{('t', 'a'): 'ta', ('s', 'ta'): 'sta', ('sta', 'r'): 'star', ('star', '_'): 'star_', ('p', 'a'): 'pa'}
['ta', 'pa', 's', '_']
['sta', 'ta', '_']


In [34]:
text = 'target pasta star star pasta star'

bpe = BPE([text], num_iter=6)
print(bpe.train())
print(bpe.tokenize('tapas'))
print(bpe.tokenize('stata'))

{('t', 'a'): 'ta', ('s', 'ta'): 'sta', ('sta', 'r'): 'star', ('star', '_'): 'star_', ('p', 'a'): 'pa', ('pa', 'sta'): 'pasta'}
['ta', 'pa', 's', '_']
['sta', 'ta', '_']
