Let's build the GPT Tokenizer (Andrej Karpathy): https://www.youtube.com/watch?v=zduSFxRajkE

In [1]:
# From https://en.wikipedia.org/wiki/Byte_pair_encoding
text = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.[4] Its modification is notable as the large language model tokenizer with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words).[5][6][7] This modification, in the first step, assumes all unique characters to be an initial set of 1-character long n-grams (i.e. initial tokens). Then, successively, the most frequent pair of adjacent characters is merged into a new, 2-character long n-gram and all instances of the pair are replaced by this new token. This is repeated until a vocabulary of prescribed size is obtained. Note that new words can always be constructed from final vocabulary tokens and initial-set characters.[8] This algorithmic approach has been extended from spoken language to sign language in recent years.[9]"
tokens = text.encode("utf-8")
tokens = list(map(int, tokens))
print(len(text), len(tokens))
print(tokens)

1087 1160
[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140, 240, 159, 135, 169, 226, 128, 140, 240, 159, 135, 170, 33, 32, 240, 159, 152, 132, 32, 66, 121, 116, 101, 32, 112, 97, 105, 114, 32, 101, 110, 99, 111, 100, 105, 110, 103, 91, 49, 93, 91, 50, 93, 32, 40, 97, 108, 115, 111, 32, 107, 110, 111, 119, 110, 32, 97, 115, 32, 100, 105, 103, 114, 97, 109, 32, 99, 111, 100, 105, 110, 103, 41, 91, 51, 93, 32, 105, 115, 32, 97, 110, 32, 97, 108, 103, 111, 114, 105, 116, 104, 109, 44, 32, 102, 105, 114, 115, 116, 32, 100, 101, 115, 99, 114, 105, 98, 101, 100, 32, 105, 110, 32, 49, 57, 57, 52, 32, 98

In [2]:
def decode_pair(pair):
    return chr(pair[0])+chr(pair[1])

def get_pair_counts(tokens):
    pairs = {}
    for pair in zip(tokens, tokens[1:]):
        if not pair in pairs:
            pairs[pair] = 0
        pairs[pair] += 1
    return pairs

pair_counts = get_pair_counts(tokens)
sorted_pairs = sorted(((v,k) for k,v in pair_counts.items()), reverse=True)

print('"' + decode_pair(sorted_pairs[0][1]) + '"')
print(sorted_pairs)

"e "
[(27, (101, 32)), (26, (115, 32)), (23, (105, 110)), (20, (32, 116)), (20, (32, 97)), (17, (101, 110)), (15, (240, 159)), (15, (110, 103)), (15, (110, 32)), (15, (32, 105)), (14, (116, 104)), (13, (116, 101)), (13, (116, 32)), (13, (97, 114)), (11, (100, 32)), (11, (97, 110)), (10, (116, 111)), (10, (114, 32)), (10, (104, 97)), (10, (97, 108)), (10, (32, 115)), (9, (115, 116)), (9, (114, 97)), (9, (111, 114)), (9, (111, 100)), (9, (105, 115)), (9, (97, 99)), (9, (32, 102)), (9, (32, 99)), (8, (108, 97)), (8, (105, 116)), (8, (101, 114)), (8, (101, 100)), (8, (100, 105)), (8, (99, 116)), (8, (99, 111)), (7, (226, 128)), (7, (159, 135)), (7, (159, 133)), (7, (116, 105)), (7, (114, 115)), (7, (114, 101)), (7, (111, 110)), (7, (110, 99)), (7, (108, 32)), (7, (103, 101)), (7, (99, 104)), (7, (97, 116)), (7, (32, 111)), (7, (32, 109)), (7, (32, 108)), (6, (239, 189)), (6, (140, 240)), (6, (128, 140)), (6, (115, 105)), (6, (111, 107)), (6, (111, 32)), (6, (110, 116)), (6, (110, 115)), (6

In [3]:
def merge_pair(tokens, pair, replace_token):
    updated = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
            updated.append(replace_token)
            i += 2
        else:
            updated.append(tokens[i])
            i += 1
    return updated

updated = merge_pair(tokens, sorted_pairs[0][1], 256)
print(len(tokens))
print(len(updated))

1160
1133


In [4]:
vocab_size = 256 + 20
num_merges = vocab_size - 256

merges = {}
encoded_tokens = list(tokens)
for i in range(num_merges):
    token_id = vocab_size + i
    pair_count = get_pair_counts(encoded_tokens)
    max_pair = max(pair_count, key=pair_count.get)
    print(f"merge {max_pair} to {token_id}")
    merges[max_pair] = token_id

    encoded_tokens = merge_pair(encoded_tokens, max_pair, token_id)

print(len(tokens), len(encoded_tokens))
print(f"compression ratio: {len(tokens) / len(encoded_tokens):0.2f}")

print(encoded_tokens)

merge (101, 32) to 276
merge (115, 32) to 277
merge (105, 110) to 278
merge (101, 110) to 279
merge (240, 159) to 280
merge (32, 97) to 281
merge (116, 104) to 282
merge (116, 32) to 283
merge (97, 114) to 284
merge (100, 32) to 285
merge (116, 101) to 286
merge (116, 111) to 287
merge (111, 100) to 288
merge (278, 103) to 289
merge (105, 277) to 290
merge (111, 114) to 291
merge (97, 99) to 292
merge (97, 110) to 293
merge (280, 133) to 294
merge (226, 128) to 295
1160 902
compression ratio: 1.29
[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 294, 164, 294, 157, 294, 152, 294, 146, 294, 158, 294, 147, 294, 148, 295, 189, 32, 280, 135, 186, 295, 140, 280, 135, 179, 295, 140, 280, 135, 174, 295, 140, 280, 135, 168, 295, 140, 280, 135, 180, 295, 140, 280, 135, 169, 295, 140, 280, 135, 170, 33, 32, 280, 152, 132, 32, 66, 121, 116, 276, 112, 97, 105, 114, 32, 279, 99, 288, 289, 91, 49, 93, 91, 50, 93, 32, 40, 97, 108, 115, 

In [5]:
reversed_merges = {v: k for k, v in merges.items()}

# TODO predecode all id's to their string, don't do the recursive lookup
def decode_token(id):
    if id < 256:
        return [id]
    
    pair = reversed_merges[id]
    return [*decode_token(pair[0]), *decode_token(pair[1])]


def decode(idx):
    out = []
    for id in idx:
        out.extend(decode_token(id))

    out = bytes(out).decode("utf-8", errors="replace")
    return out


out = decode([239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 294, 164, 294, 157, 294, 152, 294, 146, 294, 158, 294, 147, 294, 148, 295, 189, 32, 280, 135, 186, 295, 140, 280, 135, 179, 295, 140, 280, 135, 174, 295, 140, 280, 135, 168, 295, 140, 280, 135, 180, 295, 140, 280, 135, 169, 295, 140, 280, 135, 170, 33, 32, 280, 152, 132, 32, 66, 121, 116, 276, 112, 97, 105, 114, 32, 279, 99, 288, 289])
print(out)

Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 Byte pair encoding


In [6]:
def encode(text):
    tokens = text.encode("utf-8")
    tokens = list(map(int, tokens))
    for pair, replace_token in merges.items():
        # TODO check if the token needs merging
        tokens = merge_pair(tokens, pair, replace_token)
    return tokens

encode("Hello World")

[72, 101, 108, 108, 111, 32, 87, 291, 108, 100]

In [7]:
decode(encode("Hello World"))

'Hello World'