In this notebook, we explore **byte-pair encoding** described in the paper [Neural machine translation of rare words with subword units](http://arxiv.org/abs/1508.07909v5). The authors also made a [reference implementation](https://github.com/rsennrich/subword-nmt) available on Github which we will use as an additional reference.

The idea of BPE is as follows. We start with a vocabulary that initially only contains the basic characters which appear in our text, and tokenize our input according to this vocabulary. We then identify pairs of characters (originally called byte pairs, lending the algorithm its name) which occur frequently in the text next to each other. We introduce a new token for the most frequent pair and re-tokenize the text, taking the new, now larger vocabulary into account. This process called **merge** is repeated iteratively until the vocabulary reaches a given size. In praxis, it will then consist of the most frequent words, a few subwords and still the original set of characters. If we now hit upon an unknown word during inference, we can split it until the individual parts appear in the vocabulary, if needed down to the character level. Therefore unknown token can only appear if new characters show up that we have not seen before, which is easily excluded by fixing a defined character set like ASCII or a subset of the unicode character set.

Before going into details, let us quickly discuss one subtle point - word boundaries. In a word-level tokenizer, there is no problem with word boundaries, as each token corresponds to a word. In a subword-level tokenizer, we need to be careful not to forget the information at which points a word ends and the next word starts. In the original paper, a dedicated "end-of-word" token "</w>" was used for that purpose. This was an ordinary token and thus could get merged with any other token while going through the algorithm. In the current version of the reference implementation, a different approach was chosen - now the end-of-word token is appended to each character while building the initial word list, we will see in a minute how this works. 

Let us now go into the details of the algorithm, using the original paper (that even contained code snippets) and the reference implementation as a guardrail. The first step consists of building up a data structure that we will call the **word frequencies** which is simply a Python dictionary containing all words in the text along with their frequencies (unfortunately this data structure is called the vocabulary in the reference implementation, but we will reserve this term for the set of token that we will identify in the course of the algorithm). Here is a simple function that takes an pre-tokenized example text and creates this data. 

In [1]:
import collections

def get_word_frequencies(pre_tokenized_text):
    counter = collections.Counter(pre_tokenized_text)
    word_frequencies = {" ".join(word) + "</w>" : frequency for word, frequency in counter.items() if len(word) > 0}
    return word_frequencies

pre_tokenized_text = ["low", "lower",  "newest", "widest"]
word_frequencies = get_word_frequencies(pre_tokenized_text)
print(word_frequencies)

{'l o w</w>': 1, 'l o w e r</w>': 1, 'n e w e s t</w>': 1, 'w i d e s t</w>': 1}


Note that the input is already pre-tokenized, i.e. split into words, in practice, we could use a PyTorch tokenizer to do this, or any other tokenizer. The reference implementation simply splits along spaces, but the details are not that important. Also note that the keys in our dictionary are not the actual words, but the word as a sequence of characters (which we store as a string to have a hashable key and use spaces as separators, it is therefore important that spaces in the input are removed before processing it, so a space should not appear as a word and no word should contain a space). Also note that, as promised, we append an end-of-word token to the lat character in any word.

A second data structure that we need is the actual **vocabulary**, i.e. the set of valid token. Initially, this is simply the set of all characters appearing in any of the words in our input. 

In [2]:
def build_vocabulary(word_frequencies):
    vocab = set()
    for word in word_frequencies.keys():
        for c in word.split():
            vocab.add(c)
    return vocab

vocab = build_vocabulary(word_frequencies)
print(vocab)

{'o', 'r</w>', 's', 'w</w>', 'n', 'i', 'e', 'l', 'd', 'w', 't</w>'}


Next, we need to count byte pairs. As we do not cross word boundaries, we can do this on word level, i.e. we go through the words, extract all byte pairs we find there and increase the count of each byte pair by the frequency of the word in the text, thus giving us eventually the frequency of this byte pair (inside word boundaries) in the text.

In [3]:
def get_stats(word_frequencies):
  pairs = collections.defaultdict(int)
  for word, freq in word_frequencies.items():
    symbols = word.split()
    for i in range(len(symbols)-1):
      pairs[symbols[i],symbols[i+1]] += freq
  return pairs

stats = get_stats(word_frequencies)
print(stats)

defaultdict(<class 'int'>, {('l', 'o'): 2, ('o', 'w</w>'): 1, ('o', 'w'): 1, ('w', 'e'): 2, ('e', 'r</w>'): 1, ('n', 'e'): 1, ('e', 'w'): 1, ('e', 's'): 2, ('s', 't</w>'): 2, ('w', 'i'): 1, ('i', 'd'): 1, ('d', 'e'): 1})


The last basic operation that we need is to conduct a **merge**. First, we identify the pair that occurs most frequently (if more than one pair appears with the same frequency, we follow the convention in the reference implementation to use lexicgraphic ordering next). In our case, this is the pair w, e. 

In [4]:
best_pair = max(stats, key=lambda x: (stats[x], x)) # return tuple in key function, so that comparison of tuples applies
print(f"Best pair: {best_pair}")

Best pair: ('w', 'e')


We now go through our word frequency dictionary and whenever we encounter the sequence "w e", we replace it by "we" (at this point our convention to store a sequence of token as a string using spaces as separators pays off, as we can use Python string operations to do this). We also add "we" as new token to our vocabulary.

In [5]:
import re

def do_merge(best_pair, word_frequencies, vocab):
    new_frequencies = dict()
    new_token = "".join(best_pair)
    pair = " ".join(best_pair)
    vocab.add(new_token)
    for word, freq in word_frequencies.items():
        new_word = re.sub(pair, new_token, word)
        new_frequencies[new_word] = word_frequencies[word]
    return new_frequencies, vocab


word_frequencies, vocab = do_merge(best_pair, word_frequencies, vocab)
print(f"Updated word frequencies: {word_frequencies}")
print(f"Updated vocab: {vocab}")
rules = []
rules.append(best_pair)

Updated word frequencies: {'l o w</w>': 1, 'l o we r</w>': 1, 'n e we s t</w>': 1, 'w i d e s t</w>': 1}
Updated vocab: {'o', 'r</w>', 's', 'w</w>', 'n', 'i', 'e', 'l', 'we', 'd', 'w', 't</w>'}


Unfortunately, our code contains a flaw. Suppose we wanted to merge the symbols "o" and "w". In the word "l o w</w>", the character w appears at the end of the word, i.e. in our encoding scheme, as part of "w</w>". This is one symbol which is different from a free-standing "w". So we must **not** merge this with the "o" preceeding it. Our code, however, would do this:

In [6]:
best_pair = ("o", "w")
new_token = "".join(best_pair)
pair = " ".join(best_pair)
word = "l o w</w>"
new_word = re.sub(pair, new_token, word)
print(new_word) # gives a merge which should not be the case

l ow</w>


To fix this, we have to use a more sophisticated regular expression that employs a lookahead assertion and a lookbehind assertion (more on this [here](https://docs.python.org/3/library/re.html) to make sure that we only match our best pair if it is surrounded by spaces or word boundaries, i.e. not by regular characters. Also note that we need to escape our byte pair, as it might itself contain characters that have a special meaning inside regular expressions.

In [7]:
best_pair = ("o", "w")
new_token = "".join(best_pair)
pattern = r"(?<!\S)" + re.escape(" ".join(best_pair)) + r"(?!\S)"
word = "l o w</w>"
new_word = re.sub(pattern, new_token, word)
print(new_word) # merge not done with this RE

#
# Updated merge function
#
def do_merge(best_pair, word_frequencies, vocab):
    new_frequencies = dict()
    new_token = "".join(best_pair)
    pattern = r"(?<!\S)" + re.escape(" ".join(best_pair)) + r"(?!\S)"
    vocab.add(new_token)
    for word, freq in word_frequencies.items():
        new_word = re.sub(pattern, new_token, word)
        new_frequencies[new_word] = word_frequencies[word]
    return new_frequencies, vocab


l o w</w>


As part of the output of a merge, we also maintain a list of the merges we have done, i.e. the best pairs, as we need to re-apply them later during tokenization. Let us repeat this process for two more times, so that we have done three merges in total.

In [8]:
for i in range(2):
    stats = get_stats(word_frequencies)
    best_pair = max(stats, key=lambda x: (stats[x], x)) 
    print(f"Best pair: {best_pair}")
    word_frequencies, vocab = do_merge(best_pair, word_frequencies, vocab)
    rules.append(best_pair)
    
print(f"Final vocabulary: {vocab}")
print(f"Final words: {word_frequencies.keys()}")
print(f"Rules: {rules}")

Best pair: ('s', 't</w>')
Best pair: ('l', 'o')
Final vocabulary: {'o', 'st</w>', 'r</w>', 's', 'w</w>', 'n', 'i', 'lo', 'e', 'l', 'we', 'd', 'w', 't</w>'}
Final words: dict_keys(['lo w</w>', 'lo we r</w>', 'n e we st</w>', 'w i d e st</w>'])
Rules: [('w', 'e'), ('s', 't</w>'), ('l', 'o')]


The output of the algorithm which we will need later on consists of two parts - the vocabulary, i.e. the final set of valid token, and the rules that the model has derived to arrive at the vocabulary, as we will need to apply the same set of rules (in the same order) to a word in order to tokenize it. Before doing this, let us quickly discuss some improvements that are present in the reference implementation but not in our code.

* the reference implementation stores the word frequencies as an array where each entry consists of the word and the frequency, not as a dictionary. This makes it easier to modify elements as the words are no longer the keys, and to index individual elements
* our code visits all words during a merge, even those words which did not even contain the byte pair that we merge. The reference code maintains an index, which is filled when the pair statistics is calculated and contains essentially a list of all words that contain the respective byte pair. This index is used to only update those words which need to be changed
* in addition the statistics are not entirely re-calculated after each merge, but updated incrementally only for those byte pairs that are actually affected by the merge
* instead of simply doing a fixed number of merges, the implementation stops if no byte pair can be found whose frequency is above a certain threshold (two by default), so that the algorithm might stop earlier

Let us now discuss how tokenization can be implemented once we have determined the vocabulary and the rules. The idea is again described in the original paper - given an unknown word, we simply apply the rules in the order in which we have derived and recorded them to the new word. Of course, there are some shortcuts that we can take, for instance by maintaining a cache of words that we have already seen, but let us again focus on the straightforward implementation in this notebook. This is actually rather simple - to encode a word we first turn it into the same format as during training, i.e. into a chain of characters, separated by whitespaces and followed by an end-of-word marker, and then use regular expressions as during training to run the replacements per rule. We will also have to build a lookup table mapping items in the vocabulary to indices so that the encoding eventually returns a list of IDs. 

In [9]:
stoi = dict()
for idx, symbol in enumerate(vocab):
    stoi[symbol] = idx

def encode(word):
    _word = " ".join(word) + "</w>"
    #
    # apply rules in the original order
    #
    for r, bp in enumerate(rules):
        new_token = "".join(bp)
        pattern = re.compile(r'(?<!\S)' + re.escape(" ".join(bp)) + r'(?!\S)')
        _word = pattern.sub(new_token, _word)
    indices = [stoi[symbol] for symbol in _word.split()]
    return _word, indices

In [10]:
print(encode("low"))
print(encode("newest"))

('lo w</w>', [7, 4])
('n e we st</w>', [5, 8, 10, 1])


Of course there is still a way to go to turn this into a real implementation. One obvious point is that we do not yet handle unknown symbols, so we should add a second meta-symbol to our vocabulary when starting the learning phase that represents an unknown symbol. During the lookup at the end of the encoding, we would then turn any failed lookup into a default index. There are also some shortcuts we can take during encoding that speed up the process substantially:
* implement a cache
* if the word appears as-is in the vocabulary, return the corresponding symbol immediately
* if we process the rules and arrive at a point where the remaining word does not contain any additional spaces (i.e. consists of one final symbol), exit the loop as we are done
* compile all regular expressions upfront and store them along with the rules