# Importing Libs

In [7]:
import wikipediaapi
import re
from tqdm import tqdm

# Byte pair encoding

In [8]:
class BPE:
    def __init__(self):
        pass

    def fit(self, corpus, num_iteration = 8):
        punc_removed_corpus = re.sub(r'[^a-zA-Z0-9]', ' ', corpus).lower()      # replacing all the punctuations with space
        corpus_vocab = [w+'_' for w in punc_removed_corpus.split()]             # putting '_' at word endings
                
                
        self.BPE_vocab = set()
        self.BPE_letters = set()
        self.BPE_merges = dict()
        self.corpus_vocab_counts = dict()

        for word in corpus_vocab:  
            self.BPE_vocab.update(word) 
            self.BPE_letters.update(word)          
            word = " ".join(word)
            self.corpus_vocab_counts[word] = self.corpus_vocab_counts.get(word,0) + 1
            
        self.BPE_vocab.add("<UNK>")

                
        for _ in tqdm(range(num_iteration)):
            
            bigram_counts = dict()
            for w in self.corpus_vocab_counts.keys():
                bigrams = zip(w.split(), w.split()[1:])
                for bigram in bigrams:
                    bigram_counts[bigram] = bigram_counts.get(bigram, 0) + self.corpus_vocab_counts[w]
                    
            max_bigram = max(bigram_counts, key=bigram_counts.get)
            # print(max_bigram, bigram_counts[max_bigram])

            old_keys = list(self.corpus_vocab_counts.keys())
            for w in old_keys:
                if " ".join(max_bigram) in w:
                    self.corpus_vocab_counts[w.replace(" ".join(max_bigram), 
                                                "".join(max_bigram), 1)] = self.corpus_vocab_counts.pop(w)
            
            self.BPE_vocab.add("".join(max_bigram))
            self.BPE_merges[max_bigram] = bigram_counts[max_bigram]
            
            
            
    def tokenize(self, test_data):
        punc_removed_test = re.sub(r'[^a-zA-Z0-9]', ' ', test_data).lower()      # replacing all the punctuations with space
        splits = [[l for l in word]+["_"] for word in punc_removed_test.split()]
            
        for pair in self.BPE_merges:
            for idx, split in enumerate(splits):
                i = 0
                while i < len(split) - 1:
                    if split[i] == pair[0] and split[i + 1] == pair[1]:
                        split = split[:i] + ["".join(pair)] + split[i + 2 :]
                    else:
                        i += 1
                splits[idx] = split

        for split in splits:
            for i in range(len(split)):
                if (len(split[i]) == 1) & (split[i] not in self.BPE_letters):
                    split[i] = '<UNK>'
                
        return splits


## BPE on wikipedia data

In [9]:
wiki = wikipediaapi.Wikipedia('dipanbanik/0.0 (dipanthedataguy@gmail.com)','en')
page = wiki.page("Natural_language_processing")

corpus = page.text

In [10]:
corpus[:100]


'Natural language processing (NLP) is a subfield of computer science and especially artificial intell'

In [11]:
bpe = BPE()
bpe.fit(corpus, num_iteration=1000)


100%|██████████| 1000/1000 [00:03<00:00, 268.79it/s]


In [12]:
bpe.BPE_merges

{('e', '_'): 893,
 ('s', '_'): 710,
 ('i', 'n'): 513,
 ('t', 'h'): 483,
 ('t', 'i'): 410,
 ('a', 'n'): 407,
 ('e', 'n'): 6,
 ('d', '_'): 380,
 ('o', 'n'): 350,
 ('e', 'r'): 311,
 ('t', '_'): 308,
 ('o', 'r'): 308,
 ('a', 'l'): 298,
 ('y', '_'): 244,
 ('th', 'e_'): 236,
 ('g', '_'): 232,
 ('o', 'f'): 220,
 ('a', 'r'): 219,
 ('r', 'e'): 210,
 ('of', '_'): 208,
 ('on', '_'): 188,
 ('t', 'e'): 176,
 ('a', '_'): 175,
 ('c', 'h'): 170,
 ('al', '_'): 165,
 ('s', 'e'): 157,
 ('o', '_'): 154,
 ('in', 'g_'): 154,
 ('a', 'ti'): 146,
 ('r', 'o'): 138,
 ('i', 'c'): 126,
 ('e', 's_'): 125,
 ('an', 'd_'): 114,
 ('t', 'o_'): 108,
 ('i', 's_'): 107,
 ('in', '_'): 106,
 ('g', 'u'): 100,
 ('c', 'o'): 98,
 ('d', 'e'): 95,
 ('u', 'r'): 94,
 ('l', 'y_'): 90,
 ('or', '_'): 89,
 ('l', 'e'): 88,
 ('a', 'g'): 87,
 ('a', 'm'): 86,
 ('p', 'ro'): 86,
 ('s', 't'): 83,
 ('ch', '_'): 81,
 ('a', 't'): 76,
 ('w', 'or'): 76,
 ('ati', 'on_'): 75,
 ('a', 's_'): 74,
 ('a', 's'): 71,
 ('s', 'u'): 71,
 ('c', 'e_'): 71,
 ('e'

In [13]:
test_data  = "A lot of Machine Learning algos are used in NLP, but transformers are the best right now."

tokenized_test_data = bpe.tokenize(test_data)
tokenized_test_data

[['a_'],
 ['lo', 't_'],
 ['of_'],
 ['ma', 'ch', 'in', 'e_'],
 ['lear', 'n', 'ing_'],
 ['alg', 'o', 's_'],
 ['ar', 'e_'],
 ['u', 'se', 'd_'],
 ['in_'],
 ['nlp_'],
 ['but_'],
 ['trans', 'form', 'ers_'],
 ['ar', 'e_'],
 ['the_'],
 ['be', 's', 't_'],
 ['right_'],
 ['now_']]

# Wordpiece Encoding

In [None]:
class WordPiece:
    def __init__(self):
        pass

    def fit(self, corpus, num_iteration = 8):
        punc_removed_corpus = re.sub(r'[^a-zA-Z0-9]', ' ', corpus).lower()      # replacing all the punctuations with space
        corpus_vocab = [w+'_' for w in punc_removed_corpus.split()]             # putting '_' at word endings
                
                
        self.BPE_vocab = set([chr(i) for i in range(ord('a'), ord('z') + 1)])
        self.corpus_vocab_counts = dict()

        for word in corpus_vocab:            
            word = " ".join(word)
            self.corpus_vocab_counts[word] = self.corpus_vocab_counts.get(word,0) + 1

                
        for _ in range(num_iteration):
            
            bigram_counts = dict()
            mono_counts = dict()
            
            for w in self.corpus_vocab_counts.keys():
                bigrams = zip(w.split(), w.split()[1:])
                for bigram in bigrams:
                    bigram_counts[bigram] = bigram_counts.get(bigram, 0) + self.corpus_vocab_counts[w]
                    mono_counts[bigram[0]] = mono_counts.get(bigram[0], 0) + self.corpus_vocab_counts[w]
                mono_counts[bigram[1]] = mono_counts.get(bigram[1], 0) + self.corpus_vocab_counts[w]
            
            for bigram in bigram_counts:
                bigram_counts[bigram] = bigram_counts[bigram]/(mono_counts[bigram[0]] * mono_counts[bigram[1]])
                    
            max_bigram = max(bigram_counts, key=bigram_counts.get)
            print(max_bigram, bigram_counts[max_bigram])

            old_keys = list(self.corpus_vocab_counts.keys())
            for w in old_keys:
                if " ".join(max_bigram) in w:
                    self.corpus_vocab_counts[w.replace(" ".join(max_bigram), 
                                                "".join(max_bigram))] = self.corpus_vocab_counts.pop(w)
            
            self.BPE_vocab.add("".join(max_bigram))      
            
    def tokenize(self, test_data):
        BPE_vocab_sorted = sorted(list(self.BPE_vocab), key=len, reverse=True)
        punc_removed_test = re.sub(r'[^a-zA-Z0-9]', ' ', test_data).lower()      # replacing all the punctuations with space
        test_vocab = [w+'_' for w in punc_removed_test.split()]

        tokenized_test_data = []
        for word in test_vocab:
            # print(word)
            tokenized_word = []
            for token in BPE_vocab_sorted:
                if word == '':
                    break
                else:
                    if token in word:
                        tokenized_word.append(token)
                        word = word.replace(token,"")
                        # print(word)
                        
                        
            tokenized_test_data.append(tokenized_word)
            
        return tokenized_test_data