# Simple Tokenizer

this code is a cleaned up version of [this colab notebook](https://colab.research.google.com/drive/1y0KnCFZvGVf_odSfcNAws6kcDD7HsI0L?usp=sharing#scrollTo=_paQxu7EOhvg) which andrej karpathy wrote in for his [youtube lesson on tokenizers](https://www.youtube.com/watch?v=zduSFxRajkE). I'll be using it as the tokenizer in all of my test models

In [60]:
len("안녕하세요 👋 (hello in Korean!)")

26

In [61]:
# prints python's numeric labels of each character
l = [ord(x) for x in "안녕하세요 👋 (hello in Korean!)"]
print(l)
print(len(l))

[50504, 45397, 54616, 49464, 50836, 32, 128075, 32, 40, 104, 101, 108, 108, 111, 32, 105, 110, 32, 75, 111, 114, 101, 97, 110, 33, 41]
26


In [62]:
# prints the utf-8 numeric labels of each character
l = list("안녕하세요 👋 (hello in Korean!)".encode("utf-8"))
print(l)

# notice that utf-8 is a dynamic encoding scheme that uses 1-4 numbers that range from 0-255. 
# so a very common character will be a single number from 0-255, and a very rare character will be 4 different numbers 0-255
print(len(l))
# for purpose of tokenization, it's perfectly reasonable to use utf-8 even though your tokens may end up being
# composed of partial characters. it doesn't matter bc everything's still in order

[236, 149, 136, 235, 133, 149, 237, 149, 152, 236, 132, 184, 236, 154, 148, 32, 240, 159, 145, 139, 32, 40, 104, 101, 108, 108, 111, 32, 105, 110, 32, 75, 111, 114, 101, 97, 110, 33, 41]
39


In [63]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print('\n', chars, v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you

 ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [64]:
tokens = text[:200].encode("utf-8") # raw bytes
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience
print('---')
print(text[:200])
print("length:", len(text[:200]))
print('---')
print(tokens)
print("length:", len(tokens))

---
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you
length: 200
---
[70, 105, 114, 115, 116, 32, 67, 105, 116, 105, 122, 101, 110, 58, 10, 66, 101, 102, 111, 114, 101, 32, 119, 101, 32, 112, 114, 111, 99, 101, 101, 100, 32, 97, 110, 121, 32, 102, 117, 114, 116, 104, 101, 114, 44, 32, 104, 101, 97, 114, 32, 109, 101, 32, 115, 112, 101, 97, 107, 46, 10, 10, 65, 108, 108, 58, 10, 83, 112, 101, 97, 107, 44, 32, 115, 112, 101, 97, 107, 46, 10, 10, 70, 105, 114, 115, 116, 32, 67, 105, 116, 105, 122, 101, 110, 58, 10, 89, 111, 117, 32, 97, 114, 101, 32, 97, 108, 108, 32, 114, 101, 115, 111, 108, 118, 101, 100, 32, 114, 97, 116, 104, 101, 114, 32, 116, 111, 32, 100, 105, 101, 32, 116, 104, 97, 110, 32, 116, 111, 32, 102, 97, 109, 105, 115, 104, 63, 10, 10, 65, 108, 108, 58, 10, 82, 101, 115, 111, 108, 118, 101, 100, 46, 32, 114, 101, 115, 111,

We happen to get the same length here because we're using such simple/common characters. If we used foreign characters or emoji then they'd each take up more than one byte. But that won't really be an issue with TinyShakespeare

In [65]:
# just to prove what i mean
# text from https://www.reedbeta.com/blog/programmers-intro-to-unicode/
example_text = "Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception."
tokens = example_text.encode("utf-8") # raw bytes
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience
print('---')
print(example_text)
print("length:", len(example_text))
print('---')
print(tokens)
print("length:", len(tokens))

---
Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to “support Unicode” in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don’t blame programmers for still finding the whole thing mysterious, even 30 years after Unicode’s inception.
length: 533
---
[239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131, 239, 189, 143, 239, 189, 132, 239, 189, 133, 33, 32, 240, 159, 133, 164, 240, 159, 133, 157, 240, 159, 133, 152, 240, 159, 133, 146, 240, 159, 133, 158, 240, 159, 133, 147, 240, 159, 133, 148, 226, 128, 189, 32, 240, 159, 135, 186, 226, 128, 140, 240, 159, 135, 179, 226, 128, 140, 240, 159, 135, 174, 226, 128, 140, 240, 159, 135, 168, 226, 128, 140, 240, 159, 135, 180, 226, 128, 140

In [66]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

# let's only do the first 200 characters for now for demonstration purposes
tokens = text[:200].encode("utf-8")

stats = get_stats(tokens)

#print(stats)
print(sorted(((v,k) for k,v in stats.items()), reverse=True))
# so these are all the pairs of tokens in the text found in order of how often they show up

[(5, (101, 32)), (5, (58, 10)), (4, (115, 116)), (4, (114, 115)), (4, (114, 101)), (4, (105, 114)), (4, (101, 100)), (4, (101, 97)), (4, (70, 105)), (4, (10, 10)), (3, (122, 101)), (3, (118, 101)), (3, (116, 105)), (3, (116, 104)), (3, (116, 32)), (3, (115, 111)), (3, (112, 101)), (3, (111, 108)), (3, (110, 58)), (3, (108, 118)), (3, (108, 108)), (3, (105, 122)), (3, (105, 116)), (3, (104, 101)), (3, (101, 115)), (3, (101, 110)), (3, (97, 107)), (3, (67, 105)), (3, (46, 10)), (3, (44, 32)), (3, (32, 116)), (3, (32, 114)), (3, (32, 97)), (3, (32, 67)), (3, (10, 70)), (2, (116, 111)), (2, (115, 112)), (2, (114, 32)), (2, (111, 117)), (2, (111, 32)), (2, (108, 58)), (2, (107, 46)), (2, (101, 114)), (2, (100, 46)), (2, (100, 32)), (2, (97, 114)), (2, (97, 110)), (2, (65, 108)), (2, (32, 115)), (2, (32, 102)), (2, (10, 65)), (1, (121, 111)), (1, (121, 32)), (1, (119, 101)), (1, (117, 114)), (1, (117, 32)), (1, (116, 44)), (1, (115, 104)), (1, (114, 116)), (1, (114, 111)), (1, (114, 97)), (1

In [67]:
# this was the most common pair
top_pair = max(stats, key=stats.get)
top_pair

(58, 10)

In [68]:
# so this function will merge a single pair for us
def merge(ids, pair, idx):
  # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
  newids = []
  i = 0
  while i < len(ids):
    # if we are not at the very last position AND the pair matches, replace it
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

# an example of how it works
print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))

# let's do it with our actual top pair
tokens2 = merge(tokens, top_pair, 256) # 256 is the id of our new token
print(tokens2)
print("length:", len(tokens2))

[5, 6, 99, 9, 1]
[70, 105, 114, 115, 116, 32, 67, 105, 116, 105, 122, 101, 110, 256, 66, 101, 102, 111, 114, 101, 32, 119, 101, 32, 112, 114, 111, 99, 101, 101, 100, 32, 97, 110, 121, 32, 102, 117, 114, 116, 104, 101, 114, 44, 32, 104, 101, 97, 114, 32, 109, 101, 32, 115, 112, 101, 97, 107, 46, 10, 10, 65, 108, 108, 256, 83, 112, 101, 97, 107, 44, 32, 115, 112, 101, 97, 107, 46, 10, 10, 70, 105, 114, 115, 116, 32, 67, 105, 116, 105, 122, 101, 110, 256, 89, 111, 117, 32, 97, 114, 101, 32, 97, 108, 108, 32, 114, 101, 115, 111, 108, 118, 101, 100, 32, 114, 97, 116, 104, 101, 114, 32, 116, 111, 32, 100, 105, 101, 32, 116, 104, 97, 110, 32, 116, 111, 32, 102, 97, 109, 105, 115, 104, 63, 10, 10, 65, 108, 108, 256, 82, 101, 115, 111, 108, 118, 101, 100, 46, 32, 114, 101, 115, 111, 108, 118, 101, 100, 46, 10, 10, 70, 105, 114, 115, 116, 32, 67, 105, 116, 105, 122, 101, 110, 256, 70, 105, 114, 115, 116, 44, 32, 121, 111, 117]
length: 195


In [69]:
vocab_size = 300 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list

# now let's actually do it
merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(ids)
  pair = max(stats, key=stats.get)
  idx = 256 + i
  print(f"merging {pair} into a new token {idx}")
  ids = merge(ids, pair, idx)
  merges[pair] = idx

merging (58, 10) into a new token 256
merging (101, 32) into a new token 257
merging (70, 105) into a new token 258
merging (258, 114) into a new token 259
merging (259, 115) into a new token 260
merging (260, 116) into a new token 261
merging (101, 100) into a new token 262
merging (101, 97) into a new token 263
merging (10, 10) into a new token 264
merging (261, 32) into a new token 265
merging (265, 67) into a new token 266
merging (266, 105) into a new token 267
merging (267, 116) into a new token 268
merging (268, 105) into a new token 269
merging (269, 122) into a new token 270
merging (270, 101) into a new token 271
merging (271, 110) into a new token 272
merging (272, 256) into a new token 273
merging (116, 104) into a new token 274
merging (44, 32) into a new token 275
merging (112, 263) into a new token 276
merging (276, 107) into a new token 277
merging (46, 264) into a new token 278
merging (108, 108) into a new token 279
merging (32, 114) into a new token 280
merging (101,

In [70]:
print("tokens length:", len(tokens)) # remember tokens are our original tokens
print("ids length:", len(ids)) # and ids are new tokens we've made
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")

tokens length: 200
ids length: 72
compression ratio: 2.78X


In [71]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
  # given ids (list of integers), return Python string
  tokens = b"".join(vocab[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

print(decode([128]))

�


In [72]:
def encode(text):
  # given a string, return list of integers (the tokens)
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = get_stats(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break # nothing else can be merged
    idx = merges[pair]
    tokens = merge(tokens, pair, idx)
  return tokens

print(encode("Hello World!"))

[72, 101, 279, 111, 32, 87, 111, 114, 108, 100, 33]


In [73]:
print(decode(encode("hello world")))

text2 = decode(encode(text))
print(text2 == text)

hello world
True


now instead of just byte-pair encoding, we're also going to enforce what are called "regex" rules that basically prevent the tokenizer from combining certain bytes. For example, we humans think that special characters like `!` are notably different from letters, so we don't want the tokenizer to combine `n` and `!` into one token `n!`. It also does other things like merge long sequences of spaces or newline characters

In [74]:
import regex as re
import tiktoken

In [75]:
# GPT-2 (does not merge spaces)
enc = tiktoken.get_encoding("gpt2")
print(enc.encode("    hello world!!!")) # 220 is the " " token

# GPT-4 (merges spaces)
enc = tiktoken.get_encoding("cl100k_base")
print(enc.encode("    hello world!!!"))

[220, 220, 220, 23748, 995, 10185]
[262, 24748, 1917, 12340]


Personally tho I have no interest in doing all that the way they did given my tiny dataset. I just want a tokenizer with more than 65 characters to give my tiny test model a more realistic modeling experience, and tinyShakespeare doesn't have a whole lot of special characters anyways. So we're gonna enforce an over-simplified version of regex which takes advantage of the fact that the first 13 characters are non-alphabetic.

# Actually Building It

We're gonna make it hella simple, start with our 65 unique characters, and turn them into 128 total tokens that are either composed of letters or non-letters but not both

In [1]:
# load the dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# the first 200 characters. It's just one continuous text document with all of the works of shakespeare back-to-back
print(text[:200])

# here are all the unique characters that occur in this text and how many there are
chars = sorted(list(set(text)))
v = len(chars)
print('\n', chars, v)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you

 ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 65


In [2]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
char_encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers

tokens = char_encode(text)

In [3]:
vocab_size = 128 # the desired final vocabulary size
num_merges = vocab_size - v
ids = list(tokens) # copy so we don't destroy the original list

In [4]:
base_indices = char_encode(chars)
print(base_indices)
origin = [ "symbol" if i < 13 else "letter" for i in base_indices]  # Track token origin
print(origin)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64]
['symbol', 'symbol', 'symbol', 'symbol', 'symbol', 'symbol', 'symbol', 'symbol', 'symbol', 'symbol', 'symbol', 'symbol', 'symbol', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter', 'letter']


In [5]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
  newids = []
  i = 0
  while i < len(ids):
    # if we are not at the very last position AND the pair matches, replace it
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

In [6]:
# now let's actually do it
merges = {} # (int, int) -> int
for i in range(num_merges):
    #print(i)
    stats = get_stats(ids)

    # Modified pair selection logic:
    while True:
        pair = max(stats, key=stats.get)  # Get the most frequent pair initially
        #print(pair)

        if origin[pair[0]] != origin[pair[1]]: # Check if origins differ
            #print(origin[pair[0]],origin[pair[1]],origin[pair[0]] != origin[pair[1]])
            del stats[pair]
        else:  # If no valid pairs left, break out of the loop
            #print(origin[pair[0]],origin[pair[1]],origin[pair[0]] != origin[pair[1]])
            break
    else:
        break  # Valid pair found 
    
    pair = max(stats, key=stats.get)
    #print(pair)

    idx = v + i
    print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx
    origin.append(origin[pair[0]])

merging (58, 46) into a new token 65
merging (6, 1) into a new token 66
merging (53, 59) into a new token 67
merging (43, 56) into a new token 68
merging (47, 52) into a new token 69
merging (39, 52) into a new token 70
merging (10, 0) into a new token 71
merging (65, 43) into a new token 72
merging (53, 56) into a new token 73
merging (47, 57) into a new token 74
merging (0, 0) into a new token 75
merging (43, 52) into a new token 76
merging (39, 56) into a new token 77
merging (39, 58) into a new token 78
merging (53, 52) into a new token 79
merging (57, 58) into a new token 80
merging (50, 50) into a new token 81
merging (6, 0) into a new token 82
merging (51, 43) into a new token 83
merging (58, 53) into a new token 84
merging (8, 75) into a new token 85
merging (70, 42) into a new token 86
merging (46, 43) into a new token 87
merging (63, 67) into a new token 88
merging (43, 57) into a new token 89
merging (52, 53) into a new token 90
merging (57, 43) into a new token 91
merging (

In [7]:
print("tokens length:", len(tokens)) # remember tokens are our original tokens
print("ids length:", len(ids)) # and ids are new tokens we've made
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")

tokens length: 1115394
ids length: 809324
compression ratio: 1.38X


This is pretty ideal because if we did too large of a compression ratio we'd be making our dataset size way too small to be useable. notice how if we hadn't implmenented our symbols vs letters rule we'd end up with a higher compression ratio

In [8]:
import pickle
import os

# Ensure the tokenizers directory exists
if not os.path.exists('./tokenizers'):
    os.makedirs('./tokenizers')

# Prepare the tokenizer data to be saved
tokenizer_data = {
    'stoi': stoi,  # Character to integer mapping
    'merges': merges  # Merges dictionary
}

# Save the tokenizer data using pickle
with open('./tokenizers/tokenizer.model', 'wb') as f:
    pickle.dump(tokenizer_data, f)

print("Tokenizer saved successfully.")

Tokenizer saved successfully.


In [17]:
# Load the tokenizer data using pickle
with open('./tokenizers/tokenizer.model', 'rb') as f:
    loaded_tokenizer_data = pickle.load(f)

# Extract the stoi mapping and merges from the loaded data
loaded_stoi = loaded_tokenizer_data['stoi']
loaded_merges = loaded_tokenizer_data['merges']

print("Tokenizer loaded successfully.")

Tokenizer loaded successfully.


In [18]:
print(stoi)
print(merges)

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}
{(58, 46): 65, (6, 1): 66, (53, 59): 67, (43, 56): 68, (47, 52): 69, (39, 52): 70, (10, 0): 71, (65, 43): 72, (53, 56): 73, (47, 57): 74, (0, 0): 75, (43, 52): 76, (39, 56): 77, (39, 58): 78, (53, 52): 79, (57, 58): 80, (50, 50): 81, (6, 0): 82, (51, 43): 83, (58, 53): 84, (8, 75): 85, (70, 42): 86, (46, 43): 87, (63, 67): 88, (43, 57): 89, (52, 53): 90, (57, 43): 91, (46, 39): 92, (56, 43): 93, (53, 44): 94, (60, 43): 

In [24]:
class SimpleTokenizer:
    def __init__(self, stoi, merges):
        self.stoi = stoi
        self.merges = merges
        self.itos = {i: s for s, i in stoi.items()}  # Inverse mapping for decoding

        self.vocab_len = len(stoi) + len(merges)

    def encode(self, text):
        # Convert the text to a list of token IDs, using space for unknown characters
        tokens = [self.stoi.get(c, self.stoi[' ']) for c in text]

        # Perform merging with the possibility of nested merges
        i = 0
        while i < len(tokens) - 1:
            pair = (tokens[i], tokens[i + 1])
            if pair in self.merges:
                # Replace the current pair with its merged token
                merged_token = self.merges[pair]
                tokens[i] = merged_token
                del tokens[i + 1]

                # Move back to handle possible nested merges
                if i > 0:
                    i -= 1
            else:
                i += 1

        return tokens

    def decode(self, tokens):
        def expand_token(token):
            # Base case: if the token is a direct mapping, return its character
            if token in self.itos:
                return self.itos[token]
            # Recursive case: if the token is a merged token, expand its constituents
            elif token in self.merges.values():
                pair = next(key for key, value in self.merges.items() if value == token)
                return ''.join(expand_token(t) for t in pair)
            # Fallback for unknown tokens
            else:
                return ''

        # Decode each token in the list, handling nested merges recursively
        return ''.join(expand_token(token) for token in tokens)

# Example usage
# Assuming loaded_stoi and loaded_merges are already loaded from the tokenizer.model file

tokenizer = SimpleTokenizer(loaded_stoi, loaded_merges)

# Encoding text
encoded_text = tokenizer.encode("JULIET:\nO Romeo, Romeo! wherefore art thou R")
print("Encoded:", encoded_text)

# Decoding back
decoded_text = tokenizer.decode(encoded_text)
print("Decoded:", decoded_text)

Encoded: [22, 33, 24, 21, 17, 32, 71, 27, 1, 30, 53, 83, 53, 66, 30, 53, 83, 53, 2, 1, 61, 87, 93, 105, 43, 1, 77, 58, 1, 65, 67, 1, 30]
Decoded: JULIET:
O Romeo, Romeo! wherefore art thou R


In [25]:
for i in range(128):
    print(f"{i}: '{tokenizer.decode([i])}'")

0: '
'
1: ' '
2: '!'
3: '$'
4: '&'
5: '''
6: ','
7: '-'
8: '.'
9: '3'
10: ':'
11: ';'
12: '?'
13: 'A'
14: 'B'
15: 'C'
16: 'D'
17: 'E'
18: 'F'
19: 'G'
20: 'H'
21: 'I'
22: 'J'
23: 'K'
24: 'L'
25: 'M'
26: 'N'
27: 'O'
28: 'P'
29: 'Q'
30: 'R'
31: 'S'
32: 'T'
33: 'U'
34: 'V'
35: 'W'
36: 'X'
37: 'Y'
38: 'Z'
39: 'a'
40: 'b'
41: 'c'
42: 'd'
43: 'e'
44: 'f'
45: 'g'
46: 'h'
47: 'i'
48: 'j'
49: 'k'
50: 'l'
51: 'm'
52: 'n'
53: 'o'
54: 'p'
55: 'q'
56: 'r'
57: 's'
58: 't'
59: 'u'
60: 'v'
61: 'w'
62: 'x'
63: 'y'
64: 'z'
65: 'th'
66: ', '
67: 'ou'
68: 'er'
69: 'in'
70: 'an'
71: ':
'
72: 'the'
73: 'or'
74: 'is'
75: '

'
76: 'en'
77: 'ar'
78: 'at'
79: 'on'
80: 'st'
81: 'll'
82: ',
'
83: 'me'
84: 'to'
85: '.

'
86: 'and'
87: 'he'
88: 'you'
89: 'es'
90: 'no'
91: 'se'
92: 'ha'
93: 're'
94: 'of'
95: 've'
96: 'it'
97: 'ing'
98: 'be'
99: 'le'
100: 'wi'
101: 'my'
102: 'hi'
103: 'ow'
104: 'ce'
105: 'for'
106: 'ay'
107: 'as'
108: 'ch'
109: 'nd'
110: 'ere'
111: 'ld'
112: 'ir'
113: 'ed'
114: 'ut'
115: 'ro'
116: 'no