In [1]:
import pickle
import os
from tqdm import tqdm

In [2]:
# loading a sample from our TinyStories dataset
from tools import get_data_loader
data_loader = get_data_loader(batch_size=5_000, split='train')

batch = next(iter(data_loader))
print(len(batch), batch[0])

Found cached dataset json (/Users/tunadorable/.cache/huggingface/datasets/noanabeshima___json/noanabeshima--TinyStoriesV2-226173b7dd235c68/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


5000 Once upon a time, there was a smelly dog named Max. Max did not like being smelly, so he tried to find a way to not be smelly. He went to a big tree and found a staff. The staff could talk and help him.
"Hello, I am a magic staff," the staff said to Max. "I can explain how to not be smelly." Max was very happy to hear this. He asked the staff what he needed to do. The staff told him that he needed to take a bath in the river.
Max went to the river and jumped in. He splashed and played in the water. The staff watched and smiled. When Max was done, he came out of the river. He was not smelly anymore.
Max thanked the staff for helping him. They became best friends and went on many fun adventures together. And Max was never smelly again.


In [3]:
# turn it into one string instead of a list of strings
combined_string = '\n\n'.join(batch)

# find the unique characters
#chars = sorted(list(set(combined_string)))
# this is the largest set of characters i found earlier from a batch size of 1_000_000, total 95 characters
chars = ['\t', '\n', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']
v = len(chars)
print('\n', chars, v)


 ['\t', '\n', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~'] 95


# BPE tokenization
rather than actual byte-pair encoding, we'll do character-pair encoding, meaning we'll build off our base number of characters isntead of actual bytes

In [4]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
char_encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers

In [5]:
vocab_size = 2048 # the desired final vocabulary size
num_merges = vocab_size - v

In [6]:
# most models work off bytes, but we'll be simplifying to just the index of each unique character
base_indices = char_encode(chars)
print(base_indices)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94]


In [7]:
def classify_char(c):
    if c.isupper():
        return ["upper_letter"]
    elif c.islower():
        return ["lower_letter"]
    elif c.isdigit():
        return ["digit"]
    elif c == ' ':
        return ["space"]
    elif c == "'":
        return ["apostrophe"]
    else:
        return ["symbol"]

origin = [classify_char(ch) for ch in chars]
print(origin)

[['symbol'], ['symbol'], ['space'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['apostrophe'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['digit'], ['digit'], ['digit'], ['digit'], ['digit'], ['digit'], ['digit'], ['digit'], ['digit'], ['digit'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['upper_letter'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['symbol'], ['lower_letter'], ['lower_letter'], ['lower_letter'], ['lower_letter']

In [8]:
def can_merge(a, b):
    # Define merge rules
    rules = {
        "upper_letter": {"upper_letter", "lower_letter", "space", "apostrophe"},
        "lower_letter": {"upper_letter", "lower_letter", "space", "apostrophe"},
        "digit": {"digit", "space"}, 
        "symbol": {"symbol", "space"}, 
        "space": {"symbol", "digit", "upper_letter", "lower_letter", "apostrophe"},
        "apostrophe": {"upper_letter", "lower_letter", "space"}
    }

    # Check if all elements in a can merge with all elements in b
    for x in a:
        for y in b:
            if y not in rules[x]:
                return False
    return True

def test_can_merge():
    # Define test cases with expected outcomes
    tests = [
        (["lower_letter", "upper_letter"], ["space"], True),
        (["lower_letter", "upper_letter"], ["digit"], False),
        (["space", "symbol"], ["space"], False),
        (["space"], ["space"], False),
        (["space", "apostrophe", "lower_letter"], ["upper_letter"], True),
        (["space", "digit"], ["symbol"], False),
        (["symbol", "space"], ["upper_letter"], False),
        (["symbol"], ["upper_letter", "space"], False),
        # Add more test cases as needed
    ]
    
    # Run tests
    for a, b, expected in tests:
        result = can_merge(a, b)
        assert result == expected, f"Test with a={a} and b={b} failed. Expected {expected}, got {result}"
        print(f"Test with a={a} and b={b} passed as expected.")

# Execute the test function
test_can_merge()


Test with a=['lower_letter', 'upper_letter'] and b=['space'] passed as expected.
Test with a=['lower_letter', 'upper_letter'] and b=['digit'] passed as expected.
Test with a=['space', 'symbol'] and b=['space'] passed as expected.
Test with a=['space'] and b=['space'] passed as expected.
Test with a=['space', 'apostrophe', 'lower_letter'] and b=['upper_letter'] passed as expected.
Test with a=['space', 'digit'] and b=['symbol'] passed as expected.
Test with a=['symbol', 'space'] and b=['upper_letter'] passed as expected.
Test with a=['symbol'] and b=['upper_letter', 'space'] passed as expected.


In [9]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

In [10]:
tokens = char_encode(combined_string)
ids = list(tokens) # copy so we don't destroy the original list

# now let's actually do it
merges = {} # (int, int) -> int
for i in tqdm(range(num_merges)):
    stats = get_stats(ids)
    valid_pairs = {}

    # Check each pair's validity based on updated origin indices
    for pair, freq in stats.items():
        # Ensure that pair indices are within the current range of origin
        if pair[0] < len(origin) and pair[1] < len(origin):
            if can_merge(origin[pair[0]], origin[pair[1]]):
                valid_pairs[pair] = freq

    if not valid_pairs:
        break  # No more valid pairs to merge

    pair = max(valid_pairs, key=valid_pairs.get)
    idx = v + i  # Ensure this index starts from an appropriate offset
    ids = merge(ids, pair, idx)
    merges[pair] = idx
    #print(idx, origin[pair[0]], origin[pair[1]], list(set(origin[pair[0]] + origin[pair[1]])))
    origin.append(list(set(origin[pair[0]] + origin[pair[1]])))  # Ensure this is consistently updated

100%|████████████████████████| 1953/1953 [09:30<00:00,  3.42it/s]


In [11]:
print("tokens length:", len(tokens)) # remember tokens are our original tokens
print("ids length:", len(ids)) # and ids are new tokens we've made
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")

tokens length: 4015765
ids length: 1101826
compression ratio: 3.64X


In [22]:
# Ensure the tokenizers directory exists
if not os.path.exists('./models'):
    os.makedirs('./models')

# Prepare the tokenizer data to be saved
tokenizer_data = {
    'stoi': stoi,  # Character to integer mapping
    'merges': merges  # Merges dictionary
}

# Save the tokenizer data using pickle
with open(f'./models/{vocab_size}.model', 'wb') as f:
    pickle.dump(tokenizer_data, f)

In [None]:
# taking a pre-existing tokenizer and trimming it down to a smaller size
# i basically ran this cell and the one below it multiple times until i got to the smallest possible size (128)
vocab_size = 95#vocab_size // 2 # 95
merges = {k: v for k, v in merges.items() if v < vocab_size}

In [None]:
# Ensure the tokenizers directory exists
if not os.path.exists('./models'):
    os.makedirs('./models')

# Prepare the tokenizer data to be saved
tokenizer_data = {
    'stoi': stoi,  # Character to integer mapping
    'merges': merges  # Merges dictionary
}

# Save the tokenizer data using pickle
with open(f'./models/{vocab_size}.model', 'wb') as f:
    pickle.dump(tokenizer_data, f)

In [10]:
from tokenizer import get_tokenizer
tokenizer = get_tokenizer(size=2048)

In [14]:
for i in range(vocab_size):
    print(f"{i}: '{tokenizer.decode([i])}'")

0: '	'
1: '
'
2: ' '
3: '!'
4: '"'
5: '#'
6: '$'
7: '%'
8: '&'
9: '''
10: '('
11: ')'
12: '*'
13: '+'
14: ','
15: '-'
16: '.'
17: '/'
18: '0'
19: '1'
20: '2'
21: '3'
22: '4'
23: '5'
24: '6'
25: '7'
26: '8'
27: '9'
28: ':'
29: ';'
30: '<'
31: '='
32: '>'
33: '?'
34: 'A'
35: 'B'
36: 'C'
37: 'D'
38: 'E'
39: 'F'
40: 'G'
41: 'H'
42: 'I'
43: 'J'
44: 'K'
45: 'L'
46: 'M'
47: 'N'
48: 'O'
49: 'P'
50: 'Q'
51: 'R'
52: 'S'
53: 'T'
54: 'U'
55: 'V'
56: 'W'
57: 'X'
58: 'Y'
59: 'Z'
60: '['
61: '\'
62: ']'
63: '_'
64: '`'
65: 'a'
66: 'b'
67: 'c'
68: 'd'
69: 'e'
70: 'f'
71: 'g'
72: 'h'
73: 'i'
74: 'j'
75: 'k'
76: 'l'
77: 'm'
78: 'n'
79: 'o'
80: 'p'
81: 'q'
82: 'r'
83: 's'
84: 't'
85: 'u'
86: 'v'
87: 'w'
88: 'x'
89: 'y'
90: 'z'
91: '{'
92: '|'
93: '}'
94: '~'
95: 'e '
96: 'd '
97: 'th'
98: ' a'
99: '. '
100: 't '
101: 'y '
102: 's '
103: 'nd '
104: 'to'
105: 'er'
106: 'ed '
107: 'the '
108: ', '
109: 'wa'
110: 'in'
111: 'he '
112: 'to '
113: 'ou'
114: 'en'
115: 'ha'
116: 'om'
117: 'sa'
118: 'ar'
119: '.
'