In [1]:
import re, collections

# Basics of Byte Pair Encoding
## https://arxiv.org/pdf/1508.07909.pdf

In [17]:
epochs = 10

# data
dictionary = {'l o w </w>' : 5,
              'l o w e r </w>' : 2,
              'n e w e s t </w>':6,
              'w i d e s t </w>':3
              }

In [18]:
# example 
pair = {'l','o'}
bigram = re.escape(' '.join(pair)) # 'l\ o'
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)') # (?<!\S)l\ o(?!\\S)
out = p.sub(repl=''.join(pair), string='l o w </w>') # l o w </w> -> low </w>

print(out)

lo w </w>


In [19]:
def get_stats(dict):
    pairs = collections.defaultdict(int)
    
    for word, freq in dict.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
            
    return pairs
            
def merge_dict(pair, dict_in):
    dict_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    
    for word in dict_in:
        word_out = p.sub(repl=''.join(pair), string=word)
        dict_out[word_out] = dict_in[word]
        
    return dict_out

#### collections.defaultdict(int)
 - acts like a regular dictionary but will automatically initialize non-existent keys with a default value of 0 upon first access.
#### re.escape
 - escape special characters = using \ to escape special characters.
#### re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
 - (?<!\\S)hello world(?!\\S) -> matches hello world, hello world!, hellohello world, but not helloworld.
### re.sub(pattern, repl, string)
 - return the string obtained by replacing the leftmost non-overlapping occurrences of pattern in string by the replacement repl. If the pattern isn’t found, string is returned unchanged. 

In [20]:
bpe_codes = {}
bpe_codes_reverse = {}

for i in range(epochs):
    pairs = get_stats(dictionary)
    most_freq = max(pairs, key=pairs.get)
    dictionary = merge_dict(most_freq, dictionary)
    
    bpe_codes[most_freq] = i
    bpe_codes_reverse[most_freq[0] + most_freq[1]] = most_freq
    
    print('Iter: {}'.format(i))
    print('Most frequent pair: {}'.format(most_freq))
    print('Dictionary: {}'.format(dictionary))

Iter: 0
Most frequent pair: ('e', 's')
Dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}
Iter: 1
Most frequent pair: ('es', 't')
Dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}
Iter: 2
Most frequent pair: ('est', '</w>')
Dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
Iter: 3
Most frequent pair: ('l', 'o')
Dictionary: {'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
Iter: 4
Most frequent pair: ('lo', 'w')
Dictionary: {'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
Iter: 5
Most frequent pair: ('n', 'e')
Dictionary: {'low </w>': 5, 'low e r </w>': 2, 'ne w est</w>': 6, 'w i d est</w>': 3}
Iter: 6
Most frequent pair: ('ne', 'w')
Dictionary: {'low </w>': 5, 'low e r </w>': 2, 'new est</w>': 6, 'w i d est</w>': 3}
Iter: 7
Most frequent pair: ('new', 'est</w>')
Dictionary: {'low </w>': 5, 'low e r </w>': 2,

#### key=pairs.get
 - specifies how to compute the 'value' for each element in the iterable for comparison purposes. 
 - pairs.get is a method that, when called with a key, returns the value associated with that key in the pairs dictionary.
 - By passing pairs.get as the key function, you're telling max() to find the key in pairs whose associated value is the highest.

In [23]:
print(bpe_codes)
print(bpe_codes_reverse)


{('e', 's'): 0, ('es', 't'): 1, ('est', '</w>'): 2, ('l', 'o'): 3, ('lo', 'w'): 4, ('n', 'e'): 5, ('ne', 'w'): 6, ('new', 'est</w>'): 7, ('low', '</w>'): 8, ('w', 'i'): 9}
{'es': ('e', 's'), 'est': ('es', 't'), 'est</w>': ('est', '</w>'), 'lo': ('l', 'o'), 'low': ('lo', 'w'), 'ne': ('n', 'e'), 'new': ('ne', 'w'), 'newest</w>': ('new', 'est</w>'), 'low</w>': ('low', '</w>'), 'wi': ('w', 'i')}
