### Custom implementation of a BPE Tokenizer from scratch

References:
1. GPT Tokenizer by Andrej Karpathy: https://www.youtube.com/watch?v=zduSFxRajkE
2. Byte pair encoding on Wikipedia: https://en.wikipedia.org/wiki/Byte_pair_encoding

In [28]:
char = '😂'
enc = char.encode('utf-8')
print(f"Unicode code point of {char}: {ord(char)}")
print(f"UTF-8 representation of {char}: {enc}")
print(f"Byte encoded form of {char}: {list(map(int, enc))}")

Unicode code point of 😂: 128514
UTF-8 representation of 😂: b'\xf0\x9f\x98\x82'
Byte encoded form of 😂: [240, 159, 152, 130]


In [29]:
words = "hello world!"
def encode(str):
    return list(map(int, str.encode('utf-8')))

# Byte encoded form of words
encode(words)

[104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, 33]

In [30]:
def get_pairs(tokens):
    counts = {}
    for pair in zip(tokens[0:], tokens[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts
    
sen = "I'm thinking of going to the gym today!"
stats = get_pairs(encode(sen))
top_pair = max(stats, key=stats.get)    # pair with max freq
ordered = list(sorted(stats.items(), key=lambda item: item[1], reverse=True))

print(stats)
print(f"Top pair: {top_pair}")


{(73, 39): 1, (39, 109): 1, (109, 32): 2, (32, 116): 4, (116, 104): 2, (104, 105): 1, (105, 110): 3, (110, 107): 1, (107, 105): 1, (110, 103): 2, (103, 32): 2, (32, 111): 1, (111, 102): 1, (102, 32): 1, (32, 103): 2, (103, 111): 1, (111, 105): 1, (116, 111): 2, (111, 32): 1, (104, 101): 1, (101, 32): 1, (103, 121): 1, (121, 109): 1, (111, 100): 1, (100, 97): 1, (97, 121): 1, (121, 33): 1}
Top pair: (32, 116)


In [31]:
def merge(ids, pair, idx):
    # in the list ids look for matching pairs and replace it with idx
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

num = [1, 2, 3, 4, 5, 6]
m1 = merge(num, (5, 6), 99)    # merge1: (5, 6) -> 99
m2 = merge(m1, (4, 99), 100)   # merge2: (4, 99) -> 100
print(m1)
print(m2)

[1, 2, 3, 4, 99]
[1, 2, 3, 100]


In [32]:
text = "Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.[4] Its modification is notable as the large language model tokenizer with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words).[5][6][7] This modification, in the first step, assumes all unique characters to be an initial set of 1-character long n-grams (i.e. initial \"tokens\"). Then, successively, the most frequent pair of adjacent characters is merged into a new, 2-character long n-gram and all instances of the pair are replaced by this new token. This is repeated until a vocabulary of prescribed size is obtained. Note that new words can always be constructed from final vocabulary tokens and initial-set characters.[8] This algorithmic approach has been extended from spoken language to sign language in recent years.[9]"
tokens = text.encode('utf-8')
tokens = list(map(int, tokens))

# Tokens before merging
pairs = get_pairs(tokens)
top_pair = max(pairs, key=pairs.get)
print(F"Top pair in the tokens is: {top_pair}")
print(f"Length of tokens before merging: {len(tokens)}")

# Tokens after merging
tokens2 = merge(tokens, top_pair, 256)
print(f"Length of tokens after merging: {len(tokens2)}")

Top pair in the tokens is: (101, 32)
Length of tokens before merging: 1054
Length of tokens after merging: 1027


The base vocabulary typically includes individual characters and special tokens. In many implementations, this base set is assumed to be 256 tokens, corresponding to the standard ASCII character set. Therefore, num_merges is calculated as vocab_size - 256 to determine the number of merges needed to reach the desired vocabulary size beyond the initial character set.

In [33]:
# get byte pairs in the tokens
def get_pairs(tokens):
    counts = {}
    for pair in zip(tokens[0:], tokens[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

# merge a pair in the ids and replace with new idx
def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

vocab_size = 376   # final desired vocab length
num_merges = vocab_size - 256 
ids = list(tokens)

# apply merging in a loop to reach desired vocab_length
merges = {} # keep a track of all merges
char_merges = {} # store a character representation of the merges
chr_idx_mapping = {} # char to index mapping

def resolve_index(idx):
    if idx <= 255:
        return chr(idx) # idx is a valid Unicode point
    else:
        return chr_idx_mapping[idx] # lookup in chr_idx_mapping

for i in range(num_merges):
    pairs = get_pairs(ids)
    top_pair = max(pairs, key=pairs.get)
    idx = 256 + i
    ids = merge(ids, top_pair, idx)
    merges[top_pair] = idx
    
    if idx > 255:
    # Resolve characters using the recursive function
        char1 = resolve_index(top_pair[0])
        char2 = resolve_index(top_pair[1])
        if char1 is not None and char2 is not None:
            chr_idx_mapping[idx] = char1 + char2  # Concatenate characters if both are valid
            
    char_merges[char1, char2] = chr_idx_mapping[idx]

In [34]:
print("Length of tokens before merging:", len(tokens))
print("Length of tokens(ids) after merging:", len(ids))
print(f"Compression Ratio: {(len(tokens) / len(ids)):.2f}X")

Length of tokens before merging: 1054
Length of tokens(ids) after merging: 474
Compression Ratio: 2.22X


In [35]:
merges

{(101, 32): 256,
 (115, 32): 257,
 (105, 110): 258,
 (101, 110): 259,
 (32, 97): 260,
 (116, 104): 261,
 (116, 32): 262,
 (97, 114): 263,
 (100, 32): 264,
 (116, 101): 265,
 (116, 111): 266,
 (111, 100): 267,
 (258, 103): 268,
 (105, 257): 269,
 (111, 114): 270,
 (97, 99): 271,
 (97, 110): 272,
 (97, 108): 273,
 (32, 115): 274,
 (116, 105): 275,
 (111, 110): 276,
 (114, 101): 277,
 (107, 259): 278,
 (99, 104): 279,
 (279, 263): 280,
 (280, 271): 281,
 (281, 265): 282,
 (282, 114): 283,
 (105, 114): 284,
 (99, 267): 285,
 (93, 32): 286,
 (101, 264): 287,
 (121, 32): 288,
 (103, 256): 289,
 (111, 102): 290,
 (108, 256): 291,
 (261, 256): 292,
 (266, 278): 293,
 (259, 285): 294,
 (97, 109): 295,
 (44, 32): 296,
 (115, 262): 297,
 (258, 32): 298,
 (97, 289): 299,
 (290, 32): 300,
 (97, 98): 301,
 (108, 263): 302,
 (109, 267): 303,
 (46, 91): 304,
 (111, 109): 305,
 (117, 110): 306,
 (84, 104): 307,
 (32, 108): 308,
 (46, 32): 309,
 (112, 97): 310,
 (310, 284): 311,
 (93, 91): 312,
 (105, 1

In [36]:
for item in char_merges.items():
    ch1, ch2 = item[0]
    print(f"'{ch1}' + '{ch2}' --> {item[1]}")

'e' + ' ' --> e 
's' + ' ' --> s 
'i' + 'n' --> in
'e' + 'n' --> en
' ' + 'a' -->  a
't' + 'h' --> th
't' + ' ' --> t 
'a' + 'r' --> ar
'd' + ' ' --> d 
't' + 'e' --> te
't' + 'o' --> to
'o' + 'd' --> od
'in' + 'g' --> ing
'i' + 's ' --> is 
'o' + 'r' --> or
'a' + 'c' --> ac
'a' + 'n' --> an
'a' + 'l' --> al
' ' + 's' -->  s
't' + 'i' --> ti
'o' + 'n' --> on
'r' + 'e' --> re
'k' + 'en' --> ken
'c' + 'h' --> ch
'ch' + 'ar' --> char
'char' + 'ac' --> charac
'charac' + 'te' --> characte
'characte' + 'r' --> character
'i' + 'r' --> ir
'c' + 'od' --> cod
']' + ' ' --> ] 
'e' + 'd ' --> ed 
'y' + ' ' --> y 
'g' + 'e ' --> ge 
'o' + 'f' --> of
'l' + 'e ' --> le 
'th' + 'e ' --> the 
'to' + 'ken' --> token
'en' + 'cod' --> encod
'a' + 'm' --> am
',' + ' ' --> , 
's' + 't ' --> st 
'in' + ' ' --> in 
'a' + 'ge ' --> age 
'of' + ' ' --> of 
'a' + 'b' --> ab
'l' + 'ar' --> lar
'm' + 'od' --> mod
'.' + '[' --> .[
'o' + 'm' --> om
'u' + 'n' --> un
'T' + 'h' --> Th
' ' + 'l' -->  l
'.' + ' ' --> . 


In [77]:
class BPETokenizer:
    
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
        self.base_vocab = 256
        self.num_merges = self.vocab_size - self.base_vocab
        self.merges = {}
        self.vocab = {}
        
        
    def get_pairs(self, ids):
        counts = {}
        for pair in zip(ids[0:], ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts

    def merge(self, ids, pair, idx):
        new_ids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                new_ids.append(idx)
                i += 2
            else:
                new_ids.append(ids[i])
                i += 1
        return new_ids
        
    def display_merges(self):
        for item in self.char_merges.items():
            if item is not None:
                ch1, ch2 = item[0]
                print(f"'{ch1}' + '{ch2}' --> {item[1]}")
        
    # fit the model on the input vocab
    def fit(self, text):
        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
        
        ids = list(text.encode('utf-8'))
        for i in range(self.num_merges):
            pairs = self.get_pairs(ids)
            top_pair = max(pairs, key=pairs.get)
            idx = 256 + i
            ids = self.merge(ids, top_pair, idx)
            
            # save the merge
            merges[top_pair] = idx
            vocab[idx] = vocab[top_pair[0]] + vocab[top_pair[1]]
        
        self.merges = merges
        self.vocab = vocab
            
            
    # encode a string of text into tokens
    def encode(self, text):
        tokens = list(text.encode('utf-8'))
        while len(tokens) >= 2:
            pairs = self.get_pairs(tokens)
            pair = min(pairs, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # nothing else can be merged
            idx = self.merges[pair]
            tokens = self.merge(tokens, pair, idx)
        return tokens
        
    # decode a list of tokens back into string
    def decode(self, ids):
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]

        tokens = b"".join(self.vocab[idx] for idx in ids)
        text = tokens.decode("utf-8", errors="replace")
        return text

Checking the performance of our BPE Tokenizer on some sample text.

Results:
Training on a sample of approx `5.5M` characters with `vocab_size=8000`

In [78]:
with open('sample.txt', 'r') as file:
    text = file.read()
        
bpe = BPETokenizer(vocab_size=300)
bpe.fit(text)   # fitting/training the tokenizer on the text