## 2.1 The Unicode Standard

In [1]:
ord("!")

33

In [8]:
# list(range(ord("!"), ord("~") + 1))
list(range(ord("¡"), ord("¬") + 1))
# list(range(ord("®"), ord("ÿ") + 1))

[161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172]

In [10]:
# ord("¬")
# ord("¡")
# ord("®")
ord("ÿ")

255

In [5]:
ord('牛')

29275

In [7]:
chr(29275)

'牛'

In [1]:
chr(0)

'\x00'

In [2]:
chr(0).__repr__()

"'\\x00'"

In [3]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [4]:
print("this is a test" + chr(0) + "string")

this is a test string


## 2.2 Unicode Encodings

In [1]:
test_string = "hello! こんにちは!"
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'


In [2]:
print(type(utf8_encoded))

<class 'bytes'>


In [4]:
# Get the byte values for the encoded string (integers from 0 to 255).
list(utf8_encoded)

[104,
 101,
 108,
 108,
 111,
 33,
 32,
 227,
 129,
 147,
 227,
 130,
 147,
 227,
 129,
 171,
 227,
 129,
 161,
 227,
 129,
 175,
 33]

In [5]:
# One byte does not necessarily correspond to one Unicode character!
print(len(test_string))

13


In [6]:
print(len(utf8_encoded))

23


In [7]:
print(utf8_encoded.decode("utf-8"))

hello! こんにちは!


In [19]:
"hello".encode("utf-8")

b'hello'

In [8]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))

'hello'

In [11]:
decode_utf8_bytes_to_str_wrong("hello! こんにちは!".encode("utf-8"))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe3 in position 0: unexpected end of data

In [29]:
# Convert a hex byte like "C0" into an 8-bit binary string
byte = int("2f", 16) #int("c0", 16)
print(byte)
binary = f"{byte:08b}"
print(binary)

47
00101111


In [30]:
value = int(binary, 2)
print(value)

47


In [31]:
b = bytes([value])
print(b)

b'/'


In [32]:
b.decode("utf-8")

'/'

## 2.3 Subword Tokenization & 2.4 BPE Tokenizer Training

In [36]:
text_string = "the"
text_encoded = text_string.encode("utf-8")
print(text_encoded)
print(type(text_encoded))
list(text_encoded)

b'the'
<class 'bytes'>


[116, 104, 101]

In [41]:
for i in list(text_encoded):
    print(f"{i:08b}")

01110100
01101000
01100101


In [5]:
import regex as re

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

re.findall(PAT, "some text that i'll pre-tokenize")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

In [43]:
re.findall(PAT, "some text that i'll pre-tokenize !!!")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize', ' !!!']

In [44]:
re.findall(PAT, "some text that i'll pre-tokenize!!!  ")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize', '!!!', '  ']

In [45]:
re.findall(PAT, "some text that i'll pre-tokenize ")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize', ' ']

In [49]:
re.findall(PAT, "some text that i'll pre-tokenize \n\n")

['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize', ' \n\n']

In [54]:
for m in re.finditer(r"\d+", "a1 b22 c333"):
    print(m.group(), m.start(), m.end())

1 1 2
22 4 6
333 8 11


In [56]:
for m in re.finditer(PAT, "some text that i'll pre-tokenize"):
    print(m.group())

some
 text
 that
 i
'll
 pre
-
tokenize


In [59]:
# If regex contains capturing group, using finditer() to return full match info
pattern = r"A(\d+)" 
text = "A1 A22 A333"

print(re.findall(pattern, text))
print([m.group() for m in re.finditer(pattern, text)])

['1', '22', '333']
['A1', 'A22', 'A333']


In [62]:
for m in re.finditer(pattern, text):
    m_encoded = m.group().encode("utf-8")
    print(m.group(), list(m_encoded))

A1 [65, 49]
A22 [65, 50, 50]
A333 [65, 51, 51, 51]


In [63]:
max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")])

('BA', 'A')

In [38]:
pattern = r"\p{L}+"

text = """
low low low low low
lower lower widest widest widest
newest newest newest newest newest newest
"""

In [39]:
vocab = {}
for m in re.finditer(pattern, text):
    k = m.group()
    if k in vocab:
        vocab[k] += 1
    else:
        vocab[k] = 1
vocab

{'low': 5, 'lower': 2, 'widest': 3, 'newest': 6}

In [40]:
bytes(tuple("low".encode("utf-8")))

b'low'

In [41]:
freq_table = {}
for word, count in vocab.items():
    freq_table[tuple(word.encode("utf-8"))] = count

freq_table

{(108, 111, 119): 5,
 (108, 111, 119, 101, 114): 2,
 (119, 105, 100, 101, 115, 116): 3,
 (110, 101, 119, 101, 115, 116): 6}

In [44]:
# combine above 2 steps together:
vocab = {}
for m in re.finditer(pattern, text):
    word = m.group()
    # k = tuple(word.encode("utf-8"))
    k = tuple(bytes([c]) for c in word.encode("utf-8"))
    if k in vocab:
        vocab[k] += 1
    else:
        vocab[k] = 1
vocab

{(b'l', b'o', b'w'): 5,
 (b'l', b'o', b'w', b'e', b'r'): 2,
 (b'w', b'i', b'd', b'e', b's', b't'): 3,
 (b'n', b'e', b'w', b'e', b's', b't'): 6}

In [45]:
# count pair frequency
from collections import Counter

def get_pair_freqs(vocab: dict[tuple[bytes, ...], int]) -> dict[tuple[bytes, bytes], int]:
    """
    vocab: { (b'l', b'o', b'w'): 5, ... }
    returns: { (b'l', b'o'): 7, (b'o', b'w'): 7, ... }
    """
    pair_freqs = Counter()
    for t, f in vocab.items():
        for i in range(len(t) - 1):
            pair = (t[i], t[i+1])
            pair_freqs[pair] += f
    return pair_freqs


In [46]:
pair_freqs = get_pair_freqs(vocab)
pair_freqs

Counter({(b'e', b's'): 9,
         (b's', b't'): 9,
         (b'w', b'e'): 8,
         (b'l', b'o'): 7,
         (b'o', b'w'): 7,
         (b'n', b'e'): 6,
         (b'e', b'w'): 6,
         (b'w', b'i'): 3,
         (b'i', b'd'): 3,
         (b'd', b'e'): 3,
         (b'e', b'r'): 2})

In [47]:
[(kv[1], kv[0]) for kv in pair_freqs.items()]

[(7, (b'l', b'o')),
 (7, (b'o', b'w')),
 (8, (b'w', b'e')),
 (2, (b'e', b'r')),
 (3, (b'w', b'i')),
 (3, (b'i', b'd')),
 (3, (b'd', b'e')),
 (9, (b'e', b's')),
 (9, (b's', b't')),
 (6, (b'n', b'e')),
 (6, (b'e', b'w'))]

In [48]:
max([(kv[1], kv[0]) for kv in pair_freqs.items()])

(9, (b's', b't'))

In [49]:
def choose_best_pair(pair_freqs: dict[tuple[bytes, bytes], int]) -> tuple[bytes, bytes]:
    if not pair_freqs:
        return None
    # max by (frequency, pair) → frequency first, then lexicographically greatest pair
    best_pair = max(pair_freqs.items(), key=lambda kv: (kv[1], kv[0]))[0]
    return best_pair

In [50]:
best_pair = choose_best_pair(pair_freqs)
print(best_pair)

(b's', b't')


In [51]:
vocab

{(b'l', b'o', b'w'): 5,
 (b'l', b'o', b'w', b'e', b'r'): 2,
 (b'w', b'i', b'd', b'e', b's', b't'): 3,
 (b'n', b'e', b'w', b'e', b's', b't'): 6}

In [52]:
a, b = best_pair
print(a)
print(b)

b's'
b't'


In [53]:
type(a)

bytes

In [57]:
new_vocab = Counter()
new_symbol = a + b

for symbols, freq in vocab.items():
    symbol_list: list[bytes] = []
    i = 0
    while i < len(symbols):
        if (i < len(symbols) - 1) and (a == symbols[i] and b == symbols[i+1]):
            symbol_list.append(new_symbol)
            i += 2
        else:
            symbol_list.append(symbols[i])
            i += 1
    
    new_vocab[tuple(symbol_list)] += freq

dict(new_vocab)

{(b'l', b'o', b'w'): 5,
 (b'l', b'o', b'w', b'e', b'r'): 2,
 (b'w', b'i', b'd', b'e', b'st'): 3,
 (b'n', b'e', b'w', b'e', b'st'): 6}

In [58]:
def merge_vocab_once(vocab: dict[tuple[bytes, ...], int],
                     pair: tuple[bytes, bytes]) -> dict[tuple[bytes, ...], int]:
    """
    Merge all occurrences of `pair` in every word.
    """
    from collections import Counter
    a, b = pair
    new_symbol = a + b
    new_vocab = Counter()

    for symbols, freq in vocab.items():
        i = 0
        symbol_list: list[bytes] = []
        while i < len(symbols):
            if (i < len(symbols) - 1) and (a == symbols[i] and b == symbols[i+1]):
                symbol_list.append(new_symbol)
                i += 2
            else:
                symbol_list.append(symbols[i])
                i += 1
        new_vocab[tuple(symbol_list)] += freq

    return dict(new_vocab)

In [59]:
new_vocab = merge_vocab_once(vocab, best_pair)
new_vocab

{(b'l', b'o', b'w'): 5,
 (b'l', b'o', b'w', b'e', b'r'): 2,
 (b'w', b'i', b'd', b'e', b'st'): 3,
 (b'n', b'e', b'w', b'e', b'st'): 6}

In [60]:
## Put all together
pair_freqs = get_pair_freqs(vocab)
best_pair = choose_best_pair(pair_freqs)
if best_pair is not None:
    vocab = merge_vocab_once(vocab, best_pair)

vocab    

{(b'l', b'o', b'w'): 5,
 (b'l', b'o', b'w', b'e', b'r'): 2,
 (b'w', b'i', b'd', b'e', b'st'): 3,
 (b'n', b'e', b'w', b'e', b'st'): 6}

In [61]:
## second round of pair + merge
pair_freqs = get_pair_freqs(vocab)
best_pair = choose_best_pair(pair_freqs)
if best_pair is not None:
    vocab = merge_vocab_once(vocab, best_pair)

In [62]:
vocab

{(b'l', b'o', b'w'): 5,
 (b'l', b'o', b'w', b'e', b'r'): 2,
 (b'w', b'i', b'd', b'est'): 3,
 (b'n', b'e', b'w', b'est'): 6}

### Sanity Check with the 6 rounds results

In [67]:
pattern = r"\p{L}+"

text = """
low low low low low
lower lower widest widest widest
newest newest newest newest newest newest
"""

vocab = {}
for m in re.finditer(pattern, text):
    k = m.group()
    if k in vocab:
        vocab[k] += 1
    else:
        vocab[k] = 1

print("Original Vocabulary:")
print(vocab)

rounds = 6
i = 0
while i < rounds:
    pair_freqs = get_pair_freqs(vocab)
    best_pair = choose_best_pair(pair_freqs)
    if best_pair is not None:
        vocab = merge_vocab_once(vocab, best_pair)
    i += 1
    print(vocab)

print("6 round merges, New Vocabulary:")
print(vocab)

Original Vocabulary:
{'low': 5, 'lower': 2, 'widest': 3, 'newest': 6}
{('l', 'o', 'w'): 5, ('l', 'o', 'w', 'e', 'r'): 2, ('w', 'i', 'd', 'e', 'st'): 3, ('n', 'e', 'w', 'e', 'st'): 6}
{('l', 'o', 'w'): 5, ('l', 'o', 'w', 'e', 'r'): 2, ('w', 'i', 'd', 'est'): 3, ('n', 'e', 'w', 'est'): 6}
{('l', 'ow'): 5, ('l', 'ow', 'e', 'r'): 2, ('w', 'i', 'd', 'est'): 3, ('n', 'e', 'w', 'est'): 6}
{('low',): 5, ('low', 'e', 'r'): 2, ('w', 'i', 'd', 'est'): 3, ('n', 'e', 'w', 'est'): 6}
{('low',): 5, ('low', 'e', 'r'): 2, ('w', 'i', 'd', 'est'): 3, ('n', 'e', 'west'): 6}
{('low',): 5, ('low', 'e', 'r'): 2, ('w', 'i', 'd', 'est'): 3, ('ne', 'west'): 6}
6 round merges, New Vocabulary:
{('low',): 5, ('low', 'e', 'r'): 2, ('w', 'i', 'd', 'est'): 3, ('ne', 'west'): 6}


## 2.5 Experimenting with BPE Tokenizer Training

In [24]:
# safest way to load txt data samples (treat it as a newline-separated text file)
def load_data(path, max_samples=None):
    """Yield up to max_samples stories from TinyStoriesV2 GPT-4 dataset."""
    with open(path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if max_samples and i >= max_samples:
                break
            # line = line.rstrip("\n")
            line = line.strip()
            if not line:
                continue
            yield line

In [25]:
path = "data/TinyStoriesV2-GPT4-train.txt"
samples = list(load_data(path, max_samples=20))

print(len(samples))
print(samples[0])

18
Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He saw many amazing things, like beautiful vases that were on display in a store. One day, Ben was walking through the store when he came across a very special vase. When Ben saw it he was amazed!


In [26]:
print(samples)

['Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He saw many amazing things, like beautiful vases that were on display in a store. One day, Ben was walking through the store when he came across a very special vase. When Ben saw it he was amazed!', 'He said, “Wow, that is a really amazing vase! Can I buy it?”', 'The shopkeeper smiled and said, “Of course you can. You can take it home and show all your friends how amazing it is!”', "So Ben took the vase home and he was so proud of it! He called his friends over and showed them the amazing vase. All his friends thought the vase was beautiful and couldn't believe how lucky Ben was.", "And that's how Ben found an amazing vase in the store!", '<|endoftext|>', 'Once upon a time, there was a reliable otter named Ollie. He lived in a river with his family. They all loved to play and swim together.', 'One day, Ollie\'s mom said, "Ollie, hurry and get some fish for dinner!" Ollie swam fast to catch f

### load in chunks and pretokenize

In [1]:
from cs336_basics.pretokenization_example import find_chunk_boundaries

In [4]:
path = "data/TinyStoriesV2-GPT4-valid.txt"

In [3]:
with open(path, "rb") as f:
    num_processes = 4
    boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

    # The following is a serial implementation, but you can parallelize this
    # by sending each start/end pair to a set of processes.
    for start, end in zip(boundaries[:-1], boundaries[1:]):
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")
        # Run pre-tokenization on your chunk and store the counts for each pre-token

In [4]:
chunk



In [5]:
import re
pre_tokenizer_re = re.compile(r"[A-Za-z]+(?:'[A-Za-z]+)?|\d+|[^\sA-Za-z\d]")

special_tokens = ["<|endoftext|>"] 
# define a list for special tokens, this is good for scale
# special_tokens = ["<|endoftext|>",  "<|pad|>", "<|bos|>"]

split_pattern = re.compile(
    "|".join(re.escape(tok) for tok in special_tokens)
)
split_pattern

re.compile(r'<\|endoftext\|>', re.UNICODE)

In [6]:
def pretokenize(text):
    # return text.split # split text with any whitespace (space, \n, \t, etc.) 
    tokens = [m.group() for m in re.finditer(pre_tokenizer_re, text)]
    return tokens

In [13]:
def bpe_special_tokens(text, special_tokens):
    # 1. split on special tokens
    split_pattern = re.compile(
        "|".join(re.escape(tok) for tok in special_tokens)
    )
    docs = split_pattern.split(text)
    for doc in docs:
        # print(doc)
        if not doc:
            continue
        # 2. pre-tokenize each document independently
        tokens = pretokenize(doc)
        for token in tokens:
            # 3. convert each token to byte sequence
            yield token.encode("utf-8")

for tok in bpe_special_tokens(chunk, ["<|endoftext|>"]):
    print(tok)

### run pair_freqs and merge on each doc

In [15]:
docs = split_pattern.split(chunk)
doc = docs[1]
doc

'\nOnce upon a time, there was a big purple cat named Tom. Tom lived in a little house with his best friend, a small girl named Sue. Sue loved to play with Tom in their yard, where there was a big patch of grass.\nOne sunny day, Sue and Tom went outside to play. They saw a big box on the ground. Sue wondered what was inside. "Let\'s weigh it," said Tom. They tried to lift the box, but it was too heavy. They knew something big was inside.\nSue and Tom decided to open the box. Inside, they found a big purple ball. They were so happy! They played with the ball all day long in their grass patch. And from that day on, Sue and Tom had even more fun playing together.\n'

In [33]:
from collections import defaultdict

count = defaultdict(int)
for m in re.finditer(pre_tokenizer_re, doc):
    token = m.group()
    tok_seq = tuple(bytes([c]) for c in token.encode("utf-8"))
    count[tok_seq] += 1

count = dict(count)

In [34]:
count

{(b'O', b'n', b'c', b'e'): 1,
 (b'u', b'p', b'o', b'n'): 1,
 (b'a',): 7,
 (b't', b'i', b'm', b'e'): 1,
 (b',',): 8,
 (b't', b'h', b'e', b'r', b'e'): 2,
 (b'w', b'a', b's'): 5,
 (b'b', b'i', b'g'): 5,
 (b'p', b'u', b'r', b'p', b'l', b'e'): 2,
 (b'c', b'a', b't'): 1,
 (b'n', b'a', b'm', b'e', b'd'): 2,
 (b'T', b'o', b'm'): 7,
 (b'.',): 13,
 (b'l', b'i', b'v', b'e', b'd'): 1,
 (b'i', b'n'): 3,
 (b'l', b'i', b't', b't', b'l', b'e'): 1,
 (b'h', b'o', b'u', b's', b'e'): 1,
 (b'w', b'i', b't', b'h'): 3,
 (b'h', b'i', b's'): 1,
 (b'b', b'e', b's', b't'): 1,
 (b'f', b'r', b'i', b'e', b'n', b'd'): 1,
 (b's', b'm', b'a', b'l', b'l'): 1,
 (b'g', b'i', b'r', b'l'): 1,
 (b'S', b'u', b'e'): 6,
 (b'l', b'o', b'v', b'e', b'd'): 1,
 (b't', b'o'): 4,
 (b'p', b'l', b'a', b'y'): 2,
 (b't', b'h', b'e', b'i', b'r'): 2,
 (b'y', b'a', b'r', b'd'): 1,
 (b'w', b'h', b'e', b'r', b'e'): 1,
 (b'p', b'a', b't', b'c', b'h'): 2,
 (b'o', b'f'): 1,
 (b'g', b'r', b'a', b's', b's'): 2,
 (b'O', b'n', b'e'): 1,
 (b's', b'u'

In [35]:
pair_freqs = defaultdict(int) # dict[tuple[bytes, bytes], int]

for seq, freq in count.items():
    if len(seq) < 2:
        continue
    for i in range(len(seq) - 1):
        pair = (seq[i], seq[i+1])
        pair_freqs[pair] += freq

pair_freqs = dict(pair_freqs)

In [57]:
# pair_freqs
def get_pair_freqs(count: dict[tuple[bytes, ...], int]) -> dict[tuple[bytes, bytes], int]:
    pair_freqs = defaultdict(int)
    for seq, freq in count.items():
        if len(seq) < 2:
            continue
        for i in range(len(seq) - 1):
            pair = (seq[i], seq[i+1])
            pair_freqs[pair] += freq
    return dict(pair_freqs)

In [50]:
## append best_pair to merges
# best_pair = max([(kv[1], kv[0]) for kv in pair_freqs.items()])[1]
best_pair = max(pair_freqs.items(), key=lambda kv: (kv[1], kv[0]))[0]
best_pair

(b'h', b'e')

In [53]:
a, b = best_pair
new_symbol = a + b
new_symbol

b'he'

In [59]:
new_count = defaultdict(int)

for seq, freq in count.items():
    if len(seq) < 2:
        continue
    new_seq = []
    i = 0
    while i < len(seq):
        if (i < len(seq) - 1) and (seq[i] == a and seq[i+1] == b):
            new_seq.append(new_symbol)
            i += 2
        else:
            new_seq.append(seq[i])
            i += 1
    new_count[tuple(new_seq)] += freq

# dict(new_count)

In [60]:
def apply_merge(count: dict[tuple[bytes, ...], int], best_pair: tuple[bytes, bytes]) -> dict[tuple[bytes, ...], int]:
    a, b = best_pair
    new_symbol = a + b
    new_count = defaultdict(int)
    for seq, freq in count.items():
        if len(seq) < 2:
            continue
        new_seq = []
        i = 0
        while i < len(seq):
            if (i < len(seq) - 1) and (seq[i] == a and seq[i+1] == b):
                new_seq.append(new_symbol)
                i += 2
            else:
                new_seq.append(seq[i])
                i += 1
        new_count[tuple(new_seq)] += freq
    return dict(new_count)

In [70]:
vocab_size = 1000

## 1. initialize vocabulary with single-byte tokens
vocab: dict[int: bytes] = {i: bytes([i]) for i in range(256)}
# vocab

# Add special tokens (do not affect training; just occupy vocab ids)
# Put them right after the 256 bytes.
for j, tok in enumerate(special_tokens):
    vocab[256 + j] = tok.encode("utf-8")

next_id = 256 + len(special_tokens)
num_merges = vocab_size - next_id
if num_merges < 0:
    print("original vocab")


merges = []
for _ in range(num_merges):
    pair_freqs = get_pair_freqs(count)
    if not pair_freqs:
        break
    best_pair = max(pair_freqs.items(), key=lambda kv: (kv[1], kv[0]))[0]
    a, b = best_pair
    merges.append(a + b)
    vocab[next_id] = a + b
    count = apply_merge(count, best_pair)
    next_id += 1

    if next_id >= vocab_size:
        break

In [25]:
# merges

In [24]:
# vocab

### test gpt-2 bytes_to_unicode function
mapping between every possible byte (an integer from 0 to 255) to a printable unicode string character representation.This function is taken from the GPT-2 code

In [11]:
bs = (
    list(range(ord("!"), ord("~") + 1))
    + list(range(ord("¡"), ord("¬") + 1))
    + list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
    if b not in bs:
        bs.append(b)
        cs.append(2**8 + n)
        n += 1
characters = [chr(n) for n in cs]

In [18]:
for i in zip(bs, characters):
    if i[0] < 10:
        print(i)

(0, 'Ā')
(1, 'ā')
(2, 'Ă')
(3, 'ă')
(4, 'Ą')
(5, 'ą')
(6, 'Ć')
(7, 'ć')
(8, 'Ĉ')
(9, 'ĉ')


In [23]:
# dict(zip(bs, characters))

### multiprocessor parallel for pretokenization

In [8]:
import regex as re
from collections import defaultdict
from pathlib import Path
from multiprocessing import Pool
from cs336_basics.pretokenization_example import find_chunk_boundaries

In [9]:
_GPT2_PRETOKENIZE_PATTERN = re.compile(
    r"'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
)

In [10]:
def _iter_pretokens_with_specials(text: str, special_tokens: list[str]) -> list[str]:
    if not special_tokens:
        return _GPT2_PRETOKENIZE_PATTERN.findall(text)

    else:
        pretokens: list[str] = []
        # re.escape() to avoid special tokens being interpreted as regular expression patterns.
        specials_pat = re.compile("|".join(re.escape(s) for s in special_tokens))
        docs = specials_pat.split(text)
        for doc in docs:
            if not doc:
                continue
            pretokens.extend(_GPT2_PRETOKENIZE_PATTERN.findall(doc))
        return pretokens

In [12]:
def _process_chunk(args: tuple[int, int, str, list[str]]) -> dict[tuple[bytes, ...], int]:
    """
    process a single chunk and return word frequencies
    """
    start, end, input_path, special_tokens = args

    # read the chunk
    with open(input_path, "rb") as f:
        f.seek(start)
        chunk_bytes = f.read(end - start)
        chunk = chunk_bytes.decode("utf-8", errors="ignore")

    # pretokenize the chunk
    pretokens = _iter_pretokens_with_specials(chunk, special_tokens)

    # build word frequencies for this chunk
    word_freqs: dict[tuple[bytes, ...], int] = defaultdict(int)
    # Pre-compute for O(1) lookup and avoid repeated encoding
    special_token_set = set(special_tokens)
    special_token_bytes = {s: s.encode("utf-8") for s in special_tokens}
    for tok in pretokens:
        if tok in special_token_set:
            word = (special_token_bytes[tok],)
        else:
            b = tok.encode("utf-8")
            word = tuple(bytes([c]) for c in b)
        word_freqs[word] += 1

    return dict(word_freqs)

In [13]:
path = "data/TinyStoriesV2-GPT4-valid.txt"

In [14]:
num_processes = 4
word_freqs = defaultdict(int)
input_path_str = str(path)
special_tokens = ["<|endoftext|>"]

with open(path, "rb") as f:
    boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

chunk_args = [
    (start, end, input_path_str, special_tokens)
    for start, end in zip(boundaries[:-1], boundaries[1:])
]

In [16]:
# print(chunk_args)

In [18]:
with Pool(processes=num_processes) as pool:
    chunk_results = pool.map(_process_chunk, chunk_args)

for chunk_word_freqs in chunk_results:
    for word, freq in chunk_word_freqs.items():
        word_freqs[word] += freq

In [20]:
# word_freqs

## 2.6 BPE Training on TinyStories 

In [1]:
from cs336_basics.bpe import train_bpe, gpt2_bytes_to_unicode
from pathlib import Path
import time
import psutil
import json
import os
import cProfile
import pstats

In [4]:
# safest way to load txt data samples (treat it as a newline-separated text file)
def load_data(path, max_samples=None):
    """Yield up to max_samples stories from TinyStoriesV2 GPT-4 dataset."""
    with open(path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if max_samples and i >= max_samples:
                break
            # line = line.rstrip("\n")
            line = line.strip()
            if not line:
                continue
            yield line

In [4]:
path = "data/TinyStoriesV2-GPT4-train.txt"
samples = list(load_data(path, max_samples=20))

print(len(samples))
print(samples[0])

18
Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He saw many amazing things, like beautiful vases that were on display in a store. One day, Ben was walking through the store when he came across a very special vase. When Ben saw it he was amazed!


In [14]:
def save_training_samples(text, output_path):
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(text, f)

In [35]:
input_path = Path("data/TinyStoriesV2-GPT4-train.txt")
samples_ts = list(load_data(input_path, 4000))
len(samples_ts)

3723

In [24]:
samples_output_path = "data/TinyStoriesV2-GPT4-train-samples.txt"
save_training_samples(samples_ts, samples_output_path)

In [25]:
# input_path = Path("data/TinyStoriesV2-GPT4-train.txt")
# input_path = Path("data/TinyStoriesV2-GPT4-valid.txt")
input_path = samples_output_path
vocab_size = 10_000
special_tokens = ["<|endoftext|>"]

In [26]:
def bytes_to_gpt2_string(byte_token: bytes) -> str:
    byte_to_unicode = gpt2_bytes_to_unicode()
    return "".join(byte_to_unicode[b] for b in byte_token)

In [27]:
def get_memory_usage() -> float:
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 * 1024)

In [28]:
initial_memory_mb = get_memory_usage()

# Profile the training
profiler = cProfile.Profile()
profiler.enable()

print("Starting BPE training ...")
start_time = time.time()

vocab, merges = train_bpe(input_path, vocab_size, special_tokens)

profiler.disable()

end_time = time.time()
training_time = end_time - start_time
training_time_minutes = training_time / 60
training_time_hours = training_time / 3600

peak_memory_mb = get_memory_usage()
memory_used_mb = peak_memory_mb -  initial_memory_mb

# Print results
print(f"\nTraining completed!")
print(f"Time: {training_time:.2f} seconds ({training_time_minutes:.2f} minutes, {training_time_hours:.4f} hours)")
print(f"Memory used: {memory_used_mb:.2f} MB ({memory_used_mb / 1024:.2f} GB)")
print(f"Vocabulary size: {len(vocab)}")
print(f"Number of merges: {len(merges)}")

Starting BPE training ...
Debug: using sequential processing for small file: data/TinyStoriesV2-GPT4-train-samples.txt

Training completed!
Time: 28.77 seconds (0.48 minutes, 0.0080 hours)
Memory used: 6.32 MB (0.01 GB)
Vocabulary size: 6034
Number of merges: 5777


In [29]:
# Print statistics sorted by cumulative time
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(20)  # Top 20 functions

# Or sort by total time
print("\n" + "="*80)
stats.sort_stats('tottime')
stats.print_stats(20)

         95647689 function calls (95647594 primitive calls) in 27.861 seconds

   Ordered by: cumulative time
   List reduced from 340 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      5/4    0.000    0.000   27.858    6.964 /home/chenchenpan/Projects/cs336-assignment1-basics/.venv/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3665(run_code)
      5/4    0.000    0.000   25.669    6.417 {built-in method builtins.exec}
       27    4.479    0.166   19.887    0.737 {built-in method time.sleep}
       26    0.125    0.005   19.856    0.764 /usr/lib/python3.12/asyncio/base_events.py:1910(_run_once)
 12022539    7.856    0.000   11.245    0.000 /home/chenchenpan/Projects/cs336-assignment1-basics/cs336_basics/bpe.py:109(_apply_merge)
      2/1    0.003    0.001    9.486    9.486 /tmp/ipykernel_3000687/2402502309.py:1(<module>)
     5778    7.762    0.001    8.945    0.002 /home/chenchenpan/Projects/cs336-assignment1-b

<pstats.Stats at 0x7f2b242093a0>

In [30]:
longest_tok_bytes = max(vocab.values(), key=len)
longest_tok_length = len(longest_tok_bytes)
longest_tok_id = [id for id, tok in vocab.items() if tok == longest_tok_bytes][0]
longest_tok_str = longest_tok_bytes.decode("utf-8", errors="replace")
longest_tok_gpt2_str = bytes_to_gpt2_string(longest_tok_bytes)

In [31]:
print(f"\nLongest token:")
print(f"  Token Id: {longest_tok_id}")
print(f"  Length: {longest_tok_length} bytes")
print(f"  UTF-8 string: {repr(longest_tok_str)}")
print(f"  GPT-2 representation: {longest_tok_gpt2_str}")


Longest token:
  Token Id: 5727
  Length: 17 bytes
  UTF-8 string: ' enthusiastically'
  GPT-2 representation: Ġenthusiastically


In [32]:
def save_vocab(vocab: dict[int, bytes], output_path: Path):
    byte_to_unicode = gpt2_bytes_to_unicode()
    gpt2_vocab = {}
    for tok_id, tok_bytes in vocab.items():
        gpt2_string = "".join(byte_to_unicode[b] for b in tok_bytes)
        gpt2_vocab[gpt2_string] = tok_id

    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(gpt2_vocab, f, indent=4, ensure_ascii=False)

In [33]:
def save_merges(merges: list[tuple[bytes, bytes]], output_path: Path):
    byte_to_unicode = gpt2_bytes_to_unicode()
    with open(output_path, "w", encoding="utf-8") as f:
        for pair in merges:
            tok1_str = "".join(byte_to_unicode[b] for b in pair[0])
            tok2_str = "".join(byte_to_unicode[b] for b in pair[1])
            f.write(f"{tok1_str} {tok2_str}\n")

In [34]:
# Save results
save_vocab(vocab, Path("tinystories_vocab.json"))
save_merges(merges, Path("tinystories_merges.txt"))
print("\n✓ Saved vocab and merges to disk!")


✓ Saved vocab and merges to disk!


### train BPE on OpenWebText

In [43]:
# sample some training data
input_path = Path("data/owt_train.txt")
max_samples = 40_000_000
samples_owt = list(load_data(input_path, max_samples))
print(len(samples_owt))

samples_output_path = "data/owt_train-samples.txt"
save_training_samples(samples_ts, samples_output_path)

20000000


In [44]:
initial_memory_mb = get_memory_usage()

print("Starting BPE training ...")
start_time = time.time()

vocab_owt, merges_owt = train_bpe(
    # input_path=Path("data/owt_valid.txt"),
    input_path=samples_output_path,
    vocab_size=32_000,
    special_tokens = ["<|endoftext|>"]
)

end_time = time.time()
training_time = end_time - start_time
training_time_minutes = training_time / 60
training_time_hours = training_time / 3600

peak_memory_mb = get_memory_usage()
memory_used_mb = peak_memory_mb - initial_memory_mb

# Print results
print(f"\nTraining completed!")
print(f"Time: {training_time:.2f} seconds ({training_time_minutes:.2f} minutes, {training_time_hours:.4f} hours)")
print(f"Memory used: {memory_used_mb:.2f} MB ({memory_used_mb / 1024:.2f} GB)")
print(f"Vocabulary size: {len(vocab_owt)}")
print(f"Number of merges: {len(merges_owt)}")

save_vocab(vocab_owt, Path("owt_vocab.json"))
save_merges(merges_owt, Path("owt_merges.txt"))
print("\n✓ Saved vocab and merges to disk!")

Starting BPE training ...
Debug: using sequential processing for small file: data/owt_train-samples.txt

Training completed!
Time: 27.14 seconds (0.45 minutes, 0.0075 hours)
Memory used: 0.00 MB (0.00 GB)
Vocabulary size: 6034
Number of merges: 5777

✓ Saved vocab and merges to disk!


## BPE Tokenizer Experiments

In [12]:
from cs336_basics.bpe import BPETokenizer
from pathlib import Path
import time
import psutil
import json
import os
import numpy as np

### encode tinystories with its own tokenizer

In [2]:
tokenizer_ts = BPETokenizer.from_files(
    vocab_filepath="tinystories_vocab.json",
    merges_filepath="tinystories_merges.txt"
)

In [5]:
path = "data/TinyStoriesV2-GPT4-valid.txt"
samples = list(load_data(path, max_samples=10))

In [6]:
all_ids = []
for sample in samples:
    ids = tokenizer_ts.encode(sample)
    all_ids.append(ids)

In [7]:
# calculate the average bytes/token ratio
ratio = []
for s in samples:
    ids = tokenizer_ts.encode(s)
    bytes_per_token = len(s.encode("utf-8")) / max(len(ids) ,1)
    ratio.append(bytes_per_token)

avg_ratio = sum(ratio) / len(ratio)
print(avg_ratio)


3.3088844382895015


In [8]:
# calculate the throughput
start = time.time()
total_bytes = 0
for s in samples:
    _ = tokenizer_ts.encode(s)
    total_bytes += len(s.encode("utf-8"))
elapsed = time.time() - start

throughput = total_bytes / elapsed  # bytes per second
print("bytes/sec:", throughput)


bytes/sec: 748921.3221963577


In [9]:
# for 825GB of text (Pile dataset), we need
seconds = (825*1024**3) / throughput
hours = seconds / (60*60)
print(f"need {round(hours, 2)} hours for tokenizing Pile dataset")

need 328.56 hours for tokenizing Pile dataset


### encode owt with its own tokenizer

In [10]:
tokenizer_owt = BPETokenizer.from_files(
    vocab_filepath="owt_vocab.json",
    merges_filepath="owt_merges.txt"
)

path = "data/owt_valid.txt"
samples = list(load_data(path, max_samples=10))

all_ids = []
ratio = []
for sample in samples:
    ids = tokenizer_owt.encode(sample)
    all_ids.append(ids)

    bytes_per_token = len(sample.encode("utf-8")) / max(len(ids) ,1)
    ratio.append(bytes_per_token)

avg_ratio = sum(ratio) / len(ratio)
# print(all_ids[0])
print(avg_ratio)

2.910303254833445


In [11]:
all_ids = []
ratio = []
for sample in samples:
    ids = tokenizer_ts.encode(sample)
    all_ids.append(ids)

    bytes_per_token = len(sample.encode("utf-8")) / max(len(ids) ,1)
    ratio.append(bytes_per_token)

avg_ratio = sum(ratio) / len(ratio)
# print(all_ids[0])
print(avg_ratio)

2.910303254833445


### Save encoded results

In [14]:
with open("data/TinyStoriesV2-GPT4-valid.txt", encoding="utf-8") as f:
    ids = list(tokenizer_ts.encode_iterable(f))

arr = np.array(ids, dtype=np.uint16)
np.save("TinyStoriesV2-GPT4-valid_ids.npy", arr)

In [16]:
with open("data/owt_valid.txt", encoding="utf-8") as f:
    ids = list(tokenizer_owt.encode_iterable(f))

arr = np.array(ids, dtype=np.uint16)
np.save("owt_valid_ids.npy", arr)