# Demonstration of byte-pair encoding

Code adapted from https://leimao.github.io/blog/Byte-Pair-Encoding/


In [7]:
import collections
import re
from nltk.corpus import brown
import nltk

In [2]:
def get_vocab(word_gen):
    """Initialize vocabulary from corpus
    """
    vocab = collections.defaultdict(int)
    for word in word_gen:
        vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

def get_stats(vocab):
    """Calculate co-occurrence statistics for token bigrams
    """
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    """Merge two tokens and update the resulting vocabulary
    """
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def get_tokens_from_vocab(vocab):
    """Recover tokens from tokenized vocabulary
    """
    tokens_frequencies = collections.defaultdict(int)
    vocab_tokenization = {}
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens_frequencies[token] += freq
        vocab_tokenization[''.join(word_tokens)] = word_tokens
    return tokens_frequencies, vocab_tokenization

def measure_token_length(token):
    if token[-4:] == '</w>':
        return len(token[:-4]) + 1
    else:
        return len(token)


In [8]:
nltk.download('brown')

vocab = get_vocab(map(lambda x: x.lower(),
                      filter(lambda x: x.isalpha(), 
                             brown.words())))

print('==========')
print('Tokens Before BPE')
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
print('All tokens: {}'.format(tokens_frequencies.keys()))
print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
print('==========')


[nltk_data] Downloading package brown to
[nltk_data]     /Users/BharathBandaru/nltk_data...
[nltk_data]   Unzipping corpora/brown.zip.


Tokens Before BPE
All tokens: dict_keys(['t', 'h', 'e', '</w>', 'f', 'u', 'l', 'o', 'n', 'c', 'y', 'g', 'r', 'a', 'd', 'j', 's', 'i', 'v', 'p', 'm', 'k', 'x', 'w', 'b', 'z', 'q'])
Number of tokens: 27


In [9]:
def step_bpe(v):
    pairs = get_stats(v)
    if not pairs:
        return v
    best = max(pairs, key=pairs.get)
    v = merge_vocab(best, v)
    return v, best

In [10]:
num_merges = 10
for i in range(num_merges):
    print('Iter: {}'.format(i))
    vocab, best = step_bpe(vocab)
    tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
    print('Best pair: {}'.format(best))
    print('All tokens: {}'.format(tokens_frequencies.keys()))
    print('Number of tokens: {}'.format(len(tokens_frequencies.keys())))
    print('==========')

Iter: 0
Best pair: ('e', '</w>')
All tokens: dict_keys(['t', 'h', 'e</w>', 'f', 'u', 'l', 'o', 'n', '</w>', 'c', 'y', 'g', 'r', 'a', 'd', 'j', 's', 'i', 'v', 'e', 'p', 'm', 'k', 'x', 'w', 'b', 'z', 'q'])
Number of tokens: 28
Iter: 1
Best pair: ('t', 'h')
All tokens: dict_keys(['th', 'e</w>', 'f', 'u', 'l', 't', 'o', 'n', '</w>', 'c', 'y', 'g', 'r', 'a', 'd', 'j', 's', 'i', 'v', 'e', 'p', 'm', 'k', 'x', 'w', 'h', 'b', 'z', 'q'])
Number of tokens: 29
Iter: 2
Best pair: ('s', '</w>')
All tokens: dict_keys(['th', 'e</w>', 'f', 'u', 'l', 't', 'o', 'n', '</w>', 'c', 'y', 'g', 'r', 'a', 'd', 'j', 's', 'i', 'v', 'e', 'p', 'm', 's</w>', 'k', 'x', 'w', 'h', 'b', 'z', 'q'])
Number of tokens: 30
Iter: 3
Best pair: ('d', '</w>')
All tokens: dict_keys(['th', 'e</w>', 'f', 'u', 'l', 't', 'o', 'n', '</w>', 'c', 'y', 'g', 'r', 'a', 'd</w>', 'j', 's', 'i', 'd', 'v', 'e', 'p', 'm', 's</w>', 'k', 'x', 'w', 'h', 'b', 'z', 'q'])
Number of tokens: 31
Iter: 4
Best pair: ('t', '</w>')
All tokens: dict_keys(['t

In [11]:
list(tokens_frequencies.keys())

['the</w>',
 'f',
 'u',
 'l',
 't',
 'on',
 '</w>',
 'c',
 'o',
 'n',
 'y',
 'g',
 'r',
 'an',
 'd</w>',
 'j',
 's',
 'a',
 'i',
 'd',
 'in',
 'v',
 'e',
 't</w>',
 'p',
 'm',
 'e</w>',
 'th',
 's</w>',
 'k',
 'er',
 'x',
 'w',
 'h',
 'b',
 'z',
 'q']

In [12]:
num_merges = 500
for i in range(num_merges):
    vocab, best = step_bpe(vocab)
    if i % 100 == 0:
        print('Iter: {}'.format(i))
        print('Best pair: {}'.format(best))
        print('==========')

Iter: 0
Best pair: ('y', '</w>')


In [8]:
tokens_frequencies, vocab_tokenization = get_tokens_from_vocab(vocab)
sorted_tokens_tuple = sorted(tokens_frequencies.items(), key=lambda item: (measure_token_length(item[0]), item[1]), reverse=True)
sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]

print(sorted_tokens[:10])


['through</w>', 'ations</w>', 'tional</w>', 'before</w>', 'ation</w>', 'which</w>', 'there</w>', 'would</w>', 'their</w>', 'other</w>']


In [9]:
len(tokens_frequencies)

537

In [10]:
vocab_tokenization.get("mountains</w>")

['m', 'oun', 'ta', 'in', 's</w>']

In [11]:
vocab_tokenization.get("fictional</w>")

['f', 'ic', 'tional</w>']

In [12]:
vocab_tokenization.get("liveliness</w>")

['li', 'v', 'el', 'in', 'ess</w>']