---------------------------
#### Token Learning Example
---------------------------

In [3]:
import re, collections

In [10]:
# extracts vocab with frequency
# {'A f t e r </w>': 1,
#  'e a c h </w>': 1,
#  'm e r g e , </w>': 1,
#  't h e r e </w>': 1,
#  'c o u l d </w>': 1,

def get_vocab(filename):
    
    vocab = collections.defaultdict(int)
    
    with open(filename, 'r', encoding='utf-8') as fhand:
        for line in fhand:
            words = line.strip().split()
            
            for word in words:
                vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

In [5]:
filename = 'bpe-example.txt'

In [11]:
get_vocab(filename)

defaultdict(int,
            {'A f t e r </w>': 1,
             'e a c h </w>': 1,
             'm e r g e , </w>': 1,
             't h e r e </w>': 1,
             'c o u l d </w>': 1,
             'b e </w>': 1,
             't h r e e </w>': 1,
             's c e n a r i o s , </w>': 1,
             't h e </w>': 4,
             'n u m b e r </w>': 3,
             'o f </w>': 3,
             't o k e n s </w>': 2,
             'd e c r e a s e s </w>': 1,
             'b y </w>': 2,
             'o n e , </w>': 1,
             'r e m a i n s </w>': 1,
             's a m e </w>': 1,
             'o r </w>': 1,
             'i n c r e a s e s </w>': 2,
             'o n e . </w>': 1,
             'B u t </w>': 1,
             'i n </w>': 1,
             'p r a c t i c e , </w>': 1,
             'a s </w>': 1,
             'm e r g e s </w>': 1,
             'i n c r e a s e s , </w>': 1,
             'u s u a l l y </w>': 1,
             'f i r s t </w>': 1,
             't h e n <

In [7]:
def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs

In [8]:
def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

In [14]:
# get tokens
def get_tokens(vocab):
    tokens = collections.defaultdict(int)
    
    for word, freq in vocab.items():
        word_tokens = word.split()
        
        for token in word_tokens:
            tokens[token] += freq
            
    return tokens

In [12]:
filename = r'D:\AI-DATASETS\01-MISC\pg.txt'

In [13]:
vocab = get_vocab(filename)

In [15]:
print('Tokens Before BPE')
tokens = get_tokens(vocab)
print('Tokens: {}'.format(tokens))
print('Number of tokens: {}'.format(len(tokens)))

Tokens Before BPE
Tokens: defaultdict(<class 'int'>, {'T': 1607, 'h': 26103, 'e': 59190, '</w>': 101849, 'P': 780, 'r': 29562, 'o': 35007, 'j': 858, 'c': 13900, 't': 44238, 'G': 300, 'u': 13723, 'n': 32498, 'b': 7426, 'g': 8752, 'B': 1162, 'k': 2732, 'f': 10463, 'A': 1379, 'l': 20619, 'd': 17581, 'M': 1204, 'i': 31414, 's': 28311, 'a': 36695, 'y': 8828, 'w': 8155, 'U': 178, 'S': 865, 'm': 9751, 'p': 8030, 'v': 4878, '.': 4061, 'Y': 250, ',': 8065, '-': 1063, 'L': 426, 'I': 1428, ':': 201, 'J': 78, 'V': 102, 'E': 895, 'R': 369, '6': 73, '2': 160, '0': 402, '5': 124, '[': 32, '#': 1, '1': 291, '4': 99, '7': 60, ']': 32, 'D': 322, 'C': 862, 'K': 41, 'O': 510, '/': 31, '*': 22, 'F': 419, 'H': 688, 'N': 793, '"': 4064, '!': 1214, 'W': 576, '3': 104, "'": 1236, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 73, '9': 36, '_': 1426, 'à': 3, 'x': 937, 'z': 364, '°': 41, 'q': 575, ';': 561, '(': 53, ')': 53, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê': 5, 'â': 1, 'ô': 1, 'Æ': 3, 'æ

In [16]:
num_merges = 3

for i in range(num_merges):
    pairs = get_stats(vocab)
    
    if not pairs:
        break
        
    best  = max(pairs, key=pairs.get)
    
    vocab = merge_vocab(best, vocab)
    
    print('Iter: {}'.format(i))
    print('Best pair: {}'.format(best))
    tokens = get_tokens(vocab)
    print('Tokens: {}'.format(tokens))
    print('Number of tokens: {}'.format(len(tokens)))
    print('==========')

Iter: 0
Best pair: ('e', '</w>')
Tokens: defaultdict(<class 'int'>, {'T': 1607, 'h': 26103, 'e</w>': 17758, 'P': 780, 'r': 29562, 'o': 35007, 'j': 858, 'e': 41432, 'c': 13900, 't': 44238, '</w>': 84091, 'G': 300, 'u': 13723, 'n': 32498, 'b': 7426, 'g': 8752, 'B': 1162, 'k': 2732, 'f': 10463, 'A': 1379, 'l': 20619, 'd': 17581, 'M': 1204, 'i': 31414, 's': 28311, 'a': 36695, 'y': 8828, 'w': 8155, 'U': 178, 'S': 865, 'm': 9751, 'p': 8030, 'v': 4878, '.': 4061, 'Y': 250, ',': 8065, '-': 1063, 'L': 426, 'I': 1428, ':': 201, 'J': 78, 'V': 102, 'E': 895, 'R': 369, '6': 73, '2': 160, '0': 402, '5': 124, '[': 32, '#': 1, '1': 291, '4': 99, '7': 60, ']': 32, 'D': 322, 'C': 862, 'K': 41, 'O': 510, '/': 31, '*': 22, 'F': 419, 'H': 688, 'N': 793, '"': 4064, '!': 1214, 'W': 576, '3': 104, "'": 1236, 'Q': 33, 'X': 49, 'Z': 10, '?': 651, '8': 73, '9': 36, '_': 1426, 'à': 3, 'x': 937, 'z': 364, '°': 41, 'q': 575, ';': 561, '(': 53, ')': 53, '{': 23, '}': 16, 'è': 2, 'é': 14, '+': 2, '=': 3, 'ö': 2, 'ê':