In [None]:
"""
Train and save the tokenizer
"""

import sys
import os

sys.path.append("/Users/benjawesome/coding/positional-gpt-2")

save_dir = os.path.join("..", "files")
os.makedirs(save_dir, exist_ok=True)

save_path = os.path.join(save_dir, "my_object.pkl")

print(save_path)

from data.byte_pair_encoding import BytePairTokenizer
from datasets import load_dataset
import pickle

raw_datasets = load_dataset("wikitext", "wikitext-2-raw-v1")

# get our data
train_text_list = raw_datasets["train"]["text"]
val_text_list = raw_datasets["validation"]["text"]

# turn the data into strings for input
full_training_corpus = "\n".join(train_text_list)
val_training_corpus = "\n".join(val_text_list)

# create tokenizer and train
tokenizer = BytePairTokenizer(full_training_corpus[:20000])
tokenizer.train(5)

# save the trained tokenizer
tokenizer_state = {
    'merges': tokenizer.merges,
    'token_to_id': tokenizer.token_ids,
    'byte_to_char': tokenizer.byte_to_char
}   

with open(save_path, 'wb') as f:
    pickle.dump(tokenizer_state, f)

# test encoding on val set
encoded = tokenizer.encode(val_training_corpus[:1000])

print(tokenizer.token_ids.get(","))

print(encoded)

../files/my_object.pkl
11
11
11
11
11
11
11
11
11
11
11
11
[199, 221, 28, 221, 39, 78, 76, 64, 81, 84, 257, 70, 64, 76, 76, 64, 81, 84, 257, 28, 221, 199, 199, 199, 221, 39, 78, 76, 64, 81, 84, 257, 70, 64, 76, 76, 64, 81, 84, 257, 11, 221, 74, 77, 78, 86, 260, 64, 257, 258, 256, 36, 84, 81, 78, 79, 68, 64, 260, 75, 78, 65, 82, 83, 68, 81, 221, 78, 81, 221, 66, 78, 76, 76, 78, 260, 75, 78, 65, 82, 83, 68, 81, 221, 11, 221, 72, 257, 64, 221, 82, 79, 68, 66, 72, 68, 257, 78, 69, 221, 66, 75, 64, 86, 68, 259, 75, 78, 65, 82, 83, 68, 81, 221, 69, 81, 78, 76, 221, 258, 256, 68, 64, 82, 83, 68, 81, 260, 32, 83, 75, 64, 77, 83, 72, 66, 221, 46, 66, 68, 64, 260, 11, 221, 44, 68, 67, 72, 83, 68, 81, 81, 64, 77, 68, 64, 260, 50, 68, 64, 221, 64, 77, 259, 79, 64, 81, 83, 257, 78, 69, 221, 258, 256, 33, 75, 64, 66, 74, 221, 50, 68, 64, 221, 13, 221, 40, 83, 221, 72, 257, 66, 75, 78, 82, 68, 75, 88, 221, 81, 68, 75, 64, 83, 68, 259, 83, 78, 221, 258, 256, 32, 76, 68, 81, 72, 66, 64, 260, 75, 78, 65

In [None]:
"""
Load and test the tokenizer object
"""

import sys
import os
import pickle

sys.path.append("/Users/benjawesome/coding/positional-gpt-2")

save_dir = os.path.join("..", "files")
os.makedirs(save_dir, exist_ok=True)

save_path = os.path.join(save_dir, "my_object.pkl")

with open(save_path, 'rb') as f:
    tokenizer_state = pickle.load(f)

token_ids = tokenizer_state.get('token_to_id')

from data.byte_pair_encoding import BytePairTokenizer
from datasets import load_dataset

tokenizer = BytePairTokenizer.load(save_path)

print("comma id: " + str(tokenizer.token_ids.get(",")))

raw_datasets = load_dataset("wikitext", "wikitext-2-raw-v1")

# get our data
train_text_list = raw_datasets["train"]["text"]
val_text_list = raw_datasets["validation"]["text"]

# turn the data into strings for input
full_training_corpus = "\n".join(train_text_list)
val_training_corpus = "\n".join(val_text_list)

print(token_ids.get(","))

# realized that my tokenizer didn't handle punctuation correct

# fix the missing puncutations

MISSING_TOKENS = [",", ".", "!", "?", " "]

next_available_id = 5000

for token_string in MISSING_TOKENS:
    if token_string not in tokenizer.token_ids:
        tokenizer.token_ids[token_string] = next_available_id
        next_available_id += 1
        print(f"Assigned ID {tokenizer.token_ids[token_string]} to token: '{token_string}'")

tokenizer_state = {
    'merges': tokenizer.merges,
    'token_to_id': tokenizer.token_ids,
    'byte_to_char': tokenizer.byte_to_char
}   

with open(save_path, 'wb') as f:
    pickle.dump(tokenizer_state, f)

print(tokenizer.token_ids.get(","))

encoded = tokenizer.encode(val_training_corpus[:1000])

print(encoded)


comma id: 5000




5000
5000
H
H
l
l
l
M
l
H
g
c
r
l
u
l
u
l
c
g
h
v
e
c
l
p
H
c
w
r
l
l
c
l
c
@
None
[373, 298, None, 315, 271, 630, 1041, 748, 630, 1649, None, 315, 271, 630, 1041, 748, 2712, 3198, 2684, None, 597, 2833, 312, 1998, None, 597, 290, 588, 919, 1323, 278, 407, 585, 268, None, 597, 2833, 612, 2552, 2835, 1174, 1709, 266, None, 3152, 261, 535, 1709, 2081, 1188, 448, 3777, 3226, 2081, 279, 2296, 1744, 584, 4275, 418, 1113, None, 597, 290, 588, None, 270, 318, 261, 294, 265, 630, 709, 1110, None, 2111, 1159, 2813, 3195, 2262, None, 1861, 2202, 2574, 1455, 3322, 542, 517, 3781, 502, None, 2921, 362, 1811, None, 926, 2655, 358, 271, 1702, 359, 420, 294, None, 618, 4464, 278, 407, 585, 257, 540, 1070, 1200, None, 597, 290, 390, 455, 505, None, 370, 736, 2376, 2135, None, 597, 2833, 1153, 308, 284, None, 419, 830, 844, 885, 2016, 3349, 4674, 588, 852, 286, 670, None, 257, 456, 455, 2514, 1049, 2155, 2873, 351, 706, 1159, 1096, 915, None, 292, 2732, 848, 340, 1814, 3094, 434, 744, None, 3959, 487, 