In [1]:
from itertools import islice
from collections import defaultdict

import pandas as pd
import numpy as np

## Corpus

In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("manashjyotiborah/top-10000-movies-hosted-on-tmdb")

Downloading from https://www.kaggle.com/api/v1/datasets/download/manashjyotiborah/top-10000-movies-hosted-on-tmdb?dataset_version_number=2...


100%|██████████| 16.9M/16.9M [00:00<00:00, 149MB/s]

Extracting files...





In [3]:
overview_df = pd.read_csv(f'{path}/movies_dataset.csv', index_col='id')
overview_df.dropna(axis=0, inplace=True)

corpus = overview_df['overview'].to_list()
corpus[:3]

["Over many missions and against impossible odds, Dom Toretto and his family have outsmarted, out-nerved and outdriven every foe in their path. Now, they confront the most lethal opponent they've ever faced: A terrifying threat emerging from the shadows of the past who's fueled by blood revenge, and who is determined to shatter this family and destroy everything—and everyone—that Dom loves, forever.",
 "Tasked with extracting a family who is at the mercy of a Georgian gangster, Tyler Rake infiltrates one of the world's deadliest prisons in order to save them. But when the extraction gets hot, and the gangster dies in the heat of battle, his equally ruthless brother tracks down Rake and his team to Sydney, in order to get revenge.",
 'With the price on his head ever increasing, John Wick uncovers a path to defeating The High Table. But before he can earn his freedom, Wick must face off against a new enemy with powerful alliances across the globe and forces that turn old friends into foe

## Tokenization

Source: https://github.com/vukrosic/courses/tree/main/llama4

In [30]:
# Helpers

def get_slice(d:dict, n:int)->list:
    return [(key, val) for key, val in zip(range(n), d.items())]


In [4]:
end_of_word = '/w'

In [5]:
unique_chars = set()
for overview in corpus:
    for char in overview:
        unique_chars.add(char)

vocab = list(unique_chars)
vocab.sort()
vocab.append(end_of_word)

vocab[:10], len(vocab)

(['\r', ' ', '!', '"', '#', '$', '%', '&', "'", '('], 129)

In [8]:
word_splits = defaultdict(int)
for doc in corpus:
    for word in doc.split(' '):
        if word:
            word_tuple = tuple(list(word) + [end_of_word])
            word_splits[word_tuple] += 1

list(islice(word_splits, 10))

[('O', 'v', 'e', 'r', '/w'),
 ('m', 'a', 'n', 'y', '/w'),
 ('m', 'i', 's', 's', 'i', 'o', 'n', 's', '/w'),
 ('a', 'n', 'd', '/w'),
 ('a', 'g', 'a', 'i', 'n', 's', 't', '/w'),
 ('i', 'm', 'p', 'o', 's', 's', 'i', 'b', 'l', 'e', '/w'),
 ('o', 'd', 'd', 's', ',', '/w'),
 ('D', 'o', 'm', '/w'),
 ('T', 'o', 'r', 'e', 't', 't', 'o', '/w'),
 ('h', 'i', 's', '/w')]

In [26]:
get_slice(word_splits, 10)

[(0, (('O', 'v', 'e', 'r', '/w'), 17)),
 (1, (('m', 'a', 'n', 'y', '/w'), 104)),
 (2, (('m', 'i', 's', 's', 'i', 'o', 'n', 's', '/w'), 14)),
 (3, (('a', 'n', 'd', '/w'), 13235)),
 (4, (('a', 'g', 'a', 'i', 'n', 's', 't', '/w'), 470)),
 (5, (('i', 'm', 'p', 'o', 's', 's', 'i', 'b', 'l', 'e', '/w'), 48)),
 (6, (('o', 'd', 'd', 's', ',', '/w'), 17)),
 (7, (('D', 'o', 'm', '/w'), 7)),
 (8, (('T', 'o', 'r', 'e', 't', 't', 'o', '/w'), 4)),
 (9, (('h', 'i', 's', '/w'), 6924))]

In [27]:
def get_pair_stats(splits:dict)->list:
    pair_counts = defaultdict(int)
    for word_tuple, freq in splits.items():
        symbols = list(word_tuple)
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i+1])
            pair_counts[pair] += freq
    return pair_counts

pair_stats = get_pair_stats(word_splits)

get_slice(pair_stats, 10)

[(0, (('O', 'v'), 30)),
 (1, (('v', 'e'), 15729)),
 (2, (('e', 'r'), 39971)),
 (3, (('r', '/w'), 28195)),
 (4, (('m', 'a'), 9653)),
 (5, (('a', 'n'), 37733)),
 (6, (('n', 'y'), 1236)),
 (7, (('y', '/w'), 22819)),
 (8, (('m', 'i'), 5564)),
 (9, (('i', 's'), 23459))]

In [None]:
def merge_pair(pair_to_merge:tuple, splits:dict)->list:
    new_splits = {}
    (first, second) = pair_to_merge
    merged_token = first + second
    for word_tuple, freq in splits.items():
        symbols = list(word_tuple)
        new_symbols = []
        i = 0
        while i < len(symbols):
            if i < len(symbols) - 1 and symbols[i] == first and symbols[i+1] == second:
                new_symbols.append(merged_token)
                i += 2
            else:
                new_symbols.append(symbols[i])
                i += 1
        new_splits[tuple(new_symbols)] = freq
    return new_splits


In [None]:
num_merges = 15
# Stores merge rules, e.g., {('a', 'b'): 'ab'}
# Example: {('T', 'h'): 'Th'}
merges = {}
# Initial word splits: {('T', 'h', 'i', 's', '</w>'): 2, ('i', 's', '</w>'): 2, ...}
current_splits = word_splits.copy() # Start with initial word splits

print("\n--- Starting BPE Merges ---")
print('Initial Splits:')
print(*list(islice(current_splits, 10)), sep='\n')
print("-" * 30)


--- Starting BPE Merges ---
Initial Splits:
('O', 'v', 'e', 'r', '/w')
('m', 'a', 'n', 'y', '/w')
('m', 'i', 's', 's', 'i', 'o', 'n', 's', '/w')
('a', 'n', 'd', '/w')
('a', 'g', 'a', 'i', 'n', 's', 't', '/w')
('i', 'm', 'p', 'o', 's', 's', 'i', 'b', 'l', 'e', '/w')
('o', 'd', 'd', 's', ',', '/w')
('D', 'o', 'm', '/w')
('T', 'o', 'r', 'e', 't', 't', 'o', '/w')
('h', 'i', 's', '/w')
------------------------------


In [None]:
corpus[0]

"Over many missions and against impossible odds, Dom Toretto and his family have outsmarted, out-nerved and outdriven every foe in their path. Now, they confront the most lethal opponent they've ever faced: A terrifying threat emerging from the shadows of the past who's fueled by blood revenge, and who is determined to shatter this family and destroy everything—and everyone—that Dom loves, forever."

In [None]:
for i, word in enumerate(corpus[0].split()):
    print(tuple(list(word)+[end_of_word]))
    if i>9: break

('O', 'v', 'e', 'r', '/w')
('m', 'a', 'n', 'y', '/w')
('m', 'i', 's', 's', 'i', 'o', 'n', 's', '/w')
('a', 'n', 'd', '/w')
('a', 'g', 'a', 'i', 'n', 's', 't', '/w')
('i', 'm', 'p', 'o', 's', 's', 'i', 'b', 'l', 'e', '/w')
('o', 'd', 'd', 's', ',', '/w')
('D', 'o', 'm', '/w')
('T', 'o', 'r', 'e', 't', 't', 'o', '/w')
('a', 'n', 'd', '/w')
('h', 'i', 's', '/w')
