In [3]:
import re 
from collections import defaultdict 
import string

def get_init_vocab(data): 
    """ 
    Given a list of strings, returns a dictionary of words mapping to their frequency  
    count in the data. 
    Args: 
        data: raw text with line breaks
        
    Returns: 
        (vocab, tokens) tuple, 
          vocab is a dictionary mapping space delimited characters to count (e.g. {'a b c </w>': 5})
          tokens is a set of basic characters. 
    """
    vocab = defaultdict(int)
    tokens = set()
    tokens.add('</w>')
    for line in data: 
        for word in line.split(): 
            vocab[' '.join(list(word)) + ' </w>'] += 1
            tokens.update(list(word))
    return vocab, tokens 
  
def count_cooccurance(vocab): 
    """ 
    Given a vocabulary (dictionary mapping words to frequency counts), returns a  
    dictionary of tuples representing the frequency count of pairs of characters  
    in the vocabulary. 
    Args:
        vocab: a dictionary mapping space-delimited tokens to count (e.g. {'a b c </w>': 5})
        
    Returns: 
        a dictionary mapping a tuple of tokens to count
    """
    pairs = defaultdict(int) 
    for word, freq in vocab.items(): 
        chars = word.split() # split the word by any white space
        for i in range(len(chars)-1): 
            pairs[chars[i], chars[i+1]] += freq 
    return pairs
  
def merge_vocab(token_pair, vocab_in): 
    """ 
    Given a pair of tokens and a vocabulary, returns a new vocabulary with the  
    pair of tokens merged together wherever they appear. 
    
    e.g. merge_vocab(('a', 'b'), {'a b c </w>': 5})
    returns {'ab c </w>': 5}
    
    Args: 
        token_pair: a tuple of two tokens
        vocab_in: a dictionary mapping space-delimited tokens to count (e.g. {'a b c </w>': 5})
        
    Returns: 
        a dictionary mapping space-delimited tokens to count (e.g. {'a b c </w>': 5})
    """
    vocab_out = defaultdict(int)  
    bigram = re.escape(' '.join(token_pair)) 
    new_token = ''.join(token_pair)
    # search for every occurance of bigram (token pairs with a space), 
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)') 
    for word in vocab_in:
        # replace the bigram (with space), with the new merged token (the concanated pair)
        w_out = p.sub(new_token, word)
        vocab_out[w_out] = vocab_in[word]
    return vocab_out

  
def byte_pair_encoding(data, n): 
    """ 
    Given a list of strings and an integer n, returns a list of n merged pairs 
    of characters found in the vocabulary of the input data. 
    
    Args: 
        data: raw text
        n: number of merge opperations
    
    Returns: 
        a list of tokens
        a dictionary mapping token to index (starting from 0)
    """
    vocab, init_tokens = get_init_vocab(data)
    tokens = list(init_tokens)
    for i in range(n): 
        pairs = count_cooccurance(vocab) 
        best_pair = max(pairs, key=pairs.get) 
        new_token = ''.join(best_pair)
        tokens.append(new_token)
        vocab = merge_vocab(best_pair, vocab)
        print('step {}: merging \"{}\" and \"{}\"'.format(i+1, best_pair[0], best_pair[1]))
    token_to_ids = dict([(tk, id) for id, tk in enumerate(tokens)])
    return tokens, token_to_ids

def tokenize(data, token_dict):
    """
    split the data into tokens and map into index. 
    It applies greedy split to text with longest matching.
    
    e.g. 
    tokenize("spiderman", {'spider':0,'man': 1})
    will return
     [0, 1]
        
    Args: 
        data: raw text
        token_dict: a dictionary mapping from token to id
        
    Returns: 
        a list of ids
        
    """
    encoded_ids = []
    for line in data: 
        for word in line.split():
            word = word + '</w>'
            last_idx = 0
            idx = len(word)
            while idx > last_idx:
                whole_word = word[last_idx:idx]
                if whole_word in token_dict:
                    encoded_ids.append(token_dict[whole_word])
                    last_idx = idx
                    idx = len(word)
                else:
                    idx = idx - 1
    return encoded_ids
  



In [4]:
# Example usage: 
corpus = '''Berman's parents divorced when he was seven. 
Thereafter, he split time between each parent's household until he entered college.[6] 
His father relocated to Dallas for a position as a lobbyist on behalf of foodservice businesses, 
while his mother moved back in with her parents in Wooster, Ohio, and became a teacher there'''
data = corpus.split('.') 
  
n = 20 # number of merge operations
id_to_tokens, token_to_ids = byte_pair_encoding(data, n)

token_ids = tokenize(data, token_to_ids)

print("The bpe tokens are: ")
for tk, tid in token_to_ids.items():
    print("{}: {}".format(tk, tid))

print("The ids of the tokenized sequence are: ")
print(token_ids)
print()
print("The sequence corresponding to ids is: ")
print(' '.join(id_to_tokens[tid] for tid in token_ids))

step 1: merging "e" and "r"
step 2: merging "s" and "</w>"
step 3: merging "e" and "</w>"
step 4: merging "e" and "n"
step 5: merging "d" and "</w>"
step 6: merging "h" and "er"
step 7: merging "en" and "t"
step 8: merging "e" and "d</w>"
step 9: merging "," and "</w>"
step 10: merging "her" and "</w>"
step 11: merging "n" and "</w>"
step 12: merging "p" and "a"
step 13: merging "pa" and "r"
step 14: merging "par" and "ent"
step 15: merging "en" and "</w>"
step 16: merging "h" and "e</w>"
step 17: merging "a" and "s</w>"
step 18: merging "s" and "e"
step 19: merging "e" and "a"
step 20: merging "i" and "t"
The bpe tokens are: 
W: 0
c: 1
[: 2
b: 3
w: 4
a: 5
s: 6
d: 7
m: 8
T: 9
f: 10
y: 11
</w>: 12
k: 13
t: 14
H: 15
o: 16
': 17
O: 18
p: 19
D: 20
e: 21
B: 22
n: 23
i: 24
]: 25
h: 26
,: 27
u: 28
6: 29
l: 30
g: 31
v: 32
r: 33
er: 34
s</w>: 35
e</w>: 36
en: 37
d</w>: 38
her: 39
ent: 40
ed</w>: 41
,</w>: 42
her</w>: 43
n</w>: 44
pa: 45
par: 46
parent: 47
en</w>: 48
he</w>: 49
as</w>: 50
se: 51