In [1]:
from itertools import islice
from collections import defaultdict

import pandas as pd
import numpy as np

## Corpus

In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("manashjyotiborah/top-10000-movies-hosted-on-tmdb")

Downloading from https://www.kaggle.com/api/v1/datasets/download/manashjyotiborah/top-10000-movies-hosted-on-tmdb?dataset_version_number=2...


100%|██████████| 16.9M/16.9M [00:00<00:00, 81.5MB/s]

Extracting files...





In [3]:
overview_df = pd.read_csv(f'{path}/movies_dataset.csv', index_col='id')
overview_df.dropna(axis=0, inplace=True)

corpus = overview_df['overview'].to_list()
corpus[:3]

["Over many missions and against impossible odds, Dom Toretto and his family have outsmarted, out-nerved and outdriven every foe in their path. Now, they confront the most lethal opponent they've ever faced: A terrifying threat emerging from the shadows of the past who's fueled by blood revenge, and who is determined to shatter this family and destroy everything—and everyone—that Dom loves, forever.",
 "Tasked with extracting a family who is at the mercy of a Georgian gangster, Tyler Rake infiltrates one of the world's deadliest prisons in order to save them. But when the extraction gets hot, and the gangster dies in the heat of battle, his equally ruthless brother tracks down Rake and his team to Sydney, in order to get revenge.",
 'With the price on his head ever increasing, John Wick uncovers a path to defeating The High Table. But before he can earn his freedom, Wick must face off against a new enemy with powerful alliances across the globe and forces that turn old friends into foe

In [4]:
toy_corpus = [
    'This is the first document.',
    'This document is the second document.',
    'And this is the third one.',
    'Is this the first document?',
]

## Tokenization

Source: https://github.com/vukrosic/courses/tree/main/llama4

In [5]:
# Helpers
END_OF_WORD = '/w'

def get_slice(d:dict, n:int)->list:
    return [(key, val) for key, val in zip(range(n), d.items())]


In [6]:
def get_unique_chars(corpus):
    unique_chars = set()
    for overview in corpus:
        for char in overview:
            unique_chars.add(char)

    vocab = list(unique_chars)
    vocab.sort()
    vocab.append(END_OF_WORD)

    return vocab


def get_word_splits(corpus):
    word_splits = defaultdict(int)
    for doc in corpus:
        for word in doc.split(' '):
            if word:
                word_tuple = tuple(list(word) + [END_OF_WORD])
                word_splits[word_tuple] += 1

    return word_splits


def get_pair_stats(splits:dict)->list:
    pair_counts = defaultdict(int)
    for word_tuple, freq in splits.items():
        symbols = list(word_tuple)
        for i in range(len(symbols)-1):
            pair = (symbols[i], symbols[i+1])
            pair_counts[pair] += freq
    return pair_counts


def merge_pair(pair_to_merge:tuple, splits:dict)->list:
    new_splits = {}
    (first, second) = pair_to_merge
    merged_token = first + second
    for word_tuple, freq in splits.items():
        symbols = list(word_tuple)
        new_symbols = []
        i = 0
        while i < len(symbols):
            if i<len(symbols)-1 and symbols[i]==first and symbols[i+1]==second:
                new_symbols.append(merged_token)
                i += 2
            else:
                new_symbols.append(symbols[i])
                i += 1
        new_splits[tuple(new_symbols)] = freq
    return new_splits



In [None]:
vocab = get_unique_chars(corpus)
print(vocab[:10], len(vocab), sep='\n')

['\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(']
129


In [None]:
toy_vocab = get_unique_chars(toy_corpus)
print(toy_vocab, len(toy_vocab), sep='\n')

[' ', '.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u', '/w']
20


In [None]:
word_splits = get_word_splits(corpus)
get_slice(word_splits, 10)

[(0, (('O', 'v', 'e', 'r', '/w'), 17)),
 (1, (('m', 'a', 'n', 'y', '/w'), 104)),
 (2, (('m', 'i', 's', 's', 'i', 'o', 'n', 's', '/w'), 14)),
 (3, (('a', 'n', 'd', '/w'), 13235)),
 (4, (('a', 'g', 'a', 'i', 'n', 's', 't', '/w'), 470)),
 (5, (('i', 'm', 'p', 'o', 's', 's', 'i', 'b', 'l', 'e', '/w'), 48)),
 (6, (('o', 'd', 'd', 's', ',', '/w'), 17)),
 (7, (('D', 'o', 'm', '/w'), 7)),
 (8, (('T', 'o', 'r', 'e', 't', 't', 'o', '/w'), 4)),
 (9, (('h', 'i', 's', '/w'), 6924))]

In [None]:
toy_word_splits = get_word_splits(toy_corpus)
toy_word_splits

defaultdict(int,
            {('T', 'h', 'i', 's', '/w'): 2,
             ('i', 's', '/w'): 3,
             ('t', 'h', 'e', '/w'): 4,
             ('f', 'i', 'r', 's', 't', '/w'): 2,
             ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '/w'): 2,
             ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '/w'): 1,
             ('s', 'e', 'c', 'o', 'n', 'd', '/w'): 1,
             ('A', 'n', 'd', '/w'): 1,
             ('t', 'h', 'i', 's', '/w'): 2,
             ('t', 'h', 'i', 'r', 'd', '/w'): 1,
             ('o', 'n', 'e', '.', '/w'): 1,
             ('I', 's', '/w'): 1,
             ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '/w'): 1})

In [None]:
pair_stats = get_pair_stats(word_splits)
get_slice(pair_stats, 10)

[(0, (('O', 'v'), 30)),
 (1, (('v', 'e'), 15729)),
 (2, (('e', 'r'), 39971)),
 (3, (('r', '/w'), 28195)),
 (4, (('m', 'a'), 9653)),
 (5, (('a', 'n'), 37733)),
 (6, (('n', 'y'), 1236)),
 (7, (('y', '/w'), 22819)),
 (8, (('m', 'i'), 5564)),
 (9, (('i', 's'), 23459))]

In [None]:
toy_pair_stats = get_pair_stats(toy_word_splits)
toy_pair_stats

defaultdict(int,
            {('T', 'h'): 2,
             ('h', 'i'): 5,
             ('i', 's'): 7,
             ('s', '/w'): 8,
             ('t', 'h'): 7,
             ('h', 'e'): 4,
             ('e', '/w'): 4,
             ('f', 'i'): 2,
             ('i', 'r'): 3,
             ('r', 's'): 2,
             ('s', 't'): 2,
             ('t', '/w'): 3,
             ('d', 'o'): 4,
             ('o', 'c'): 4,
             ('c', 'u'): 4,
             ('u', 'm'): 4,
             ('m', 'e'): 4,
             ('e', 'n'): 4,
             ('n', 't'): 4,
             ('t', '.'): 2,
             ('.', '/w'): 3,
             ('s', 'e'): 1,
             ('e', 'c'): 1,
             ('c', 'o'): 1,
             ('o', 'n'): 2,
             ('n', 'd'): 2,
             ('d', '/w'): 3,
             ('A', 'n'): 1,
             ('r', 'd'): 1,
             ('n', 'e'): 1,
             ('e', '.'): 1,
             ('I', 's'): 1,
             ('t', '?'): 1,
             ('?', '/w'): 1})

In [7]:
def fun(corpus:list, num_merges:int=15, verbose=False)->list:
    merges = {}
    vocab = get_unique_chars(corpus)
    current_splits = get_word_splits(corpus)
    for i in range(num_merges):
        # 1. Calculate Pair Frequencies
        pair_stats = get_pair_stats(current_splits)
        if not pair_stats:
            print('No more pairs to merge.')
            break

        # 2. Find Best Pair
        # The 'max' function iterates over all key-value pairs in the 'pair_stats' dictionary
        # The 'key=pair_stats.get' tells 'max' to use the frequency (value) for comparison, not the pair (key) itself
        # This way, 'max' selects the pair with the highest frequency
        best_pair = max(pair_stats, key=pair_stats.get)
        best_freq = pair_stats[best_pair]

        # 3. Merge the Best Pair
        current_splits = merge_pair(best_pair, current_splits)
        new_token = best_pair[0] + best_pair[1]

        # 4. Update Vocabulary
        vocab.append(new_token)

        # 5. Store Merge Rule
        merges[best_pair] = new_token

        if verbose:
            print(f'Merge Iteration {i+1}/{num_merges}')
            sorted_pairs = sorted(pair_stats.items(), key=lambda item: item[1], reverse=True)
            print(f'Top 5 Pair Frequencies: {sorted_pairs[:5]}')
            print(f'Found Best Pair: {best_pair} with Frequency: {best_freq}')
            print(f"Merging {best_pair} into '{new_token}'")
            print(f'Splits after merge: {current_splits}')
            print(f"Updated Vocabulary: {vocab}")
            print(f'Updated Merges: {merges}')
            print('-' * 30)

    return vocab, current_splits



In [None]:
vocab, splits = fun(toy_corpus)

In [None]:
vocab

[' ',
 '.',
 '?',
 'A',
 'I',
 'T',
 'c',
 'd',
 'e',
 'f',
 'h',
 'i',
 'm',
 'n',
 'o',
 'r',
 's',
 't',
 'u',
 '/w',
 's/w',
 'is/w',
 'th',
 'the',
 'the/w',
 'do',
 'doc',
 'docu',
 'docum',
 'docume',
 'documen',
 'document',
 'ir',
 './w',
 'd/w']

In [None]:
get_unique_chars(toy_corpus) # initial vocab

[' ',
 '.',
 '?',
 'A',
 'I',
 'T',
 'c',
 'd',
 'e',
 'f',
 'h',
 'i',
 'm',
 'n',
 'o',
 'r',
 's',
 't',
 'u',
 '/w']

In [None]:
splits

{('T', 'h', 'is/w'): 2,
 ('is/w',): 3,
 ('the/w',): 4,
 ('f', 'ir', 's', 't', '/w'): 2,
 ('document', './w'): 2,
 ('document', '/w'): 1,
 ('s', 'e', 'c', 'o', 'n', 'd/w'): 1,
 ('A', 'n', 'd/w'): 1,
 ('th', 'is/w'): 2,
 ('th', 'ir', 'd/w'): 1,
 ('o', 'n', 'e', './w'): 1,
 ('I', 's/w'): 1,
 ('document', '?', '/w'): 1}

In [None]:
get_word_splits(toy_corpus)

defaultdict(int,
            {('T', 'h', 'i', 's', '/w'): 2,
             ('i', 's', '/w'): 3,
             ('t', 'h', 'e', '/w'): 4,
             ('f', 'i', 'r', 's', 't', '/w'): 2,
             ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '.', '/w'): 2,
             ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '/w'): 1,
             ('s', 'e', 'c', 'o', 'n', 'd', '/w'): 1,
             ('A', 'n', 'd', '/w'): 1,
             ('t', 'h', 'i', 's', '/w'): 2,
             ('t', 'h', 'i', 'r', 'd', '/w'): 1,
             ('o', 'n', 'e', '.', '/w'): 1,
             ('I', 's', '/w'): 1,
             ('d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '?', '/w'): 1})

## Attention

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

In [9]:
hidden_size = 128  # Dimensionality of the model's hidden states
num_attention_heads = 16 # Total number of query heads
num_key_value_heads = 4  # Number of key/value heads (for GQA)
head_dim = hidden_size // num_attention_heads # Dimension of each attention head
max_position_embeddings = 256 # Maximum sequence length the model expects
rope_theta = 10000.0 # Base for RoPE frequency calculation
rms_norm_eps = 1e-5 # Epsilon for RMSNorm
attention_bias = False # Whether to use bias in Q
attention_dropout = 0.0 # Dropout probability for attention weights
use_qk_norm = True # Whether to apply L2 norm to Q and K before attention

# Sample Input
batch_size = 2
sequence_length = 10
hidden_states = torch.randn(batch_size, sequence_length, hidden_size)
# Create position IDs for each token in the sequence, repeated for each batch
# torch.arange(0, sequence_length) generates a 1D tensor with values from 0 to sequence_length-1
# The unsqueeze(0) adds an extra dimension at the 0th position, making it (1, sequence_length)
# This allows repeat(batch_size, 1) to create a tensor of shape (batch_size, sequence_length)
position_ids = torch.arange(0, sequence_length).unsqueeze(0).repeat(batch_size, 1) # Shape: (batch_size, sequence_length)
# Simple causal mask (upper triangular) for demonstration
# In reality, Llama4 uses a more complex mask creation including padding handling
attention_mask = torch.triu(torch.ones(sequence_length, sequence_length) * -torch.inf, diagonal=1)
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) # Shape: (1, 1, sequence_length, sequence_length)
attention_mask = attention_mask.expand(batch_size, 1, -1, -1) # Shape: (batch_size, 1, sequence_length, sequence_length)

In [11]:
print('Sample Input Shapes:')
print(f'\thidden_states: {hidden_states.shape}')
print(f'\tposition_ids: {position_ids.shape}')
print(f'\tattention_mask: {attention_mask.shape}')

Sample Input Shapes:
	hidden_states: torch.Size([2, 10, 128])
	position_ids: torch.Size([2, 10])
	attention_mask: torch.Size([2, 1, 10, 10])


In [12]:
# Define projection layers
q_proj = nn.Linear(hidden_size, num_attention_heads*head_dim, bias=attention_bias)
k_proj = nn.Linear(hidden_size, num_key_value_heads*head_dim, bias=attention_bias)
v_proj = nn.Linear(hidden_size, num_key_value_heads*head_dim, bias=attention_bias)
o_proj = nn.Linear(num_attention_heads*head_dim, hidden_size, bias=attention_bias)

In [13]:
# Calculate projections
query_states = q_proj(hidden_states)
key_states = k_proj(hidden_states)
value_states = v_proj(hidden_states)

In [14]:
# Reshape Q, K, V for multi-head attention
# Target shape: (batch_size, num_heads, sequence_length, head_dim)
query_states = query_states.view(batch_size, sequence_length, num_attention_heads, head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, sequence_length, num_key_value_heads, head_dim).transpose(1, 2)

In [16]:
print('Projected Shapes:')
print(f'\tquery_states: {query_states.shape}') # (batch_size, num_attention_heads, sequence_length, head_dim)
print(f'\tkey_states: {key_states.shape}')     # (batch_size, num_key_value_heads, sequence_length, head_dim)
print(f'\tvalue_states: {value_states.shape}')   # (batch_size, num_key_value_heads, sequence_length, head_dim)

Projected Shapes:
	query_states: torch.Size([2, 16, 10, 8])
	key_states: torch.Size([2, 4, 10, 8])
	value_states: torch.Size([2, 4, 10, 8])


In [17]:
num_key_value_groups = num_attention_heads // num_key_value_heads
print(f'Num Key/Value Groups (Q heads per K/V head): {num_key_value_groups}')

Num Key/Value Groups (Q heads per K/V head): 4
