# Byte-Pair Encoding Tokenizer
---
Following Andrej Karpathy's [video](https://www.youtube.com/watch?v=zduSFxRajkE) for tokenizers, building my own BPE tokenizer. From this, will use it as a basis to train my own embedding model based on Word2Vec and analyze the results of that training

## 1) Import Dependencies and Test Data

In [1]:
# Install dependencies
import regex as re

# Also defining our regex up here (from GPT-4 tokenizer)
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

In [2]:
# First, need to import data. We'll use tiny shakespeare -- will be able to clearly see what kinds of words are commonly used in the ye olde times
file_path = '../data/tinyshakespeare.txt'
with open(file_path, 'r') as file:
    data = file.read()

# Print first 100 chars
print(data[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


## 2) Defining our BPETokenizer Class

### Approach
---
For our tokenizer, we'll create a class called BPETokenizer. Upon instantiation, it will take in a regex expression (optional). Subfunctions include:
1. _encodetext: Takes in raw text and returns 2-D python list of ASCII values (if regex, split by regex; else just 2-D w/ 1 element)
2. _mergepair: helper function to take in your encoded text and a tuple (pair, token) that replaces the pair with the desired token
3. train(text, vocab_size): Function to iteratively convert most frequent byte-pairs to a new token. Saves encoding map to class
4. encode(text): Converts text to tokens; returns as 1-d python list of tokens
5. decode(tokens): Converts tokens to text; returns string.

In [23]:
class BPETokenizer():

    def __init__(self, regex=None):
        self.encoding_map = {}
        self.decoding_map = {}
        self.regex = regex

    def regex_setter(self, re):
        """ Simple setter function for regex """
        self.regex = re

    def encmap_getter(self):
        """ Simple getter to return our encoding map """
        return self.encoding_map
        
    def decmap_getter(self):
        """ Simple getter to return our decoding map """
        return self.decoding_map

    def _encodetext(self, text):
        """
            Takes in raw text and returns 2-D python array of ASCII values (e.g., [[70, 105, 114, 115, 116], [32, 67,...]])
            Replaces non-ASCII with '?'
        """
        if self.regex == None:
            text = [text]
        else:
            text = re.findall(self.regex, text)     # Converts text to ['First', ' Citizen', ':\n', 'Before', ' we', ' proceed',...]
        return list( list(t.encode("ascii", errors="replace")) for t in text )

    def _mergepair(self, tokens, pair_tok):
        """
            Takes in 2-D python list of tokens and iterates through, replacing 'tok' where 'pair' exists
        """
        pair, tok = pair_tok
        merged = []
        replaced = False
        for block in tokens:
            merged_block = []
            if len(block) <= 1:
                merged_block.extend(block)    # Simply append element and move onto next block
            else:
                for idx in range(len(block)-1):
                    if replaced:
                        replaced = False   # Knows to skip next idx if you already replaced it with a bpe token 
                    elif block[idx]==pair[0] and block[idx+1]==pair[1]:
                        merged_block.append(tok)
                        replaced = True
                    elif idx == len(block)-2:
                        merged_block.extend([block[idx], block[idx+1]])
                    else:
                        merged_block.append(block[idx])
            merged.append(merged_block)
        return merged

    def train(self, text, vocab_size):
        """
            Takes in raw text and a desired length for max vocabulary size
            Upon completion, sets encoder_map and decoder_map to the BPE mappings (forward, backward resp.) 
        """
        encoded_text = self._encodetext(text)
        encoding_map = {}                 # Create python dictionary of merges
        num_merges = vocab_size - 128     # Number of iterations of BPE
        
        for i in range(num_merges):
            bytepair_count = {}
            for block in encoded_text:
                if len(block) > 1:
                    for idx in range(len(block)-1):
                        pair = (block[idx], block[idx+1])
                        count = bytepair_count.get(pair, 0)
                        bytepair_count[pair] = count+1
            
            # Once done iterating through all the ascii values, sort and assign most freq bytepair to new token
            freq_pair = max(bytepair_count, key=bytepair_count.get)
            new_token = 128 + i
            encoding_map[freq_pair] = new_token
            encoded_text = self._mergepair(encoded_text, (freq_pair, new_token))

        self.encoding_map = encoding_map
        self.decoding_map = {value: key for key, value in encoding_map.items()}

    def encode(self, text):
        """
            Takes in raw text and returns tokenized text (1-D array)
        """
        encodings = list(self.encoding_map.items())
        encoded_text = self._encodetext(text)
        for pair_tok in encodings:
            encoded_text = self._mergepair(encoded_text, pair_tok)
        tokenized_text = [tok for sublist in encoded_text for tok in sublist]
        return tokenized_text

    def decode(self, tokens):
        """
            Takes in tokenized text (1-D array) and returns raw text
        """
        decodings = list(self.decoding_map.items())
        for tok, pair in decodings[::-1]:
            decoded_tokens = []
            for idx, t in enumerate(tokens):
                if t != tok:
                    decoded_tokens.append(t)
                    # print(t)
                else:
                    decoded_tokens.extend( [pair[0], pair[1]] )
                    # print([pair[0], pair[1]])
            tokens = decoded_tokens
            # print(''.join(chr(value) for value in tokens))

        decoded_text = ''.join(chr(value) for value in tokens)
        return decoded_text

In [40]:
trimmed_text = data[:500]
tokenizer = BPETokenizer(GPT4_SPLIT_PATTERN)
tokenizer.train(trimmed_text, 150)

In [41]:
en = tokenizer.encode(trimmed_text)

In [43]:
tokenizer.decode(en)

"Firs Citiz:\nor w proceedny furthe, hear me speak.\n\nll:\n, speak.\n\n Citiz:\n allresolvedrathe t die than t famish?\n\nll:\nsolved.resolved.\n\n Citiz:\nrs,ou knowaius Marcius is chief enemy t thepeople.\n\nll:\nnow, w know.\n\n Citiz:\nus killhim, and w'llhave cor a our ow price.\nIs't averdict?\n\nll:\nmor talking on't; let itbe done: away, away!\n\necond Citiz:\nwor, good citiz.\n\n Citiz:\n accountedpoor"

In [31]:
tokenizer.decode(en)

"Firs Citiz:\nore we proceed any further, hear me speak.\n\nll:\nak, speak.\n\n Citiz:\nre allresolved rather to die than to famish?\n\nll:\nsolved. resolved.\n\n Citiz:\nrs,you know Caius Marcius is chief enemy to the people.\n\nll:\nnow't, we know't.\n\n Citiz:\nus killhim, and we'llhave corn a our own price.\nIs't averdict?\n\nll:\nmore talking on't; let itbe done: away, away!\n\necond Citiz:\nword, good citiz.\n\n Citiz:\n accounted poor"

## 3) Analyze our Results

In [38]:
def print_tokens(encoding_map):
    """
    Takes in a dictionary called encoding_map and prints out the text that each token relates to
    """
    encodings = list(encoding_map.items())
    tok_text = {}
    for (tok_1, tok_2), mapped_token in encodings:
        text = []
        # Append the text representation of tok_1
        if tok_1 in tok_text:
            text.append(tok_text[tok_1])
        else:
            text.append(chr(tok_1))
        # Append the text representation of tok_2
        if tok_2 in tok_text:
            text.append(tok_text[tok_2])
        else:
            text.append(chr(tok_2))
        # Join the characters or strings in 'text' list into a single string
        tok_text[mapped_token] = ''.join(text)
    
    tok_text_list = list(tok_text.items())
    for tok, text in tok_text_list:
        # Replace newline characters with a visible representation
        safe_text = text.replace('\n', '\\n')
        print(f"Token: {tok}, Text: '{safe_text}'")

In [39]:
## Let's print out the text associated with our tokens that we encoded
print_tokens(tokenizer.encmap_getter())

Token: 128, Text: ':\n'
Token: 129, Text: ' a'
Token: 130, Text: '\n\n'
Token: 131, Text: 'it'
Token: 132, Text: 'en'
Token: 133, Text: ' C'
Token: 134, Text: 'iti'
Token: 135, Text: 'itiz'
Token: 136, Text: 'll'
Token: 137, Text: 'rs'
Token: 138, Text: ' Citiz'
Token: 139, Text: '.\n\n'
