In [1]:
import re, collections
from IPython.display import display, Markdown, Latex

- num_merges: Decide BPE(Byte Pair Encoding)'s trials count
- dictionary: BPE's word set

In [12]:
num_merges = 10
dictionary = {
    'l o w </w>': 5,
    'l o w e r </w>': 2,
    'n e w e s t </w>': 6,
    'w i d e s t </w>': 3
}

In [13]:
def get_stats(dictionary):
    # count the frequnecy of pairs of unigram
    pairs = collections.defaultdict(int)
    
    for word, freq in dictionary.items():
        symbols = word.split()
        
        for x, y in zip(symbols, symbols[1:]):
            pairs[x, y] += freq
    
    print(f"current frequency of pairs : {dict(pairs)}")
    return pairs

def merge_dictionary(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
        
    return v_out

In [14]:
bpe_codes, bpe_codes_reverse = {}, {}

for i in range(num_merges):
    display(Markdown(f"### Iteration {i + 1}"))
    
    pairs = get_stats(dictionary)
    best = max(pairs, key=pairs.get)
    
    dictionary = merge_dictionary(best, dictionary)
    
    bpe_codes[best] = i
    bpe_codes_reverse[best[0] + best[1]] = best
    
    print(f"new merge: {best}")
    print(f"dictionary: {dictionary}")

### Iteration 1

current frequency of pairs : {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 8, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('e', 's'): 9, ('s', 't'): 9, ('t', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3}
new merge: ('e', 's')
dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}


### Iteration 2

current frequency of pairs : {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'es'): 6, ('es', 't'): 9, ('t', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'es'): 3}
new merge: ('es', 't')
dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}


### Iteration 3

current frequency of pairs : {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est'): 6, ('est', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}
new merge: ('est', '</w>')
dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}


### Iteration 4

current frequency of pairs : {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}
new merge: ('l', 'o')
dictionary: {'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}


### Iteration 5

current frequency of pairs : {('lo', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}
new merge: ('lo', 'w')
dictionary: {'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}


### Iteration 6

current frequency of pairs : {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}
new merge: ('n', 'e')
dictionary: {'low </w>': 5, 'low e r </w>': 2, 'ne w est</w>': 6, 'w i d est</w>': 3}


### Iteration 7

current frequency of pairs : {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('ne', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}
new merge: ('ne', 'w')
dictionary: {'low </w>': 5, 'low e r </w>': 2, 'new est</w>': 6, 'w i d est</w>': 3}


### Iteration 8

current frequency of pairs : {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('new', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}
new merge: ('new', 'est</w>')
dictionary: {'low </w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'w i d est</w>': 3}


### Iteration 9

current frequency of pairs : {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}
new merge: ('low', '</w>')
dictionary: {'low</w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'w i d est</w>': 3}


### Iteration 10

current frequency of pairs : {('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}
new merge: ('w', 'i')
dictionary: {'low</w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'wi d est</w>': 3}


In [15]:
def get_pairs(word):
    """
    Ruturn set of symbol pairs in a word.
    Word is represented as a tuple of symbols (symbols being variable length strings)
    """
    pairs = set()
    for prev_char, char in zip(word, word[1:]):
        pairs.add((prev_char, char))
    
    return pairs

def encode(orig):
    """
    Encode word based on list of BPE merge operations, which are applied consecutively
    """
    word = tuple(orig) + ('</w>', )
    display(Markdown(f'__word split into characters:__ <tt>{word}<tt>'))
    
    pairs = get_pairs(word)
    
    if not pairs:
        return orig
    
    iteration = 0
    
    while True:
        iteration += 1
        display(Markdown(f"__iteration {iteration}:__"))
        
        print(f"bigrams in the word: {pairs}")
        bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
        print(f"candidate for merging: {bigram}")
        
        if bigram not in bpe_codes:
            display(Markdown("__Candidate not in BPE merges, algorithm stop__"))
            break
            
        first, second = bigram
        new_word = []
        i = 0
        
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
                
            except:
                new_word.extend(word[i:])
                break
                
            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first + second)
                i += 2
                
            else:
                new_word.append(word[i])
                i += 1
                
        new_word = tuple(new_word)
        word = new_word
        print(f'word after merging: {word}')
        if len(word) == 1:
            break
            
        else:
            pairs = get_pairs(word)
            
    if word[-1] == '</w>':
        word = word[:-1]
        
    elif word[-1].endswith('</w>'):
        word = word[:-1] + (word[-1]. replace('</w>', ''), )
        
    return word
    

In [16]:
encode('loki')

__word split into characters:__ <tt>('l', 'o', 'k', 'i', '</w>')<tt>

__iteration 1:__

bigrams in the word: {('o', 'k'), ('k', 'i'), ('l', 'o'), ('i', '</w>')}
candidate for merging: ('l', 'o')
word after merging: ('lo', 'k', 'i', '</w>')


__iteration 2:__

bigrams in the word: {('lo', 'k'), ('k', 'i'), ('i', '</w>')}
candidate for merging: ('lo', 'k')


__Candidate not in BPE merges, algorithm stop__

('lo', 'k', 'i')

In [17]:
encode("lowest")

__word split into characters:__ <tt>('l', 'o', 'w', 'e', 's', 't', '</w>')<tt>

__iteration 1:__

bigrams in the word: {('s', 't'), ('e', 's'), ('o', 'w'), ('w', 'e'), ('l', 'o'), ('t', '</w>')}
candidate for merging: ('e', 's')
word after merging: ('l', 'o', 'w', 'es', 't', '</w>')


__iteration 2:__

bigrams in the word: {('o', 'w'), ('es', 't'), ('l', 'o'), ('w', 'es'), ('t', '</w>')}
candidate for merging: ('es', 't')
word after merging: ('l', 'o', 'w', 'est', '</w>')


__iteration 3:__

bigrams in the word: {('est', '</w>'), ('w', 'est'), ('o', 'w'), ('l', 'o')}
candidate for merging: ('est', '</w>')
word after merging: ('l', 'o', 'w', 'est</w>')


__iteration 4:__

bigrams in the word: {('w', 'est</w>'), ('o', 'w'), ('l', 'o')}
candidate for merging: ('l', 'o')
word after merging: ('lo', 'w', 'est</w>')


__iteration 5:__

bigrams in the word: {('w', 'est</w>'), ('lo', 'w')}
candidate for merging: ('lo', 'w')
word after merging: ('low', 'est</w>')


__iteration 6:__

bigrams in the word: {('low', 'est</w>')}
candidate for merging: ('low', 'est</w>')


__Candidate not in BPE merges, algorithm stop__

('low', 'est')

In [18]:
encode("lowing")

__word split into characters:__ <tt>('l', 'o', 'w', 'i', 'n', 'g', '</w>')<tt>

__iteration 1:__

bigrams in the word: {('w', 'i'), ('n', 'g'), ('g', '</w>'), ('i', 'n'), ('o', 'w'), ('l', 'o')}
candidate for merging: ('l', 'o')
word after merging: ('lo', 'w', 'i', 'n', 'g', '</w>')


__iteration 2:__

bigrams in the word: {('w', 'i'), ('g', '</w>'), ('n', 'g'), ('lo', 'w'), ('i', 'n')}
candidate for merging: ('lo', 'w')
word after merging: ('low', 'i', 'n', 'g', '</w>')


__iteration 3:__

bigrams in the word: {('n', 'g'), ('g', '</w>'), ('low', 'i'), ('i', 'n')}
candidate for merging: ('n', 'g')


__Candidate not in BPE merges, algorithm stop__

('low', 'i', 'n', 'g')

In [19]:
encode("highing")

__word split into characters:__ <tt>('h', 'i', 'g', 'h', 'i', 'n', 'g', '</w>')<tt>

__iteration 1:__

bigrams in the word: {('n', 'g'), ('g', '</w>'), ('h', 'i'), ('i', 'n'), ('g', 'h'), ('i', 'g')}
candidate for merging: ('n', 'g')


__Candidate not in BPE merges, algorithm stop__

('h', 'i', 'g', 'h', 'i', 'n', 'g')