<a href="https://colab.research.google.com/github/blooming-ai/generativeai/blob/main/text/byte_pair_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to Byte Pair Encoding





### The following BPE code is adapted from the paper:
#### [Neural Machine Translation of Rare Words with Subword Units](https://arxiv.org/pdf/1508.07909.pdf)
#### Few code snippets are taken from https://github.com/karpathy/minGPT


In [None]:
import re
import pdb
import string
from collections import defaultdict

#Byte pair encoding
def word_to_charachter_tuple(word:str)->tuple:
    '''
    Converts a word into a tuple of characters along with an end character.
    "word" -> ('w','o','r','d','</w>')
    '''
    word.strip()
    word = "".join(ch for ch in word if ch.isalnum()) # keep only alpha-numeric characters
    _lst = list(word.lower())
    _lst.append("<\w>") # add end of word
    return tuple(_lst)

def get_pairs(word_as_tuple:tuple)->list:
    '''
    returns ('w','o','r','d','</w>') -> [('w','o'),('o,'r'),('r','d'),('d','</w>')]
    '''
    output = []
    for i in range(len(word_as_tuple)-1):
        output.append((word_as_tuple[i],word_as_tuple[i+1]))
    return output

def replace_pair(word_as_tuple:tuple, pair:tuple)->tuple:
    '''
    Given word = ('w','o','r','d','</w>') and pair = ('o,'r')
    returns ('w','or','d','</w>'). Replacement happens for each occurance of the pair
    '''
    word = word_as_tuple
    new_word = list()
    is_last_char_used = False
    i=0
    while i < len(word)-1:
        if (word[i],word[i+1]) == pair:
            new_word.append( word[i]+word[i+1] )
            if i == len(word)-2: is_last_char_used = True
            i += 1 # skip the next merged character
        else:
            new_word.append(word[i])

        i += 1

    if not is_last_char_used : new_word.append(word[len(word)-1])

    return tuple(new_word)

def construct_word_vocab(word_vocab:defaultdict, file_path:str)->dict:
    '''
    Read file and update word_vocab dict. word_vocab has format word_vocab[('w','o','r','d')]->freq
    '''
    with open(file_path) as fp:
        for line in fp:
            for item in line.split(): #split ignore multiple spaces
                item_as_tuple = word_to_charachter_tuple(item) # tuple to make a hashable dict key
                if len(item_as_tuple) == 0: continue #ignore empty key
                word_vocab[item_as_tuple] += 1

    return word_vocab

def get_byte_pair_hist(word_vocab:defaultdict)->dict:
    '''
    Read word_vocab[('w','o','r','d','\w')]->freq and construct byte pair histogram
    returns pair[('w','o')]->freq
    '''
    pair = defaultdict(int)
    for word, freq in word_vocab.items():
        for bigram in get_pairs(word):
            pair[bigram] += freq

    return pair

def merge_pair(pair:tuple, word_vocab_in:dict)->dict:
    '''
    merge the input pair in the key of the word_vocab dict. E.g. pair=('w','o') then update
    word_vocab_in[('w','o','r','d')]->freq to
    word_vocab_in[('wo','r','d')]->freq
    '''
    word_vocab_out = {}
    for word, freq in word_vocab_in.items():
        new_word = replace_pair(word, pair)
        word_vocab_out[new_word] = freq
    return word_vocab_out

def byte_pair_encoding(word_vocab:defaultdict, n:int=10)->dict:
    '''
    Given word_vocab[('w','o','r','d')]->freq merges pairs with highest frequency 'n' times.
    E.g. a merge involves replacing ('w','o'), bigram with highest freq., as word_vocab_in[('wo','r','d')]->freq.
    returns:
    Merge rank bpe_ranks[ ('w','o') ]-> 0 (implies first merge),
    Merged word vocab - word_vocab[('wo','r','d')]->freq
    '''
    i = 0
    merges = []
    for i in range(n):
        pairs = get_byte_pair_hist(word_vocab)
        best = max(pairs, key=pairs.get)
        word_vocab = merge_pair(best, word_vocab)
        merges.append(best)

    # bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab
    bpe_ranks = dict(zip(merges, range(len(merges))))
    return bpe_ranks, word_vocab

def get_bpe_encoder_decoder_map(word_vocab:defaultdict)->tuple:
    '''
    Given word_vocab[('wo','r','d')]->freq of merged words
    returns:
    Encoder encoder['wo')] -> id
    Decoder decoder[ id ] -> 'wo'
    '''
    # assign an id for each token
    bpe_encoder = {}; bpe_decoder = {}; id = 0
    for key, value in word_vocab.items():
        for token in key:
            if token not in bpe_encoder:
                bpe_encoder[token] = id
                bpe_decoder[id] = token
                id+=1

    return bpe_encoder, bpe_decoder



In [None]:
#Get data file
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
!mv input.txt sample_data/

In [None]:
file_path = "sample_data/input.txt"

In [None]:
#Testing:
word_vocab = defaultdict(int)
word_vocab = construct_word_vocab(word_vocab, file_path)
print("is empty string a key in word_vocab:",() in word_vocab)
pairs = get_byte_pair_hist(word_vocab)
best = max(pairs, key=pairs.get)
word_vocab_out = merge_pair(('f','i'), word_vocab)

print("word vocab -> ",word_vocab)
print("byte pair hist -> ",pairs)
print("best pair -> ",best)
print("merge f,i-> ",word_vocab_out)





In [None]:
# initialize word_vocab
initial_vocab = list(string.ascii_lowercase)
initial_vocab.extend(list(string.digits))
initial_vocab = [(str(val)) for val in initial_vocab] # convert each key to a tuple
word_vocab = defaultdict(int,zip(initial_vocab, [1]*len(initial_vocab)))

#construct from a file
word_vocab = construct_word_vocab(word_vocab, file_path)
merge_ranks, word_vocab = byte_pair_encoding(word_vocab, 200)
bpe_encoder, bpe_decoder = get_bpe_encoder_decoder_map(word_vocab)

print("number of unique words: ",len(word_vocab))
print("number of tokens: ", len(bpe_encoder))
print("bpe encoder map-> ", bpe_encoder)
print("bpe decoder map-> ", bpe_decoder)
print("merges -> ", merge_ranks)


In [None]:
cache={}
def bpe_tokenize(input:str, merge_ranks:dict, bpe_encoder:dict)->list:
    tokens = []
    for word in input.split():
        if word in cache: return cache[word]

        word_tuple = word_to_charachter_tuple(word)
        while True:
            if len(word_tuple) == 1: break #Cannot get pair from a single element
            pairs = get_pairs(word_tuple)
            bigram = min(pairs, key = lambda pair: merge_ranks.get(pair, float('inf'))) # find the next lowest rank bigram that can be merged
            if bigram not in merge_ranks: break # no more bigrams are eligible to be merged
            word_tuple = replace_pair(word_tuple, bigram)

        for token in word_tuple:
            if token in bpe_encoder: tokens.append( bpe_encoder[token] )
            else: raise Exception("unknown token: "+ token)

    return tokens


In [None]:
#Test
line = "I am tokenizing"
tokens = bpe_tokenize(line, merge_ranks, bpe_encoder)
print("Tokens ids -> ",tokens)
print("Tokens -> ",[bpe_decoder[key] for key in tokens])
reconstruction = [bpe_decoder[key] for key in tokens]
reconstruction = "".join(reconstruction)
print("Reconstruction-> ",reconstruction.replace('<\w>', ' '))