In [1]:
import re, collections

# Basics of Byte Pair Encoding
## https://arxiv.org/pdf/1508.07909.pdf

In [2]:
epochs = 10

# data
dictionary = {'l o w </w>' : 5,
              'l o w e r </w>' : 2,
              'n e w e s t </w>':6,
              'w i d e s t </w>':3
              }

In [3]:
# example 
pair = {'l','o'}
bigram = re.escape(' '.join(pair)) # 'l\ o'
p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)') # (?<!\S)l\ o(?!\\S)
out = p.sub(repl=''.join(pair), string='l o w </w>') # l o w </w> -> low </w>

print(out)

l o w </w>


In [4]:
def get_stats(dict):
    pairs = collections.defaultdict(int)
    
    for word, freq in dict.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
            
    return pairs
            
def merge_dict(pair, dict_in):
    dict_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    
    for word in dict_in:
        word_out = p.sub(repl=''.join(pair), string=word)
        dict_out[word_out] = dict_in[word]
        
    return dict_out

#### collections.defaultdict(int)
 - acts like a regular dictionary but will automatically initialize non-existent keys with a default value of 0 upon first access.
#### re.escape
 - escape special characters = using \ to escape special characters.
#### re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
 - (?<!\\S)hello world(?!\\S) -> matches hello world, hello world!, hellohello world, but not helloworld.
### re.sub(pattern, repl, string)
 - return the string obtained by replacing the leftmost non-overlapping occurrences of pattern in string by the replacement repl. If the pattern isn’t found, string is returned unchanged. 

In [5]:
bpe_codes = {}
bpe_codes_reverse = {}

for i in range(epochs):
    pairs = get_stats(dictionary)
    most_freq = max(pairs, key=pairs.get)
    dictionary = merge_dict(most_freq, dictionary)
    
    bpe_codes[most_freq] = i
    bpe_codes_reverse[most_freq[0] + most_freq[1]] = most_freq
    
    print('Iter: {}'.format(i))
    print('Most frequent pair: {}'.format(most_freq))
    print('Dictionary: {}'.format(dictionary))

Iter: 0
Most frequent pair: ('e', 's')
Dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}
Iter: 1
Most frequent pair: ('es', 't')
Dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}
Iter: 2
Most frequent pair: ('est', '</w>')
Dictionary: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
Iter: 3
Most frequent pair: ('l', 'o')
Dictionary: {'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
Iter: 4
Most frequent pair: ('lo', 'w')
Dictionary: {'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}
Iter: 5
Most frequent pair: ('n', 'e')
Dictionary: {'low </w>': 5, 'low e r </w>': 2, 'ne w est</w>': 6, 'w i d est</w>': 3}
Iter: 6
Most frequent pair: ('ne', 'w')
Dictionary: {'low </w>': 5, 'low e r </w>': 2, 'new est</w>': 6, 'w i d est</w>': 3}
Iter: 7
Most frequent pair: ('new', 'est</w>')
Dictionary: {'low </w>': 5, 'low e r </w>': 2,

#### key=pairs.get
 - specifies how to compute the 'value' for each element in the iterable for comparison purposes. 
 - pairs.get is a method that, when called with a key, returns the value associated with that key in the pairs dictionary.
 - By passing pairs.get as the key function, you're telling max() to find the key in pairs whose associated value is the highest.

In [6]:
print(bpe_codes) # key: pair, value: iteration : record the order of merging
print(bpe_codes_reverse)

{('e', 's'): 0, ('es', 't'): 1, ('est', '</w>'): 2, ('l', 'o'): 3, ('lo', 'w'): 4, ('n', 'e'): 5, ('ne', 'w'): 6, ('new', 'est</w>'): 7, ('low', '</w>'): 8, ('w', 'i'): 9}
{'es': ('e', 's'), 'est': ('es', 't'), 'est</w>': ('est', '</w>'), 'lo': ('l', 'o'), 'low': ('lo', 'w'), 'ne': ('n', 'e'), 'new': ('ne', 'w'), 'newest</w>': ('new', 'est</w>'), 'low</w>': ('low', '</w>'), 'wi': ('w', 'i')}


# Handling OOV

In [14]:
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as a tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def encode(orig):
    """Encode word based on list of BPE merge operations, which are applied consecutively"""

    word = tuple(orig) + ('</w>',)
    print(word)
    pairs = get_pairs(word)

    if not pairs:
        return orig

    iteration = 0
    while True:
        iteration += 1
        print("Iteration {}".format(iteration))

        print("bigrams in the word: {}".format(pairs))
        bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
        print("candidate for merging: {}".format(bigram))
        if bigram not in bpe_codes:
            print("Candidate not in BPE merges, algorithm stops.")
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            print("in")
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                print(j)
                print(new_word)
                i = j
            except:
                new_word.extend(word[i:])
                print("except")
                print(i)
                print(word[i:])
                print(new_word)
                break

            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                print("a")
                new_word.append(first+second)
                i += 2
            else:
                print("b")
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        print("word after merging: {}".format(word))
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)

    # 특별 토큰인 </w>는 출력하지 않는다.
    if word[-1] == '</w>':
        word = word[:-1]
    elif word[-1].endswith('</w>'):
        word = word[:-1] + (word[-1].replace('</w>',''),)

    return word

In [15]:
encode("lowest")

('l', 'o', 'w', 'e', 's', 't', '</w>')
Iteration 1
bigrams in the word: {('o', 'w'), ('t', '</w>'), ('l', 'o'), ('e', 's'), ('w', 'e'), ('s', 't')}
candidate for merging: ('e', 's')
in
3
['l', 'o', 'w']
a
in
except
5
('t', '</w>')
['l', 'o', 'w', 'es', 't', '</w>']
word after merging: ('l', 'o', 'w', 'es', 't', '</w>')
Iteration 2
bigrams in the word: {('o', 'w'), ('t', '</w>'), ('w', 'es'), ('es', 't'), ('l', 'o')}
candidate for merging: ('es', 't')
in
3
['l', 'o', 'w']
a
in
except
5
('</w>',)
['l', 'o', 'w', 'est', '</w>']
word after merging: ('l', 'o', 'w', 'est', '</w>')
Iteration 3
bigrams in the word: {('o', 'w'), ('est', '</w>'), ('l', 'o'), ('w', 'est')}
candidate for merging: ('est', '</w>')
in
3
['l', 'o', 'w']
a
word after merging: ('l', 'o', 'w', 'est</w>')
Iteration 4
bigrams in the word: {('o', 'w'), ('l', 'o'), ('w', 'est</w>')}
candidate for merging: ('l', 'o')
in
0
[]
a
in
except
2
('w', 'est</w>')
['lo', 'w', 'est</w>']
word after merging: ('lo', 'w', 'est</w>')
Itera

('low', 'est')