In [4]:
from IPython.display import Image

- references
    - https://towardsdatascience.com/byte-pair-encoding-the-dark-horse-of-modern-nlp-eb36c7df4f10

In [6]:
Image(url='https://miro.medium.com/v2/resize:fit:1400/format:webp/1*x1Y_n3sXGygUPSdfXTm9pQ.gif', width=300)

In [7]:
import re
from collections import Counter, defaultdict


def build_vocab(corpus: str) -> dict:
    """Step 1. Build vocab from text corpus"""

    # Separate each char in word by space and add mark end of token
    tokens = [" ".join(word) + " </w>" for word in corpus.split()]

    # Count frequency of tokens in corpus
    vocab = Counter(tokens)

    return vocab


def get_stats(vocab: dict) -> dict:
    """Step 2. Get counts of pairs of consecutive symbols"""

    pairs = defaultdict(int)
    for word, frequency in vocab.items():
        symbols = word.split()

        # Counting up occurrences of pairs
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i + 1]] += frequency

    return pairs


def merge_vocab(pair: tuple, v_in: dict) -> dict:
    """Step 3. Merge all occurrences of the most frequent pair"""

    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')

    for word in v_in:
        # replace most frequent pair in all vocabulary
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]

    return v_out

In [12]:
corpus = 'aaabdaaabac'
vocab = build_vocab(corpus)  # Step 1

num_merges = 3  # Hyperparameter
for i in range(num_merges):

    pairs = get_stats(vocab)  # Step 2

    if not pairs:
        break

    # step 3
    best = max(pairs, key=pairs.get)
    print(best, pairs[best])
    vocab = merge_vocab(best, vocab)
    print(i, vocab)
    print()

('a', 'a') 4
0 {'aa a b d aa a b a c </w>': 1}

('aa', 'a') 2
1 {'aaa b d aaa b a c </w>': 1}

('aaa', 'b') 2
2 {'aaab d aaab a c </w>': 1}



In [8]:
corpus = 'low lower newest wildest'
vocab = build_vocab(corpus)  # Step 1

num_merges = 6  # Hyperparameter
for i in range(num_merges):

    pairs = get_stats(vocab)  # Step 2

    if not pairs:
        break

    # step 3
    best = max(pairs, key=pairs.get)
    print(best, pairs[best])
    vocab = merge_vocab(best, vocab)
    print(i, vocab)
    print()

('l', 'o') 2
0 {'lo w </w>': 1, 'lo w e r </w>': 1, 'n e w e s t </w>': 1, 'w i l d e s t </w>': 1}

('lo', 'w') 2
1 {'low </w>': 1, 'low e r </w>': 1, 'n e w e s t </w>': 1, 'w i l d e s t </w>': 1}

('e', 's') 2
2 {'low </w>': 1, 'low e r </w>': 1, 'n e w es t </w>': 1, 'w i l d es t </w>': 1}

('es', 't') 2
3 {'low </w>': 1, 'low e r </w>': 1, 'n e w est </w>': 1, 'w i l d est </w>': 1}

('est', '</w>') 2
4 {'low </w>': 1, 'low e r </w>': 1, 'n e w est</w>': 1, 'w i l d est</w>': 1}

('low', '</w>') 1
5 {'low</w>': 1, 'low e r </w>': 1, 'n e w est</w>': 1, 'w i l d est</w>': 1}

