In [9]:
# 2024/1/21
# zhangzhong
# Byte Pair Encoding

In [10]:
import torch
import collections


In [11]:
symbols = [
    "a",
    "b",
    "c",
    "d",
    "e",
    "f",
    "g",
    "h",
    "i",
    "j",
    "k",
    "l",
    "m",
    "n",
    "o",
    "p",
    "q",
    "r",
    "s",
    "t",
    "u",
    "v",
    "w",
    "x",
    "y",
    "z",
    "_",
    "[UNK]",
]

raw_token_freqs = {'fast_': 4, 'faster_': 3, 'tall_': 5, 'taller_': 4}
token_freqs = {}
for token, freq in raw_token_freqs.items():
    token_freqs[' '.join(list(token))] = freq
print(token_freqs)

{'f a s t _': 4, 'f a s t e r _': 3, 't a l l _': 5, 't a l l e r _': 4}


In [12]:
# returns the most frequent pair of consecutive symbols within a word
def get_max_freq_pair(token_freqs):
    pairs = collections.defaultdict(int)
    for token, freq in token_freqs.items():
        symbols = token.split()
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i + 1]] += freq
    return max(pairs, key=pairs.get)

In [13]:
# greedy merge the most frequent pair of consecutive symbols to produce new symbols
def merge_symbols(max_freq_pair, token_freqs, symbols):
    symbols.append(''.join(max_freq_pair))
    new_token_freqs = {}
    for token, freq in token_freqs.items():
        new_token = token.replace(' '.join(max_freq_pair), ''.join(max_freq_pair))
        new_token_freqs[new_token] = freq
    return new_token_freqs

In [14]:
# now we iteratively perform the byte pair encoding
num_merges = 10
for i in range(num_merges):
    max_freq_pair = get_max_freq_pair(token_freqs)
    token_freqs = merge_symbols(max_freq_pair, token_freqs, symbols)
    print(f'merge #{i + 1}:', max_freq_pair)
    print(token_freqs)
    print(symbols)
    print("---------------------")

merge #1: ('t', 'a')
{'f a s t _': 4, 'f a s t e r _': 3, 'ta l l _': 5, 'ta l l e r _': 4}
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]', 'ta']
---------------------
merge #2: ('ta', 'l')
{'f a s t _': 4, 'f a s t e r _': 3, 'tal l _': 5, 'tal l e r _': 4}
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]', 'ta', 'tal']
---------------------
merge #3: ('tal', 'l')
{'f a s t _': 4, 'f a s t e r _': 3, 'tall _': 5, 'tall e r _': 4}
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]', 'ta', 'tal', 'tall']
---------------------
merge #4: ('f', 'a')
{'fa s t _': 4, 'fa s t e r _': 3, 'tall _': 5, 'tall e r _': 4}
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't',

In [15]:
# break words into the longest possible subwords from the symbols
def segment_BPE(tokens, symbols):
    subwords = []
    for token in tokens:
        start, end = 0, len(token)
        cur_subwords = []
        while start < len(token) and start < end:
            subword = token[start:end]
            if subword in symbols:
                cur_subwords.append(subword)
                start = end
                end = len(token)
            else:
                end -= 1
        if start < len(token):
            cur_subwords.append('[UNK]')
        subwords.append(' '.join(cur_subwords))
    return subwords

In [None]:
tokens = ['tallest_', 'fatter_']
print(segment_BPE(tokens, symbols))