# Byte Pair Encoding Example

This is an example of byte pair encoding from [Neural Machine Translation of Rare Words with Subword Units](https://arxiv.org/abs/1508.07909) by Sennrich et al. (2015).

In [1]:
import collections
import re

In [2]:
def get_stats(vocab: dict[str, int]) -> dict[tuple[str, str], int]:
    """Get stats of token pairs."""
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i], symbols[i+1]] += freq
    return pairs

In [3]:
def merge_vocab(pair: tuple[str, str], v_in: dict[str, int]) -> dict[str, int]:
    """Merge a particular pair together and return the new vocabulary."""
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

In [4]:
vocab = {
    'l o w </w>' : 5,
    'l o w e r </w>' : 2,
    'n e w e s t </w>': 6,
    'w i d e s t </w>': 3,
}
num_merges = 10

for i in range(num_merges):
    print(f"{vocab=}")
    pairs = get_stats(vocab)
    top_pairs = sorted(list(pairs.items()), key=lambda x: x[1], reverse=True)[:5]
    print(f"{top_pairs=}")
    best = top_pairs[0][0]
    vocab = merge_vocab(best, vocab)
    print(f"best={best}: {pairs[best]}")

vocab={'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
top_pairs=[(('e', 's'), 9), (('s', 't'), 9), (('t', '</w>'), 9), (('w', 'e'), 8), (('l', 'o'), 7)]
best=('e', 's'): 9
vocab={'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}
top_pairs=[(('es', 't'), 9), (('t', '</w>'), 9), (('l', 'o'), 7), (('o', 'w'), 7), (('n', 'e'), 6)]
best=('es', 't'): 9
vocab={'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}
top_pairs=[(('est', '</w>'), 9), (('l', 'o'), 7), (('o', 'w'), 7), (('n', 'e'), 6), (('e', 'w'), 6)]
best=('est', '</w>'): 9
vocab={'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
top_pairs=[(('l', 'o'), 7), (('o', 'w'), 7), (('n', 'e'), 6), (('e', 'w'), 6), (('w', 'est</w>'), 6)]
best=('l', 'o'): 7
vocab={'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
top_pairs=[(('lo', 'w'), 7), (('n', 'e'), 6), (('e', 'w'), 6), (('w', 'est</w>'),