In [None]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

import regex as re
re.findall(PAT, "some text that i'll pre-tokenize")


['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

In [None]:
import regex as re

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

text = "some text that i'll pre-tokenize some text that i'll pre-tokenize"

# Using re.finditer to get an iterator of match objects
matches = re.finditer(PAT, text)

# Example: Counting pre-tokens without storing them all
token_counts = {}
for match in matches:
    token = match.group(0)  # Get the matched string
    token_counts[token] = token_counts.get(token, 0) + 1

print("Pre-token counts:")
for token, count in token_counts.items():
    print(f"'{token}': {count}")


Pre-token counts:
'some': 1
' text': 2
' that': 2
' i': 2
''ll': 2
' pre': 2
'-': 2
'tokenize': 2
' some': 1


In [None]:
from datasets import load_dataset
from collections import defaultdict
import regex as re

dataset = load_dataset("roneneldan/TinyStories", split="train[:5000]")

word_count = defaultdict(int)

for example in dataset:
    # print(example["text"])
    matches = re.finditer(PAT, example["text"])

    for match in matches:
        token = match.group(0)
        word_count[token] += 1

print(word_count)

  from .autonotebook import tqdm as notebook_tqdm




In [None]:
def merge(indices: list[int], pair: tuple[int, int], new_index: int) -> list[int]:
    """Return `indices`, but with all instances of `pair` replaced with `new_index`."""
    # Find the first occurrence of the pair
    first_occurrence_index = -1
    for i in range(len(indices) - 1):
        if indices[i] == pair[0] and indices[i + 1] == pair[1]:
            first_occurrence_index = i
            break

    # If the pair was not found, return the original list
    if first_occurrence_index == -1:
        return indices

    new_indices = []
    i = first_occurrence_index
    new_indices = indices[:first_occurrence_index]
    while i < len(indices):
        if i + 1 < len(indices) and indices[i] == pair[0] and indices[i + 1] == pair[1]:
            new_indices.append(new_index)
            i += 2
        else:
            new_indices.append(indices[i])
            i += 1
    return new_indices

In [None]:
from dataclasses import dataclass

@dataclass(frozen=True)
class BPETokenizerParams:
    """All you need to specify a BPETokenizer."""
    vocab: dict[int, bytes]     # index -> bytes
    merges: dict[tuple[int, int], int]  # index1,index2 -> new_index


def train_bpe(word_count, num_merges: int) -> BPETokenizerParams:
    # Start with the list of bytes of string.
    indices_count = [(list(map(int, string.encode("utf-8"))), count) for string, count in word_count.items()]
    merges: dict[tuple[int, int], int] = {}  # index1, index2 => merged index
    vocab: dict[int, bytes] = {x: bytes([x]) for x in range(256)}  # index -> bytes
    for i in range(num_merges):
        # Count the number of occurrences of each pair of tokens
        counts = defaultdict(int)
        for indices, count in indices_count:
            for index1, index2 in zip(indices, indices[1:]):  # For each adjacent pair
                counts[(index1, index2)] += count
        # Find the most common pair.
        pair = max(counts, key=counts.get)
        index1, index2 = pair
        # Merge that pair.
        new_index = 256 + i
        merges[pair] = new_index
        vocab[new_index] = vocab[index1] + vocab[index2]
        for i, (indices, count) in enumerate(indices_count):
            indices = merge(indices, pair, new_index)
            indices_count[i] = (indices, count)

    return BPETokenizerParams(vocab=vocab, merges=merges)


bpe_params = train_bpe(word_count, 1000)
bpe_params

BPETokenizerParams(vocab={0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: 

In [None]:
for i in range(256, 400):
    print(bpe_params.vocab[i])

b'he'
b' t'
b' a'
b' s'
b' w'
b'nd'
b' the'
b'ed'
b'in'
b' and'
b' b'
b' to'
b' wa'
b' h'
b're'
b'it'
b' f'
b'ou'
b'er'
b' l'
b' he'
b' was'
b' c'
b' d'
b' m'
b' p'
b' o'
b'ing'
b'om'
b'ar'
b'ay'
b'is'
b' g'
b'The'
b'id'
b'at'
b'll'
b'en'
b' sa'
b'ne'
b' ha'
b'im'
b'le'
b' S'
b'an'
b'or'
b' it'
b' th'
b' The'
b'et'
b' H'
b'il'
b'on'
b' her'
b'ir'
b'ver'
b' in'
b' e'
b' He'
b' n'
b'ot'
b'ld'
b' She'
b'ut'
b'ow'
b' u'
b' be'
b'ck'
b'ce'
b' said'
b' she'
b'ig'
b' st'
b'oo'
b' so'
b'pp'
b' r'
b' "'
b'am'
b' y'
b'ke'
b' of'
b've'
b'ith'
b'st'
b'ked'
b' his'
b'very'
b' with'
b'ri'
b' day'
b' I'
b'nt'
b'ad'
b' pl'
b' up'
b' that'
b' They'
b' had'
b' you'
b'ily'
b'itt'
b'ould'
b'el'
b' T'
b'ent'
b' on'
b'es'
b' play'
b' for'
b' L'
b' they'
b' we'
b'my'
b'ittle'
b'ound'
b'un'
b' \n'
b'out'
b' little'
b"'s"
b'ch'
b' mom'
b' happ'
b'ly'
b'ime'
b' there'
b'her'
b' time'
b'all'
b' sm'
b' sh'
b' very'
b' li'
b'ht'
b' wh'
b' ne'
b' re'
b'ome'
b' B'
b'al'
b' want'
b'se'
b' do'


In [None]:
from abc import ABC

class Tokenizer(ABC):
    """Abstract interface for a tokenizer."""
    def encode(self, string: str) -> list[int]:
        raise NotImplementedError
    def decode(self, indices: list[int]) -> str:
        raise NotImplementedError


class BPETokenizer(Tokenizer):
    """BPE tokenizer given a set of merges and a vocabulary."""
    def __init__(self, params: BPETokenizerParams):
        self.params = params
    def encode(self, string: str) -> list[int]:
        indices = list(map(int, string.encode("utf-8")))
        # Note: this is a very slow implementation
        for pair, new_index in self.params.merges.items():
            indices = merge(indices, pair, new_index)
        return indices
    def decode(self, indices: list[int]) -> str:
        bytes_list = list(map(self.params.vocab.get, indices))
        string = b"".join(bytes_list).decode("utf-8")
        return string


tokenizer = BPETokenizer(bpe_params)

indices = tokenizer.encode("I love to play")
print(indices)

reconstructed_string = tokenizer.decode(indices)
print(reconstructed_string)


[73, 919, 267, 364]
I love to play


In [None]:
for i in indices:
    print(bpe_params.vocab[i])

b'I'
b' love'
b' to'
b' play'


In [None]:
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
import jax.numpy as jnp
import numpy as np


prompt_text = "Stanford CS336 Course is "
model_id = "google/gemma-3-4b-it"
# model_id = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(model_id)

input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
input_ids

tensor([[     2, 153480,  14923, 236800, 236800, 236825,  24435,    563, 236743]])

In [None]:
tokenizer.convert_ids_to_tokens(input_ids[0])

['<bos>', 'Stanford', '▁CS', '3', '3', '6', '▁Course', '▁is', '▁']

In [None]:
print(tokenizer.decode(input_ids[0]))

<bos>Stanford CS336 Course is 


In [None]:
tokenizer.vocab

{'gend': 36848,
 'logrus': 137864,
 '<unused1771>': 257673,
 'ğ': 237209,
 'бран': 68879,
 '▁<!--': 10072,
 '▁लहंगा': 137052,
 '▁brain': 7875,
 '▁días': 16692,
 '▁necessitates': 132117,
 'Elle': 99247,
 'графии': 173090,
 '▁слишком': 84199,
 '▁puntu': 193334,
 '▁Encore': 163217,
 '褓': 252375,
 'Andrea': 112618,
 '▁ఆరోగ': 139748,
 'язку': 179131,
 '괌': 250854,
 '産の': 200561,
 '▁blossomed': 190088,
 'immer': 23651,
 '▁ulang': 110563,
 '▁Gonna': 234042,
 'তি': 2739,
 '▁shortages': 55943,
 'ダブル': 104263,
 '▁മാറ': 218199,
 'Ire': 79812,
 'ologiche': 188319,
 'ОО': 162727,
 '𝑝': 248497,
 '엣': 245745,
 '▁catchment': 105028,
 'ને': 6313,
 'ريكي': 118214,
 '韬': 247077,
 'nosi': 151172,
 '▁Stade': 180520,
 '▁सरी': 216319,
 '▁ഘ': 139780,
 '▁grassland': 97818,
 'ലം': 105901,
 '앓': 247799,
 'mergeddata': 198442,
 '▁Consensus': 145209,
 '▁serez': 151659,
 '▁rejoindre': 149075,
 "_'+": 197429,
 'adig': 74762,
 '▁hert': 159144,
 '▁শাহ': 47846,
 'userdetails': 163982,
 'ℏ': 246210,
 '▁REALLY': 108374,
