In [3]:
from collections import Counter, defaultdict

# Toy corpus
corpus_words = ["low", "lowest", "new", "newer", "wide", "widest"]

def word_to_symbols(word):
    return list(word) + ["_"]  # add end-of-word marker

corpus = [word_to_symbols(w) for w in corpus_words]

def get_pair_counts(corpus):
    pair_counts = Counter()
    for word in corpus:
        for i in range(len(word) - 1):
            pair = (word[i], word[i+1])
            pair_counts[pair] += 1
    return pair_counts

def merge_pair(corpus, pair):
    new_corpus = []
    bigram = list(pair)
    merged = "".join(bigram)
    for word in corpus:
        i = 0
        new_word = []
        while i < len(word):
            if i < len(word) - 1 and word[i] == bigram[0] and word[i+1] == bigram[1]:
                new_word.append(merged)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_corpus.append(new_word)
    return new_corpus

def vocab_from_corpus(corpus):
    vocab = set()
    for word in corpus:
        vocab.update(word)
    return vocab

# Learn BPE merges
num_merges = 10
corpus_bpe = corpus[:]

for step in range(1, num_merges + 1):
    pair_counts = get_pair_counts(corpus_bpe)
    if not pair_counts:
        break
    best_pair, best_count = pair_counts.most_common(1)[0]
    vocab_size = len(vocab_from_corpus(corpus_bpe))
    print(f"Step {step}: best pair = {best_pair}, count = {best_count}, vocab size = {vocab_size}")
    corpus_bpe = merge_pair(corpus_bpe, best_pair)
    merges = ["low", "new", "wide", "er_", "est_"]  # example learned subwords

def segment_word(word, merges):
    symbols = list(word) + ["_"]
    # Greedy longest-match over merges
    changed = True
    while changed:
        changed = False
        for m in sorted(merges, key=len, reverse=True):
            m_list = list(m)
            i = 0
            while i <= len(symbols) - len(m_list):
                if symbols[i:i+len(m_list)] == m_list:
                    symbols = symbols[:i] + ["".join(m_list)] + symbols[i+len(m_list):]
                    changed = True
                else:
                    i += 1
    return symbols

words_to_segment = ["new", "newer", "lowest", "widest", "earnest"]

for w in words_to_segment:
    print(w, "->", segment_word(w, merges))

Step 1: best pair = ('l', 'o'), count = 2, vocab size = 11
Step 2: best pair = ('lo', 'w'), count = 2, vocab size = 10
Step 3: best pair = ('e', 's'), count = 2, vocab size = 10
Step 4: best pair = ('es', 't'), count = 2, vocab size = 10
Step 5: best pair = ('est', '_'), count = 2, vocab size = 9
Step 6: best pair = ('n', 'e'), count = 2, vocab size = 9
Step 7: best pair = ('ne', 'w'), count = 2, vocab size = 9
Step 8: best pair = ('w', 'i'), count = 2, vocab size = 9
Step 9: best pair = ('wi', 'd'), count = 2, vocab size = 8
Step 10: best pair = ('low', '_'), count = 1, vocab size = 7
new -> ['new', '_']
newer -> ['new', 'er_']
lowest -> ['low', 'est_']
widest -> ['wide', 's', 't', '_']
earnest -> ['e', 'a', 'r', 'n', 'est_']
