In [2]:
from datasets import load_dataset

In [3]:
ds = load_dataset("roneneldan/TinyStories")
ds["train"][0]

{'text': 'One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.'}

In [4]:
ds['train'][:2]['text']

['One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.',
 'Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.\n\nOne day, Beep was driving in the park when he saw a big tree. The tree had many leaves that we

In [5]:
text = ''.join([t for t in ds['train'][:5]['text']])
tokens = list(map(int, text.encode('utf-8'))) # Converting to a list of intergers in range [0, 255]
len(text), len(tokens)

(3711, 3711)

In [6]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)

In [7]:
top_pair = max(stats, key=stats.get)
top_pair

(101, 32)

In [8]:
chr(top_pair[0]), chr(top_pair[1])

('e', ' ')

In [9]:
def merge_pair(tokens, pair_to_merge, new_idx):
    merged_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens)-1 and (tokens[i], tokens[i+1]) == pair_to_merge:
            merged_tokens.append(new_idx)
            i += 2
        else:
            merged_tokens.append(tokens[i])
            i += 1
    return merged_tokens


In [10]:
updated_token_ids = merge_pair(tokens, top_pair, 256)


In [11]:
final_vocab_size = 276
num_merges = final_vocab_size - 256
token_ids = list(tokens)

merges = {}
for i in range(num_merges):
    stats = get_stats(token_ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f"merging {pair} into new token {idx}")
    token_ids = merge_pair(token_ids, pair, idx)
    merges[pair] = idx

merging (101, 32) into new token 256
merging (32, 116) into new token 257
merging (100, 32) into new token 258
merging (104, 256) into new token 259
merging (104, 101) into new token 260
merging (32, 97) into new token 261
merging (46, 32) into new token 262
merging (105, 110) into new token 263
merging (257, 259) into new token 264
merging (110, 258) into new token 265
merging (116, 32) into new token 266
merging (119, 97) into new token 267
merging (101, 101) into new token 268
merging (260, 114) into new token 269
merging (121, 32) into new token 270
merging (115, 32) into new token 271
merging (104, 97) into new token 272
merging (44, 32) into new token 273
merging (111, 32) into new token 274
merging (101, 114) into new token 275


In [12]:
print("Token length : ", len(tokens))
print("New ids length : ", len(token_ids))
print(f"Compression ratio = {len(tokens)/len(token_ids):.2f}X")

Token length :  3711
New ids length :  2733
Compression ratio = 1.36X


In [46]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text


print(decode([77, 55]))

M7


In [47]:
def encode(text):
    tokens = list(text.encode("utf-8"))
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        pair = min(stats, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break # Nothing else to merge
        idx = merges[pair]
        tokens = merge_pair(tokens, pair, idx)
    return tokens

In [48]:
text = "hey, how are you"
encoded = encode(text)
print(encoded)
decode(encoded)

[260, 121, 273, 104, 111, 119, 261, 114, 256, 121, 111, 117]


'hey, how are you'