In [53]:
# https://github.com/karpathy/minGPT/blob/master/mingpt/bpe.py

In [1]:
import os
import json
import regex as re
import requests

import torch


def bytes_to_unicode():
    """
    Every possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode
    character that represents it visually. Some bytes have their appearance preserved
    because they don't cause any trouble. These are defined in list bs. For example:
    chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!".
    However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps these
    bytes, into new characters in a range where chr() returns a single nice character.
    So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8).
    In particular, the space character is 32, which we can see by ord(' '). Instead,
    this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'.
    So this is just a simple one-to-one mapping of bytes 0..255 into unicode characters
    that "look nice", either in their original form, or a funny shifted character
    like 'Ā', or 'Ġ', etc.
    """
    # the 188 integers that render fine in their original form and need no shifting
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict
    # now get the representations of the other 68 integers that do need shifting
    # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop
    n = 0
    for b in range(2**8):
        if b not in bs:
            # if this byte is "ugly" then map it to the next available "nice" character
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    d = dict(zip(bs, cs))
    return d

def get_pairs(word):
    """
    Return all bigrams as a set of tuples, of consecutive elements in the iterable word.
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

In [112]:
class Encoder:

    def __init__(self, encoder, bpe_merges, added_tokens):
        # byte encoder/decoder
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        
        # bpe token encoder/decoder
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        
        # bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        # the splitting pattern used for pre-tokenization
        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment
        
        """
        ok so what is this regex looking for, exactly?
        python re reference: https://docs.python.org/3/library/re.html
        - the vertical bars | is OR, so re.findall will chunkate text as the pieces match, from left to right
        - '\'s' would split up things like Andrej's -> (Andrej, 's)
        - ' ?\p{L}': optional space followed by 1+ unicode code points in the category "letter"
        - ' ?\p{N}': optional space followed by 1+ unicode code points in the category "number"
        - ' ?[^\s\p{L}\p{N}]+': optional space, then 1+ things that are NOT a whitespace, letter or number
        - '\s+(?!\S)': 1+ whitespace characters (e.g. space or tab or etc) UNLESS they are followed by non-whitespace
                       so this will consume whitespace characters in a sequence but exclude the last whitespace in
                       that sequence. that last whitespace has the opportunity to then match the optional ' ?' in
                       earlier patterns.
        - '\s+': 1+ whitespace characters, intended probably to catch a full trailing sequence of whitespaces at end of string
        So TLDR:
        - we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens
        - we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces
        """
        pattern = "|".join(
        [
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""\p{N}{1,3}""",
            r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""",
            r"""\s*[\r\n]+""",
            r"""\s+(?!\S)""",
            r"""\s+""",
        ]
        )

        special_tokens = added_tokens
        self.special_tokens_pattern = "|".join([i["content"] if "<|" not in i["content"] else i["content"].replace("|", "\|") for i in added_tokens])

        self.special_token_regex = re.compile(self.special_tokens_pattern)
        self.pat = re.compile(pattern)
        self.cache = {}

    def bpe(self, token):
        """
        this function uses self.bpe_ranks to iteratively merge all the possible bpe tokens
        up the tree. token is a string of one individual 'word' (after regex tokenization)
        and after byte encoding, e.g. 'Ġthere'.
        """
        # token is a string of one individual 'word', after byte encoding, e.g. 'Ġthere'

        # memoization, for efficiency
        if token in self.cache:
            return self.cache[token]

        word = tuple(token) # individual characters that make up the token, in a tuple
        pairs = get_pairs(word) # get all bigrams

        if not pairs:
            return token

        while True:

            # find the next lowest rank bigram that can be merged
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break # no more bigrams are eligible to be merged
            first, second = bigram

            # we will now replace all occurences of (first, second) in the list of current
            # words into one merged token first_second, in the output list new_words
            new_word = []
            i = 0
            while i < len(word):

                # find the next occurence of first in the sequence of current words
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                # if this occurence is also followed by second, then merge them into one
                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1

            # all occurences of (first, second) have been merged to first_second
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)

        # concat all words into a string, and use ' ' as the separator. Note that
        # by now all characters have been byte encoded, guaranteeing that ' ' is
        # not used in the actual data and is a 'special' delimiter character
        word = ' '.join(word)

        # cache the result and return
        self.cache[token] = word
        return word

    def encode(self, text):
        """ string goes in, list of integers comes out """
        bpe_idx = []
        # pre-tokenize the input text into string tokens (words, roughly speaking)
        tokens = re.findall(self.pat, text)
        # process each token into BPE integers
        for token in tokens:
            # encode the token as a bytes (b'') object
            token_bytes = token.encode('utf-8')
            # translate all bytes to their unicode string representation and flatten
            token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
            # perform all the applicable bpe merges according to self.bpe_ranks
            token_merged = self.bpe(token_translated).split(' ')
            # translate all bpe tokens to integers
            token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
            # extend our running list of all output integers
            bpe_idx.extend(token_ix)
        return bpe_idx

    def encode_with_spl_token(self, text):
        bpe_idx = []
        start_index = 0
        m = self.special_token_regex.finditer(text)
        
        try:
            while True:
                next_spl = next(m)
                sub_input = input[start_index:next_spl.start()]
                tokens = re.findall(self.pat, sub_input)
                # process each token into BPE integers
                for token in tokens:
                    # encode the token as a bytes (b'') object
                    token_bytes = token.encode('utf-8')
                    # translate all bytes to their unicode string representation and flatten
                    token_translated = ''.join(encoder.byte_encoder[b] for b in token_bytes)
                    # perform all the applicable bpe merges according to self.bpe_ranks
                    token_merged = encoder.bpe(token_translated).split(' ')
                    # translate all bpe tokens to integers
                    token_ix = [encoder.encoder[bpe_token] for bpe_token in token_merged]
                    # extend our running list of all output integers
                    bpe_idx.extend(token_ix)
                bpe_idx.extend([encoder.encoder[input[next_spl.start():next_spl.end()]]])
                start_index = next_spl.end()
        except StopIteration:
            pass
        return bpe_idx


    def encode_with_spl_token_show_progess(self, text):
        bpe_idx = []
        parts = []
        start_index = 0
        m = self.special_token_regex.finditer(text)
        
        try:
            while True:
                next_spl = next(m)
                sub_input = input[start_index:next_spl.start()]
                tokens = re.findall(self.pat, sub_input)
                for token in tokens:
                    token_bytes = token.encode('utf-8')
                    token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
                    token_merged = self.bpe(token_translated).split(' ')
                    token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
                    bpe_idx.extend(token_ix)
                    parts.append({
                        'token': token,
                        'token_bytes': token_bytes,
                        'token_translated': token_translated,
                        'token_merged': token_merged,
                        'token_ix': token_ix,
                        'is_special_token': False
                    })
                bpe_idx.extend([encoder.encoder[input[next_spl.start():next_spl.end()]]])
                parts.append({
                        'token': input[next_spl.start():next_spl.end()],
                        'token_bytes': input[next_spl.start():next_spl.end()],
                        'token_translated': input[next_spl.start():next_spl.end()],
                        'token_merged': input[next_spl.start():next_spl.end()],
                        'token_ix': encoder.encoder[input[next_spl.start():next_spl.end()]],
                        'is_special_token': True
                    })
                start_index = next_spl.end()
            
        except StopIteration:
            pass
        out = {
            'bpe_idx': bpe_idx, # the actual output sequence
            'tokens': [self.decoder[i] for i in bpe_idx], # result of pre-tokenization
            'parts': parts, # intermediates for each token part
        }
        return out
        
    def encode_and_show_work(self, text):
        """ debugging function, same as encode but returns all intermediate work """
        bpe_idx = []
        parts = []
        tokens = re.findall(self.pat, text)
        for token in tokens:
            token_bytes = token.encode('utf-8')
            token_translated = ''.join(self.byte_encoder[b] for b in token_bytes)
            token_merged = self.bpe(token_translated).split(' ')
            token_ix = [self.encoder[bpe_token] for bpe_token in token_merged]
            bpe_idx.extend(token_ix)
            parts.append({
                'token': token,
                'token_bytes': token_bytes,
                'token_translated': token_translated,
                'token_merged': token_merged,
                'token_ix': token_ix,
            })
        out = {
            'bpe_idx': bpe_idx, # the actual output sequence
            'tokens': [self.decoder[i] for i in bpe_idx], # result of pre-tokenization
            'parts': parts, # intermediates for each token part
        }
        return out

    def decode(self, bpe_idx):
        """ list of integers comes in, string comes out """
        # inverse map the integers to get the tokens
        tokens_merged = [self.decoder[token] for token in bpe_idx]
        # inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes
        tokens_flat = ''.join(tokens_merged)
        tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat])
        # recover the full utf-8 string
        text = tokens_bytes.decode('utf-8', errors='replace')
        return text

In [113]:
def get_encoder():
    """
    Returns an instance of the GPT BPE Encoder/Decoder
    and handles caching of "database" files.
    """
    encoder_local_file = "trained_tokenizer/post_processed_tokenizer.json"

    with open(encoder_local_file, 'r') as f:
            encoder = json.load(f)
    # load encoder.json that has the raw mappings from token -> bpe index
    
    encode = encoder["model"]["vocab"]
    bpe_data = encoder["model"]["merges"]
    # light postprocessing: strip the version on first line and the last line is a blank
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data]
    # assert len(bpe_merges) == 50000 # 50,000 merged tokens

    additonal_tokens = encoder["added_tokens"]
    # construct the Encoder object and return
    enc = Encoder(encode, bpe_merges, additonal_tokens)
    return enc

In [114]:
encoder = get_encoder()

In [117]:
input = "<s>Hi this is jagan 😘 88888888 \n</s>this is next sentence<|start_of_turn|>"

encoder.decode(encoder.encode_with_spl_token(input))

'<s>Hi this is jagan 😘 88888888 \n</s>this is next sentence<|start_of_turn|>'

In [118]:
progress = encoder.encode_with_spl_token_show_progess(input)


print(progress['tokens'])
print()
print("Then we iterate over each chunk and process them in turn...")
for part in progress['parts']:
    print(part)

['<s>', 'Hi', 'Ġthis', 'Ġis', 'Ġj', 'agan', 'Ġ', 'ð', 'Ł', 'ĺ', 'ĺ', 'Ġ', '888', '888', '88', 'Ġ', 'Ċ', '</s>', 'this', 'Ġis', 'Ġnext', 'Ġsentence', '<|start_of_turn|>']

Then we iterate over each chunk and process them in turn...
{'token': '<s>', 'token_bytes': '<s>', 'token_translated': '<s>', 'token_merged': '<s>', 'token_ix': 0, 'is_special_token': True}
{'token': 'Hi', 'token_bytes': b'Hi', 'token_translated': 'Hi', 'token_merged': ['Hi'], 'token_ix': [17624], 'is_special_token': False}
{'token': ' this', 'token_bytes': b' this', 'token_translated': 'Ġthis', 'token_merged': ['Ġthis'], 'token_ix': [451], 'is_special_token': False}
{'token': ' is', 'token_bytes': b' is', 'token_translated': 'Ġis', 'token_merged': ['Ġis'], 'token_ix': [325], 'is_special_token': False}
{'token': ' jagan', 'token_bytes': b' jagan', 'token_translated': 'Ġjagan', 'token_merged': ['Ġj', 'agan'], 'token_ix': [551, 7994], 'is_special_token': False}
{'token': ' 😘', 'token_bytes': b' \xf0\x9f\x98\x98', 'token