# GPT Tokenizer
This tokenizer follows the video created by [Andrej Karpathy](https://www.youtube.com/watch?v=zduSFxRajkE). <br> <br>
We use the Byte Pair Encoding algorithm to turn characters into tokens. We start by encoding each character into its UTF-8 byte representation, then we find the pair of tokens that occur most frequently. This new pair of tokens are replaced by a new minted token that was not used before. We keep repeating this until we are happy with our vocabulary.

In [1]:
#Create sample text
text = "Hello tokenizer 😄🤪🧐!"

#Encode to UTF-8 bytes
utf = text.encode('utf-8')

#Map bytes to integers
tokens = list(utf)


print(f'Text: {text} length: {len(text)} \nBytes: {utf} length: {len(utf)}\nIntegers: {tokens} length: {len(tokens)}\n')

Text: Hello tokenizer 😄🤪🧐! length: 20 
Bytes: b'Hello tokenizer \xf0\x9f\x98\x84\xf0\x9f\xa4\xaa\xf0\x9f\xa7\x90!' length: 29
Integers: [72, 101, 108, 108, 111, 32, 116, 111, 107, 101, 110, 105, 122, 101, 114, 32, 240, 159, 152, 132, 240, 159, 164, 170, 240, 159, 167, 144, 33] length: 29



In [2]:
#Go through each pair of integers and count how many times they appear
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        if pair in counts:
            counts[pair] += 1
        else:
            counts[pair] = 1
    return counts

In [3]:
#Get the raw stats
stats = get_stats(tokens)

#Print the stats
print(sorted(stats.items(), key=lambda x: x[1], reverse=True))

[((240, 159), 3), ((72, 101), 1), ((101, 108), 1), ((108, 108), 1), ((108, 111), 1), ((111, 32), 1), ((32, 116), 1), ((116, 111), 1), ((111, 107), 1), ((107, 101), 1), ((101, 110), 1), ((110, 105), 1), ((105, 122), 1), ((122, 101), 1), ((101, 114), 1), ((114, 32), 1), ((32, 240), 1), ((159, 152), 1), ((152, 132), 1), ((132, 240), 1), ((159, 164), 1), ((164, 170), 1), ((170, 240), 1), ((159, 167), 1), ((167, 144), 1), ((144, 33), 1)]


In [4]:
top_pair = max(stats, key=lambda x: x[1])
print(f'Top pair: {top_pair[0] } appears {top_pair[1]} times')

Top pair: 32 appears 240 times


In [5]:
#Take a list of ids, and a tuple pair, then search and replace that pair with the new idx
def merge_tokens(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and (ids[i], ids[i+1]) == pair:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    return new_ids

In [6]:
merge_tokens(tokens, top_pair[0], -42)

[72,
 101,
 108,
 108,
 111,
 32,
 116,
 111,
 107,
 101,
 110,
 105,
 122,
 101,
 114,
 32,
 240,
 159,
 152,
 132,
 240,
 159,
 164,
 170,
 240,
 159,
 167,
 144,
 33]

In [7]:
num_merges = 10

merges = {}
idx = 256 #Start from the first non-ascii integer
ids = list(tokens) #Copy the list of tokens so that we don't modify the original list
for i in range(num_merges):
    stats = get_stats(ids)
    top_pair = max(stats, key=lambda x: x[1])
    idx += 1
    print(f'Merging: {top_pair} to the new token: {idx}')
    ids = merge_tokens(ids, top_pair, idx)
    merges[top_pair] = idx

    



Merging: (32, 240) to the new token: 257
Merging: (114, 257) to the new token: 258
Merging: (101, 258) to the new token: 259
Merging: (122, 259) to the new token: 260
Merging: (105, 260) to the new token: 261
Merging: (110, 261) to the new token: 262
Merging: (101, 262) to the new token: 263
Merging: (107, 263) to the new token: 264
Merging: (111, 264) to the new token: 265
Merging: (116, 265) to the new token: 266


In [8]:
#Compare the original token list with the new token list
print(f'Original tokens: {len(tokens)}, New tokens: {len(ids)}, compression ratio: {len(ids)/len(tokens):.4f}\n\n')
#Compression ration

Original tokens: 29, New tokens: 19, compression ratio: 0.6552




In [68]:
#Create a vocabulary, we start with the normal characters from 0-255, then we add the new tokens
int2vocab = {i: bytes([i]) for i in range(256)}

#Add the new tokens to the vocabulary
int2vocab.update({v:k for k, v in merges.items()})

#Create a reverse vocabulary
vocab2int = {v:k for k, v in int2vocab.items()}

ValueError: invalid literal for int() with base 10: b'\x00'

In [56]:
int2vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [69]:
vocab2int

{b'\x00': 0,
 b'\x01': 1,
 b'\x02': 2,
 b'\x03': 3,
 b'\x04': 4,
 b'\x05': 5,
 b'\x06': 6,
 b'\x07': 7,
 b'\x08': 8,
 b'\t': 9,
 b'\n': 10,
 b'\x0b': 11,
 b'\x0c': 12,
 b'\r': 13,
 b'\x0e': 14,
 b'\x0f': 15,
 b'\x10': 16,
 b'\x11': 17,
 b'\x12': 18,
 b'\x13': 19,
 b'\x14': 20,
 b'\x15': 21,
 b'\x16': 22,
 b'\x17': 23,
 b'\x18': 24,
 b'\x19': 25,
 b'\x1a': 26,
 b'\x1b': 27,
 b'\x1c': 28,
 b'\x1d': 29,
 b'\x1e': 30,
 b'\x1f': 31,
 b' ': 32,
 b'!': 33,
 b'"': 34,
 b'#': 35,
 b'$': 36,
 b'%': 37,
 b'&': 38,
 b"'": 39,
 b'(': 40,
 b')': 41,
 b'*': 42,
 b'+': 43,
 b',': 44,
 b'-': 45,
 b'.': 46,
 b'/': 47,
 b'0': 48,
 b'1': 49,
 b'2': 50,
 b'3': 51,
 b'4': 52,
 b'5': 53,
 b'6': 54,
 b'7': 55,
 b'8': 56,
 b'9': 57,
 b':': 58,
 b';': 59,
 b'<': 60,
 b'=': 61,
 b'>': 62,
 b'?': 63,
 b'@': 64,
 b'A': 65,
 b'B': 66,
 b'C': 67,
 b'D': 68,
 b'E': 69,
 b'F': 70,
 b'G': 71,
 b'H': 72,
 b'I': 73,
 b'J': 74,
 b'K': 75,
 b'L': 76,
 b'M': 77,
 b'N': 78,
 b'O': 79,
 b'P': 80,
 b'Q': 81,
 b'R': 82,
 b'S': 

In [74]:
def encode(text, vocab2int):
    ids = text.encode('utf-8')
    print(ids)
    return 'fisk'
        
    
    encode_more = True  
    while encode_more:
        encode_more = False
        encoded = []
        i = 0
        for _ in ids:
            if i < len(ids) - 1 and (ids[i], ids[i+1]) in vocab2int:
                encoded.append(vocab2int[(ids[i], ids[i+1])])
                i += 2
                encode_more = True #We need to keep encoding
            else:
                if i < len(ids):
                    if ids[i] in vocab2int:
                        encoded.append(vocab2int[ids[i]])
                    else:
                        encoded.append(-1) #We don't have this token in the vocabulary
                i += 1
                
        ids = encoded
    return encoded
        
    
    

In [67]:
vocab2int

{b'\x00': 0,
 b'\x01': 1,
 b'\x02': 2,
 b'\x03': 3,
 b'\x04': 4,
 b'\x05': 5,
 b'\x06': 6,
 b'\x07': 7,
 b'\x08': 8,
 b'\t': 9,
 b'\n': 10,
 b'\x0b': 11,
 b'\x0c': 12,
 b'\r': 13,
 b'\x0e': 14,
 b'\x0f': 15,
 b'\x10': 16,
 b'\x11': 17,
 b'\x12': 18,
 b'\x13': 19,
 b'\x14': 20,
 b'\x15': 21,
 b'\x16': 22,
 b'\x17': 23,
 b'\x18': 24,
 b'\x19': 25,
 b'\x1a': 26,
 b'\x1b': 27,
 b'\x1c': 28,
 b'\x1d': 29,
 b'\x1e': 30,
 b'\x1f': 31,
 b' ': 32,
 b'!': 33,
 b'"': 34,
 b'#': 35,
 b'$': 36,
 b'%': 37,
 b'&': 38,
 b"'": 39,
 b'(': 40,
 b')': 41,
 b'*': 42,
 b'+': 43,
 b',': 44,
 b'-': 45,
 b'.': 46,
 b'/': 47,
 b'0': 48,
 b'1': 49,
 b'2': 50,
 b'3': 51,
 b'4': 52,
 b'5': 53,
 b'6': 54,
 b'7': 55,
 b'8': 56,
 b'9': 57,
 b':': 58,
 b';': 59,
 b'<': 60,
 b'=': 61,
 b'>': 62,
 b'?': 63,
 b'@': 64,
 b'A': 65,
 b'B': 66,
 b'C': 67,
 b'D': 68,
 b'E': 69,
 b'F': 70,
 b'G': 71,
 b'H': 72,
 b'I': 73,
 b'J': 74,
 b'K': 75,
 b'L': 76,
 b'M': 77,
 b'N': 78,
 b'O': 79,
 b'P': 80,
 b'Q': 81,
 b'R': 82,
 b'S': 

In [73]:
vocab2int

{b'\x00': 0,
 b'\x01': 1,
 b'\x02': 2,
 b'\x03': 3,
 b'\x04': 4,
 b'\x05': 5,
 b'\x06': 6,
 b'\x07': 7,
 b'\x08': 8,
 b'\t': 9,
 b'\n': 10,
 b'\x0b': 11,
 b'\x0c': 12,
 b'\r': 13,
 b'\x0e': 14,
 b'\x0f': 15,
 b'\x10': 16,
 b'\x11': 17,
 b'\x12': 18,
 b'\x13': 19,
 b'\x14': 20,
 b'\x15': 21,
 b'\x16': 22,
 b'\x17': 23,
 b'\x18': 24,
 b'\x19': 25,
 b'\x1a': 26,
 b'\x1b': 27,
 b'\x1c': 28,
 b'\x1d': 29,
 b'\x1e': 30,
 b'\x1f': 31,
 b' ': 32,
 b'!': 33,
 b'"': 34,
 b'#': 35,
 b'$': 36,
 b'%': 37,
 b'&': 38,
 b"'": 39,
 b'(': 40,
 b')': 41,
 b'*': 42,
 b'+': 43,
 b',': 44,
 b'-': 45,
 b'.': 46,
 b'/': 47,
 b'0': 48,
 b'1': 49,
 b'2': 50,
 b'3': 51,
 b'4': 52,
 b'5': 53,
 b'6': 54,
 b'7': 55,
 b'8': 56,
 b'9': 57,
 b':': 58,
 b';': 59,
 b'<': 60,
 b'=': 61,
 b'>': 62,
 b'?': 63,
 b'@': 64,
 b'A': 65,
 b'B': 66,
 b'C': 67,
 b'D': 68,
 b'E': 69,
 b'F': 70,
 b'G': 71,
 b'H': 72,
 b'I': 73,
 b'J': 74,
 b'K': 75,
 b'L': 76,
 b'M': 77,
 b'N': 78,
 b'O': 79,
 b'P': 80,
 b'Q': 81,
 b'R': 82,
 b'S': 

In [75]:
encoded = encode(text, vocab2int)

b'Hello tokenizer \xf0\x9f\x98\x84\xf0\x9f\xa4\xaa\xf0\x9f\xa7\x90!'


In [72]:
encoded

[-1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1]

In [49]:
encoded

[b'H',
 b'e',
 b'l',
 b'l',
 b'o',
 b' ',
 b't',
 b'o',
 b'k',
 b'e',
 b'n',
 b'i',
 b'z',
 b'e',
 b'r',
 b' ',
 b'\xf0',
 b'\x9f',
 b'\x98',
 b'\x84',
 b'\xf0',
 b'\x9f',
 b'\xa4',
 b'\xaa',
 b'\xf0',
 b'\x9f',
 b'\xa7',
 b'\x90',
 b'!']

In [30]:
def decode(ids, vocab):
    tokens = b''.join([vocab[t] for t in ids])
    text = tokens.decode('utf-8', errors='replace')
    return text

In [33]:
decode([65, 128], vocab)

'A�'

In [42]:
encode('Hello tokenizer 😄🤪🧐!', vocab)

ValueError: invalid literal for int() with base 10: b'H'

In [34]:
decode(encode('Hello tokenizer 😄🤪🧐!', vocab), vocab)

KeyError: b'H'