# Byte-Pair Encoding (BPE) Tokenizer

## The Unicode Standard

## Problem (unicode1): Understanding Unicode (1 point)



a) What Unicode character does chr(0) return?

The `chr(0)` function returns the Unicode character with code point 0, which is the **null character** (often represented as `'\0'`). It is a control character used to signify the end of a string in many programming languages and systems. In Unicode, it is referred to as "NULL" and has no visual representation.

In [1]:
string = chr(0)

string

'\x00'

b) How does this character’s string representation (__repr__()) differ from its printed representation?

The string representation of a character using __repr__() provides a detailed and unambiguous representation of the object, often including escape sequences for non-printable characters. For the null character (chr(0)), __repr__() would return '\x00', showing its hexadecimal escape code.

In contrast, the printed representation (print()) attempts to display the character as-is. Since the null character has no visual representation, printing it results in no visible output.

In [2]:
string.__repr__()

"'\\x00'"

c) What happens when this character occurs in text?

When using print function with a string containing the null character, it will not display anything for that character.

In [3]:
"this is a test" + chr(0) + "string"


'this is a test\x00string'

In [4]:
print("this is a test" + chr(0) + "string")

this is a test string


## Problem (unicode2): Unicode Encodings (3 points)

a) What are some reasons to prefer training our tokenizer on UTF-8 encoded bytes, rather than UTF-16 or UTF-32? It may be helpful to compare the output of these encodings for various input strings

UTF-8 is often preferred for training tokenizers for several reasons:

1. **Space Efficiency**: UTF-8 uses a variable-length encoding scheme, where common characters (like ASCII) are represented with one byte, while less common characters use more bytes. This makes it more space-efficient for texts that primarily consist of ASCII characters.

2. **Compatibility**: UTF-8 is backward compatible with ASCII, meaning that any valid ASCII text is also valid UTF-8 text. This compatibility makes it easier to work with existing systems and libraries that expect ASCII input.

3. **Widespread Adoption**: UTF-8 is the most widely used encoding on the web and in many programming languages, making it easier to find libraries and tools that support it.

4. **Simplicity**: UTF-8's variable-length encoding allows for efficient processing of text, as it can handle a wide range of characters without requiring additional complexity in handling fixed-width encodings like UTF-16 or UTF-32.

5. **Avoiding BOM**: UTF-16 and UTF-32 often include a Byte Order Mark (BOM) to indicate endianness, which can complicate text processing. UTF-8 does not require a BOM, simplifying the handling of text files.


b) Consider the following (incorrect) function, which is intended to decode a UTF-8 byte string into a Unicode string. Why is this function incorrect? Provide an example of an input byte string that yields incorrect results.

In [5]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

In [6]:
decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))

'hello'

The function decode_utf8_bytes_to_str_wrong is incorrect because it attempts to decode each byte in the input byte string individually, rather than decoding the entire byte string as a single UTF-8 encoded sequence. UTF-8 is a variable-length encoding, meaning that some characters are represented by multiple bytes. Decoding each byte separately will fail for multi-byte characters, as the individual bytes do not represent valid UTF-8 characters on their own.

In [7]:
eur_sign = b'\xe2\x82\xac'

print(eur_sign.decode("utf-8"))

€


In [8]:
try:
    decode_utf8_bytes_to_str_wrong(eur_sign)
except UnicodeDecodeError as e:
    print(f"UnicodeDecodeError: {e}")

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe2 in position 0: unexpected end of data


When using this function, it will raise a `UnicodeDecodeError` when it encounters a byte sequence that does not correspond to a valid UTF-8 character.



## BPE Tokenizer Training

In [9]:
from cs336_basics.bpe.utils import find_chunk_boundaries

In [10]:
# Get number of CPU cores
import multiprocessing
num_cores = multiprocessing.cpu_count()
print(f"Number of CPU cores: {num_cores}")

Number of CPU cores: 8


In [11]:
with open('./data/owt_train.txt', "rb") as f:
    boundaries = find_chunk_boundaries(
        f, num_cores, "<|endoftext|>".encode("utf-8")
    )

print(f"Chunk boundaries: {boundaries}")


Chunk boundaries: [0, 1490070394, 2980128270, 4470223005, 5960269363, 7450321387, 8940385101, 10430448049, 11920511059]


In [12]:
with open('./data/owt_train.txt', "rb") as f:
    start, end = boundaries[0], boundaries[1]
    f.seek(start)
    chunk = f.read(end - start)
print(f"Chunk size: {len(chunk)} bytes")


Chunk size: 1490070394 bytes


In [13]:
chunk_str = chunk.decode("utf-8", errors="ignore")

print(f"Chunk string length: {len(chunk_str)} characters, first 100 characters: {chunk_str[:100]}")


Chunk string length: 1476958212 characters, first 100 characters: What wouldn't you do to save someone you love?

When They Come Calling is a modern ghost story, a su


In [14]:
from pathlib import Path

with open(Path('.').parent / "tests/fixtures/corpus.en",) as f:
    chunk_str = f.read()

chunk_str



In [15]:
from cs336_basics.bpe.normalization import normalize_text

In [31]:
normalized_chunk = normalize_text(chunk_str)
normalized_chunk



In [32]:
from cs336_basics.bpe.tokenization import count_word_frequencies, get_alphabet

word_counts = count_word_frequencies(normalized_chunk)

print(f"Number of unique subwords: {len(word_counts)}")
print(f"Most common subwords: {sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:10]}")

Number of unique subwords: 4475
Most common subwords: [(b'the', 1357), (b',', 1326), (b'.', 1039), (b'of', 646), (b'and', 629), (b'a', 490), (b'to', 483), (b';', 480), (b'in', 456), (b'is', 339)]


In [34]:
list(word_counts.keys())

[b'iron',
 b'cement',
 b'is',
 b'a',
 b'ready',
 b'for',
 b'use',
 b'paste',
 b'which',
 b'laid',
 b'as',
 b'fillet',
 b'by',
 b'putty',
 b'knife',
 b'or',
 b'finger',
 b'in',
 b'the',
 b'mould',
 b'edges',
 b'(',
 b'corners',
 b')',
 b'of',
 b'steel',
 b'ingot',
 b'.',
 b'protects',
 b'against',
 b'hot',
 b',',
 b'abrasive',
 b'casting',
 b'process',
 b'fire',
 b'restant',
 b'repair',
 b'places',
 b'ovens',
 b'open',
 b'fireplaces',
 b'etc',
 b'construction',
 b'and',
 b'highways',
 b'...',
 b'an',
 b'announcement',
 b'must',
 b'be',
 b'commercial',
 b'character',
 b'goods',
 b'services',
 b'advancement',
 b'through',
 b'P',
 b'O',
 b'Box',
 b'system',
 b'NOT',
 b'ALLOWED',
 b'deliveries',
 b'spam',
 b'other',
 b'improper',
 b'information',
 b'deleted',
 b'translator',
 b'Internet',
 b'Toolbar',
 b'MS',
 b'Explorer',
 b'it',
 b'allows',
 b'you',
 b'to',
 b'translate',
 b'real',
 b'time',
 b'any',
 b'web',
 b'pasge',
 b'from',
 b'one',
 b'language',
 b'another',
 b'only',
 b'have',
 b'

In [40]:
from cs336_basics.bpe.tokenization import split_words

splits = split_words(word_counts)

for word in list(splits.keys()):  # Display only first 10 words
    if '(' in word.decode('utf-8'):
       print(f"Word: '{word}', Splits: {splits[word]}")

Word: 'b'('', Splits: [b'(']


In [41]:
from cs336_basics.bpe.tokenization import compute_pair_freqs

pair_freqs = compute_pair_freqs(splits, word_counts)
print(f"Number of unique pairs: {len(pair_freqs)}")
print(f"Most common pairs: {sorted(pair_freqs.items(), key=lambda x: x[1], reverse=True)[:10]}")

Number of unique pairs: 985
Most common pairs: [((b't', b'h'), 2764), ((b'h', b'e'), 2168), ((b'i', b'n'), 1890), ((b'e', b'r'), 1564), ((b'r', b'e'), 1490), ((b'a', b'n'), 1488), ((b'o', b'r'), 1258), ((b'o', b'n'), 1214), ((b'n', b'd'), 1117), ((b'a', b't'), 1110)]


In [42]:
from cs336_basics.bpe.tokenization import merge_pair

In [43]:
from copy import deepcopy

clone_splits = deepcopy(splits)

pair_to_merge = max(list(pair_freqs.keys()), key=lambda k: pair_freqs[k])
print(f"Merging pair: {pair_to_merge} with frequency: {pair_freqs[pair_to_merge]}")

first_word, second_word = pair_to_merge
merge_pair(first_word, second_word, clone_splits, word_counts)

new_word = first_word + second_word
for word in list(clone_splits.keys()):
    if new_word in word:
        print(f"Word '{word}' after merging: {clone_splits[word]}")

Merging pair: (b't', b'h') with frequency: 2764
Word 'b'the'' after merging: [b'th', b'e']
Word 'b'through'' after merging: [b'th', b'r', b'o', b'u', b'g', b'h']
Word 'b'other'' after merging: [b'o', b'th', b'e', b'r']
Word 'b'another'' after merging: [b'a', b'n', b'o', b'th', b'e', b'r']
Word 'b'this'' after merging: [b'th', b'i', b's']
Word 'b'there'' after merging: [b'th', b'e', b'r', b'e']
Word 'b'that'' after merging: [b'th', b'a', b't']
Word 'b'their'' after merging: [b'th', b'e', b'i', b'r']
Word 'b'than'' after merging: [b'th', b'a', b'n']
Word 'b'with'' after merging: [b'w', b'i', b'th']
Word 'b'both'' after merging: [b'b', b'o', b'th']
Word 'b'ninth'' after merging: [b'n', b'i', b'n', b'th']
Word 'b'method'' after merging: [b'm', b'e', b'th', b'o', b'd']
Word 'b'rather'' after merging: [b'r', b'a', b'th', b'e', b'r']
Word 'b'without'' after merging: [b'w', b'i', b'th', b'o', b'u', b't']
Word 'b'thesaurus'' after merging: [b'th', b'e', b's', b'a', b'u', b'r', b'u', b's']
Word 

In [48]:
from cs336_basics.bpe.tokenization import train_bpe_tokenizer


vocab, merges = train_bpe_tokenizer(
    splits,
    word_counts,
    vocab_size=500,
)
print(f"Number of merges: {len(merges)}")
print(f"Vocabulary size after training: {len(vocab)}")

Number of merges: 500
Vocabulary size after training: 500


In [53]:
vocabs_without_specials = [
    word for word in vocab.values() if word != b"<|endoftext|>"
]
for word_bytes in vocabs_without_specials:
    assert b"<|" not in word_bytes

# E2E training of a BPE tokenizer

In [25]:
from cs336_basics.bpe import train_bpe


In [26]:
vocab, merges = train_bpe(
    input_path=Path('.').parent / "tests/fixtures/corpus.en",
    vocab_size=500,
    special_tokens=["<|endoftext|>"],
)

In [56]:
from tests.common import FIXTURES_PATH, gpt2_bytes_to_unicode
import json

In [57]:
input_path = FIXTURES_PATH / "corpus.en"
vocab, merges = train_bpe(
    input_path=input_path,
    vocab_size=500,
    special_tokens=["<|endoftext|>"],
)


In [59]:
# Path to the reference tokenizer vocab and merges
reference_vocab_path = FIXTURES_PATH / "train-bpe-reference-vocab.json"
reference_merges_path = FIXTURES_PATH / "train-bpe-reference-merges.txt"

# Compare the learned merges to the expected output merges
gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()}


with open(reference_merges_path) as f:
    gpt2_reference_merges = [tuple(line.rstrip().split(" ")) for line in f]
    reference_merges = [
        (
            bytes([gpt2_byte_decoder[token] for token in merge_token_1]),
            bytes([gpt2_byte_decoder[token] for token in merge_token_2]),
        )
        for merge_token_1, merge_token_2 in gpt2_reference_merges
    ]

In [None]:

assert merges == reference_merges


In [61]:

# Compare the vocab to the expected output vocab
with open(reference_vocab_path) as f:
    gpt2_reference_vocab = json.load(f)
    reference_vocab = {
        gpt2_vocab_index: bytes(
            [gpt2_byte_decoder[token] for token in gpt2_vocab_item]
        )
        for gpt2_vocab_item, gpt2_vocab_index in gpt2_reference_vocab.items()
    }

In [69]:
[word for word in vocab.values() if word.decode("utf-8") == "<|endoftext|>"]

[b'<|endoftext|>']