### 2.1
- printable vs no appearance: '\x00' vs '\x80'
- `some encoding -> binary -> another encoding` is a receipe for errors. See https://en.wikipedia.org/wiki/Specials_(Unicode_block)#:~:text=The%20replacement%20character%20%EF%BF%BD%20(often,of%20data%20to%20correct%20symbols. for examples.

In [None]:
print(chr(2049))
print(list(chr(2049).encode()))
print([bin(n) for n in list(chr(2049).encode())])

In [None]:
def gpt2_bytes_to_unicode() -> dict[int, str]:
    """
    Returns a mapping between every possible byte (an integer from 0 to 255) to a
    printable unicode string character representation. This function is taken
    from the GPT-2 code.

    For example, `chr(0)` is `\x00`, which is an unprintable character:

    >>> chr(0)
    '\x00'
    >>> print(chr(0))

    As a result, this function returns a dictionary `d` where `d[0]` returns `Ā`.
    The bytes that are visually printable keep their original string representation [1].
    For example, `chr(33)` returns `!`, and so accordingly `d[33]` returns `!`.
    Note in particular that the space character `chr(32)` becomes `d[32]`, which
    returns 'Ġ'.

    For unprintable characters, the function shifts takes the integer representing
    the Unicode code point of that character (returned by the Python `ord`) function
    and shifts it by 256. For example, `ord(" ")` returns `32`, so the the space character
    ' ' is shifted to `256 + 32`. Since `chr(256 + 32)` returns `Ġ`, we use that as the
    string representation of the space.

    This function can simplify the BPE implementation and makes it slightly easier to
    manually inspect the generated merges after they're serialized to a file.
    """
    # These 188 integers can used as-is, since they are not whitespace or control characters.
    # See https://www.ssec.wisc.edu/~tomw/java/unicode.html.
    bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    cs = bs[:]
    # now get the representations of the other 68 integers that do need shifting
    # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop
    # Get printable representations of the remaining integers 68 integers.
    n = 0
    for b in range(2**8):
        if b not in bs:
            # If this integer isn't in our list of visually-representable
            # charcters, then map it to the next nice character (offset by 256)
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    characters = [chr(n) for n in cs]
    d = dict(zip(bs, characters))
    return d

In [None]:
def merge_pretoken_counts(pretoken_counts, pair_counts, max_pair):
    print("==========")
    new_pretoken_counts = {}
    new_pair_counts = dict(pair_counts) 
    for byte_tup, byte_tup_count in pretoken_counts.items():
        i = 0
        new_byte_tup = []
        while i < len(byte_tup)-1:
            cur = (byte_tup[i], byte_tup[i+1])
            if cur == max_pair:
                new_byte_tup.append(b"".join(max_pair))
                new_pair_counts[max_pair] -= byte_tup_count
                # when current pair is max_pair, always affect proceeding pair
                prev = (byte_tup[i-1], byte_tup[i]) if i > 0 else None
                new_pair_counts[prev] = new_pair_counts.get(prev, 0) - byte_tup_count
                i += 2
            else:
                # when not max_pair, just take the element as is
                new_byte_tup.append(byte_tup[i])
                # need to look into previous two elements
                check = (byte_tup[i-2], byte_tup[i-1]) if i > 1 else None
                if check == max_pair:
                    prev = (byte_tup[i-1], byte_tup[i])
                    new_pair_counts[prev] = new_pair_counts.get(prev, 0) - byte_tup_count
                i += 1
        if i == len(byte_tup) - 1:
            new_byte_tup.append(byte_tup[i])
        # update pretoken counts
        new_byte_tup = tuple(new_byte_tup)
        new_pretoken_counts[new_byte_tup] = new_pretoken_counts.get(new_byte_tup, 0) + byte_tup_count

        if new_byte_tup != byte_tup:
            i = 0
            while i < len(new_byte_tup) - 1:
                pair = (new_byte_tup[i], new_byte_tup[i+1])
                if b"".join(max_pair) in pair:
                    new_pair_counts[pair] = new_pair_counts.get(pair, 0) + byte_tup_count
                i += 1
    new_pair_counts = {k:v for k, v in new_pair_counts.items() if v > 0}
    return new_pretoken_counts, new_pair_counts, get_max_pair(new_pair_counts)

In [None]:
def merge_one_tuple(byte_tup, max_pair):
    """Optimized version using list operations instead of tuple concatenation"""
    # if len(byte_tup) == 2:
    #     if byte_tup == max_pair:
    #         return (b"".join(max_pair),), [0], None, None
    #     else:
    #         return byte_tup, None, None, None
    # if len(byte_tup) == 3:
    #     if byte_tup[:2] == max_pair:
    #         return (b"".join(max_pair), byte_tup[2]), [0, 1], none, [0]
    #     if byte_tup[-2:] == max_pair:
    #         return (byte_tup[0], b"".join(max_pair)), [0, 1], [0], none
    
    max_pair_0, max_pair_1 = max_pair  # Unpack once
    merged_token = max_pair_0 + max_pair_1  # Pre-compute joined bytes

    if len(byte_tup) == 1:
        return byte_tup, None, None, None

    merged_byte_tup = b"".join(byte_tup)
    if merged_token not in merged_byte_tup:
        return byte_tup, None, None, None
    

    result = []
    ids = []
    i = 0
    while i < len(byte_tup):
        if (i < len(byte_tup) - 1 and 
            byte_tup[i] == max_pair_0 and 
            byte_tup[i + 1] == max_pair_1):
            # Merge the pair
            result.append(merged_token)
            ids.append(i)
            i += 2
        else:
            result.append(byte_tup[i])
            i += 1

    assert ids is not None, "something is wrong with `merge_one_tuple` function."
    ids_prev = [i-1 for i in ids if i > 0]
    ids_post = [i+1 for i in ids if i < len(byte_tup)-2]
    assert set(ids_prev).intersection(set(ids_post)) == set(), "`ids_prev` and `ids_post` should have no overlap!"

    ids2rm = set(ids).union(set(ids_prev), set(ids_post))
        
    return tuple(result), sorted(ids2rm), ids_prev, ids_post

In [None]:
merge_one_tuple((b'l', b'o', b'w'), (b'l', b'o'))

In [None]:
def merge_pretoken_counts(pretoken_counts, pair_counts, max_pair):
    """
    Optimized version that incrementally updates pair counts instead of 
    recalculating everything from scratch. Only pairs that overlap with 
    the merged pair need to have their counts updated.
    """
    new_pretoken_counts = {}
    new_pair_counts = dict(pair_counts)  # Faster than .copy()
    
    # # Remove the merged pair from pair counts
    # new_pair_counts.pop(max_pair, None)
    
    for byte_tup, byte_tup_count in pretoken_counts.items():
        new_byte_tup, ids2rm, ids_prev, ids_post = merge_one_tuple(byte_tup, max_pair)
        new_pretoken_counts[new_byte_tup] = new_pretoken_counts.get(new_byte_tup, 0) + byte_tup_count
        
        # Only update pair counts for sequences that actually changed
        if ids2rm:
            # print(byte_tup)
            # print(ids2rm)
            for i in ids2rm:
                pair = (byte_tup[i], byte_tup[i+1])
                count = new_pair_counts.get(pair, 0) - byte_tup_count
                if count > 0:
                    new_pair_counts[pair] = count
                else:
                    new_pair_counts.pop(pair)
            
        if ids_prev:
            # Add new pair counts for the merged sequence
            for i in ids_prev:
                new_pair = (byte_tup[i], b"".join(max_pair))
                new_pair_counts[new_pair] = new_pair_counts.get(new_pair, 0) + byte_tup_count
        if ids_post:
            # Add new pair counts for the merged sequence
            for i in ids_post:
                new_pair = (b"".join(max_pair), byte_tup[i+1])
                new_pair_counts[new_pair] = new_pair_counts.get(new_pair, 0) + byte_tup_count
    
    # new_pair_counts = {k:v for k, v in new_pair_counts.items() if v > 0}
    return new_pretoken_counts, new_pair_counts, get_max_pair(new_pair_counts)

In [None]:
doc = "low low low low low lower lower widest widest widest newest newest newest newest newest newest"
from train_bpe_test import get_max_pair
from collections import Counter
pretokens = doc.split()
pretoken_counts = Counter(pretokens)
pretoken_counts = {tuple(bytes([b]) for b in k.encode()):v for k,v in pretoken_counts.items()}
print(pretoken_counts)
pair_counts = get_pair_counts(pretoken_counts)
max_pair = get_max_pair(pair_counts)
max_pair

In [None]:
ptc, prc = pretoken_counts.copy(), pair_counts.copy()
max_pair = get_max_pair(prc)
print(ptc)
print(prc)

In [None]:
for i in range(1):
    print(max_pair)
    ptc, prc,max_pair = merge_pretoken_counts(ptc, prc, max_pair)
    print()
    print(ptc)
    print(prc)
    print()

In [None]:
merge_one_tuple((b'l', b'o', b'w', b'e', b'r'), (b'o', b'w'))

In [None]:
for i in range(1):
    print(max_pair)
    ptc, prc,max_pair = merge_pretoken_counts(ptc, prc, max_pair)
    print()
    print(ptc)
    print(prc)
    print()

In [None]:
for i in range(1):
    print(max_pair)
    ptc, prc,max_pair = merge_pretoken_counts(ptc, prc, max_pair)
    print()
    print(ptc)
    print(prc)
    print()

In [None]:
for i in range(1):
    print(max_pair)
    ptc, prc,max_pair = merge_pretoken_counts(ptc, prc, max_pair)
    print()
    print(ptc)
    print(prc)
    print()

In [None]:
merge_one_tuple((b'l', b'o', b'w'), (b'o', b'w'))

In [None]:
max_pair

In [None]:
merge_one_tuple((b'l', b'o', b'w'), (b'o', b'w'))

In [None]:
??get_max_pair

In [None]:
max_pair

In [None]:
byte_tup

In [None]:
get_max_pair({(b'low',): 5, (b'low', b'e', b'r'): 2, (b'w', b'i', b'd', b'est'): 3, (b'n', b'e', b'west'): 6})

In [None]:
prc

In [None]:
??get_max_pair

In [None]:
new_prc

In [None]:
new_ptc

In [None]:
new_ptc

In [None]:
new_bt

In [None]:
ptc

In [None]:
for bt, bc in ptc.items():
    i = 0
    while i < len(bt) - 1:
        old_pair = (bt[i], bt[i+1])
        print(old_pair)

In [None]:
('a', 'b') > chr(1)

In [None]:
new_ptc

In [None]:
new_prc

In [None]:
pair_counts

### 2.6

In [None]:
ord('Ġ')

In [None]:
def update_pretoken(pretoken, pair):
    result = []
    i = 0
    while i < len(pretoken):
        if (i < len(pretoken) - 1 and 
            pretoken[i] == pair[0] and 
            pretoken[i + 1] == pair[1]):
            # Merge the pair
            result.append(b"".join(pair))
            i += 2
        else:
            result.append(pretoken[i])
            i += 1
    
    return result

In [None]:
import sys
sys.path.append("../tests")
from common import gpt2_bytes_to_unicode
import json

def get_tokenizer_from_vocab_merges_path(
    vocab_path: str | os.PathLike,
    merges_path: str | os.PathLike,
    special_tokens: list[str] | None = None,
):
    gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()}
    with open(vocab_path) as vocab_f:
        gpt2_vocab = json.load(vocab_f)
    gpt2_bpe_merges = []
    with open(merges_path) as f:
        for line in f:
            cleaned_line = line.rstrip()
            if cleaned_line and len(cleaned_line.split(" ")) == 2:
                gpt2_bpe_merges.append(tuple(cleaned_line.split(" ")))
    # The GPT-2 tokenizer uses a remapped unicode encoding for bytes. Let's
    # just return the original bytes, so we don't force students to use
    # any particular encoding scheme.
    vocab = {
        gpt2_vocab_index: bytes([gpt2_byte_decoder[token] for token in gpt2_vocab_item])
        for gpt2_vocab_item, gpt2_vocab_index in gpt2_vocab.items()
    }
    # If any of the special tokens don't exist in the vocab, append them to the vocab.
    if special_tokens:
        for special_token in special_tokens:
            byte_encoded_special_token = special_token.encode("utf-8")
            if byte_encoded_special_token not in set(vocab.values()):
                vocab[len(vocab)] = byte_encoded_special_token

    merges = [
        (
            bytes([gpt2_byte_decoder[token] for token in merge_token_1]),
            bytes([gpt2_byte_decoder[token] for token in merge_token_2]),
        )
        for merge_token_1, merge_token_2 in gpt2_bpe_merges
    ]
    # return Tokenizer(vocab, merges, special_tokens)
    return vocab, merges, special_tokens

In [None]:
import regex as re

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

VOCAB_PATH = "/home/azureuser/02-fun/cs336-assignment1-basics/tests/fixtures/gpt2_vocab.json"
MERGES_PATH = "/home/azureuser/02-fun/cs336-assignment1-basics/tests/fixtures/gpt2_merges.txt"
VOCAB, MERGES, sptok = get_tokenizer_from_vocab_merges_path(VOCAB_PATH, MERGES_PATH)

def split_by_special_tokens(special_tokens, text):
    if not special_tokens:
        return [text]
    escaped_patterns = [re.escape(p) for p in sorted(special_tokens, key=len, reverse=True)]
    pattern = f"({'|'.join(escaped_patterns)})"
    return re.split(pattern, text)

import tiktoken
reference_tokenizer = tiktoken.get_encoding("gpt2")


In [None]:
from typing import Iterable, Iterator
import regex as re
import json

class Tokenizer:
    def __init__(
        self,
        vocab: dict[int, bytes],
        merges: Iterable[tuple[bytes, bytes]],
        special_tokens: list[str] | None = None
    ):
        self.vocab = vocab if vocab else {}
        self.merges = merges if merges else []
        self.special_tokens = special_tokens

    @classmethod
    def from_files(cls, vocab_filepath:str, merges_filepath:str, special_tokens: list[str] | None=None):
        gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()}
        with open(vocab_filepath) as vocab_f:
            gpt2_vocab = json.load(vocab_f)
        gpt2_bpe_merges = []
        with open(merges_filepath) as f:
            for line in f:
                cleaned_line = line.rstrip()
                if cleaned_line and len(cleaned_line.split(" ")) == 2:
                    gpt2_bpe_merges.append(tuple(cleaned_line.split(" ")))
        # The GPT-2 tokenizer uses a remapped unicode encoding for bytes. Let's
        # just return the original bytes, so we don't force students to use
        # any particular encoding scheme.
        vocab = {
            gpt2_vocab_index: bytes([gpt2_byte_decoder[token] for token in gpt2_vocab_item])
            for gpt2_vocab_item, gpt2_vocab_index in gpt2_vocab.items()
        }
        # If any of the special tokens don't exist in the vocab, append them to the vocab.
        if special_tokens:
            for special_token in special_tokens:
                byte_encoded_special_token = special_token.encode("utf-8")
                if byte_encoded_special_token not in set(vocab.values()):
                    vocab[len(vocab)] = byte_encoded_special_token

        merges = [
            (
                bytes([gpt2_byte_decoder[token] for token in merge_token_1]),
                bytes([gpt2_byte_decoder[token] for token in merge_token_2]),
            )
            for merge_token_1, merge_token_2 in gpt2_bpe_merges
        ]
        return cls(vocab, merges, special_tokens)

    def encode(self, text: str) -> list[int]:
        vocab_reversed = {v:k for k,v in self.vocab.items()}
        pretokens = re.findall(PAT, text)
        chunks = split_by_special_tokens(self.special_tokens, text)
        tokens = []
        for chunk in chunks:
            if chunk in self.special_tokens:
                tokens.append(chunk.encode())
                continue
            pretokens = re.findall(PAT, chunk)
            for pretoken in pretokens:
                pretoken  = [bytes([b]) for b in pretoken.encode()]
                while len(pretoken) >= 2:
                    pairs = list(zip(pretoken[:-1], pretoken[1:]))
                    try:
                        pid = min([self.merges.index(p) for p in pairs if p in self.merges])
                        pair = self.merges[pid]
                        pretoken = update_pretoken(pretoken, pair)
                    except ValueError:
                        break
                tokens.extend(pretoken)
        return [vocab_reversed[token] for token in tokens]

    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        for text in iterable:
            yield from self.encode(text)

    def decode(self, ids: list[int]):
        return b"".join([self.vocab[i] for i in ids]).decode("utf-8", errors="replace")

In [None]:
tokenizer = Tokenizer.from_files(VOCAB_PATH, MERGES_PATH, special_tokens=["<|endoftext|>"])

In [None]:
test_string = "Héllò hôw <|endoftext|><|endoftext|> are ü? 🙃<|endoftext|>"
ids = tokenizer.encode(test_string)

In [None]:
test_string = "Héllò hôw <|endoftext|><|endoftext|> are ü? 🙃<|endoftext|>"
# test_string = "Hello how <|endoftext|><|endoftext|> are u? 🙃<|endoftext|>"
# ids = tokenizer.encode(test_string)
special_tokens = ["<|endoftext|>"]
text = test_string
vocab_reversed = {v:k for k,v in VOCAB.items()}
chunks = split_by_special_tokens(tokenizer.special_tokens, text)
tokens = []
for chunk in chunks:
    if chunk in tokenizer.special_tokens:
        tokens.append(chunk.encode())
        continue
    pretokens = re.findall(PAT, chunk)
    for pretoken in pretokens:
        pretoken  = [bytes([b]) for b in pretoken.encode()]
        while len(pretoken) >= 2:
            pairs = list(zip(pretoken[:-1], pretoken[1:]))
            try:
                pid = min([tokenizer.merges.index(p) for p in pairs if p in tokenizer.merges])
                pair = tokenizer.merges[pid]
                pretoken = update_pretoken(pretoken, pair)
            except ValueError:
                break
        tokens.extend(pretoken)

In [None]:
reference_ids = reference_tokenizer.encode(test_string, allowed_special={"<|endoftext|>"})
print(reference_ids)

In [None]:
print(ids)

In [None]:
reference_tokenizer.decode(reference_ids)

In [None]:
ord(127).decode()

In [None]:
b"\x3c".decode()

In [None]:
ids

In [None]:
"ò".encode()

In [None]:
b'\xc3\xb3'.decode()

In [None]:
ids[:10]

In [None]:
VOCAB[127]

In [None]:
VOCAB[2634]

In [None]:
ord('\xa9')

In [None]:
b'\xc3\xb2'.decode()

In [None]:
for i in ids:
    print(VOCAB[i].decode("utf-8", errors="replace"))

In [None]:
VOCAB[127].decode("utf-8", errors="replace")

In [None]:
tokenizer.decode(ids)

In [None]:

pretoken = re.findall(PAT, test_string)[0]
print(pretoken)

pretoken = [bytes([b]) for b in test_string.encode()]
print(pretoken)
i = 0
while len(pretoken) >= 2:
    pairs = list(zip(pretoken[:-1], pretoken[1:]))
    try:
        pid = min([MERGES.index(p) for p in pairs if p in MERGES])
    except ValueError:
        break
    pair = MERGES[pid]
    pretoken = update_pretoken(pretoken, pair)
    print(pretoken)

# if idx_cur > len(MERGES):
#     print("")


> Good example of why we cannot use vocab to do encoding. This goes into post

In [None]:
def encode(self, text: str) -> list[int]:
    vocab_reversed = {v:k for k,v in self.vocab.items()}
    pretokens = re.findall(PAT, text)
    tokens = []
    for pretoken in pretokens:
        pretoken  = [bytes([b]) for b in pretoken.encode()]
        token = pretoken[0]
        i = 1
        while i < len(pretoken):
            token_ = b"".join((token, pretoken[i]))
            if token_ in vocab_reversed:
                token = token_
                i += 1
            else:
                tokens.append(token)
                token = pretoken[i]
                i += 1
        tokens.append(token)
    return [vocab_reversed[token] for token in tokens]

import tiktoken
reference_tokenizer = tiktoken.get_encoding("gpt2")
test_string = "Héllò hôw are ü? 🙃"

reference_ids = reference_tokenizer.encode(test_string)
ids = tokenizer.encode(test_string)
assert ids != reference_ids

print(VOCAB[220], VOCAB[8582])
print(VOCAB[12520])

In [None]:
# Open the sample file
with open("../tests/fixtures/tinystories_sample.txt") as f:
    # Use encode_iterable with the file handle. This returns a generator.
    token_generator = tokenizer.encode_iterable(f)
    
    # Let's iterate through the generator and print the first 20 token IDs
    print("First 20 token IDs:")
    for i, token_id in enumerate(token_generator):
        if i >= 20:
            break
        print(token_id, end=", ")

In [None]:
tokenized_string = reference_tokenizer.encode(test_string, allowed_special={"<|endoftext|>"})

In [None]:
tokenized_string.count("<|endoftext|>")

In [None]:
test_string

In [None]:
tokenizer.decode(tokenizer.encode(test_string))

In [None]:
min([MERGES.index(p) for p in pair_counts if p in MERGES])

In [None]:
min([])

In [None]:
min(pair_counts, key=lambda p: MERGES.index(p))

In [None]:
MERGES.index((b'\xc3', b'\xa9'))

In [None]:
MERGES

In [None]:
def encode(text: str) -> list[int]:
    vocab_reversed = {v:k for k,v in VOCAB.items()}
    pretokens = re.findall(PAT, text)
    print(pretokens)
    tokens = []
    for pretoken in pretokens:
        pretoken  = [bytes([b]) for b in pretoken.encode()]
        print(pretoken)
        token = pretoken[0]
        i = 1
        while i < len(pretoken):
            token_tmp = b"".join((token, pretoken[i]))
            if token_tmp in vocab_reversed:
                token = token_tmp
                i += 1
            else:
                tokens.append(token)
                print(token)
                token = pretoken[i]
                i += 1
        tokens.append(token)
        print(token)
    return [vocab_reversed[token] for token in tokens]

In [None]:
(b'Hel', b'l') 

In [None]:
encode("Hello, how are you?")

In [None]:
vr = {v:k for k,v in VOCAB.items()}
vr[b'He']

> could miss merges like (b' ', b'\x0b9\x011')

In [None]:
VOCAB_PATH = "/home/azureuser/02-fun/cs336-assignment1-basics/tests/fixtures/gpt2_vocab.json"
MERGES_PATH = "/home/azureuser/02-fun/cs336-assignment1-basics/tests/fixtures/gpt2_merges.txt"

import tiktoken

def test_ascii_string_matches_tiktoken():
    reference_tokenizer = tiktoken.get_encoding("gpt2")
    # tokenizer = get_tokenizer_from_vocab_merges_path(
    tokenizer = Tokenizer.from_files(
        VOCAB_PATH, MERGES_PATH, ["<|endoftext|>"]
    )
    test_string = "Hello, how are you?"

    reference_ids = reference_tokenizer.encode(test_string)
    print(reference_ids)
    ids = tokenizer.encode(test_string)
    # assert ids == reference_ids

    tokenized_string = [tokenizer.decode([x]) for x in ids]

    return tokenized_string
    # assert tokenized_string == ["Hello", ",", " how", " are", " you", "?"]

    # assert tokenizer.decode(ids) == test_string
    # assert reference_tokenizer.decode(reference_ids) == test_string

In [None]:
tokenizer.encode("Hello, how are you?")

In [None]:
test_ascii_string_matches_tiktoken()

In [None]:
tokenizer.vocab

In [None]:
type(token_generator)

In [None]:
tokenizer.decode(tokenizer.encode(" the bananas are green"))

In [None]:
import os
os.listdir("../data/")

In [None]:
special_token = "<|endoftext|>"
with open("../data/TinyStoriesV2-GPT4-valid.txt", "rb") as f:
    doc = f.read().split(special_token.encode())[0]

In [None]:
doc

In [None]:
jkj

In [None]:
import json
json.loads(json.dumps({1:vocab[1].decode()}))

In [None]:
vocab = {idx: bytes([idx]) for idx in range(256)}
# for (p0, p1), idx in merges.items():
#     vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
  # given ids (list of integers), return Python string
  tokens = b"".join(vocab[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

print(decode([128]))

In [None]:
ord(bytes([28]))

In [None]:
bytes([68])

In [None]:
ord('&')

In [None]:
ord(b'&')

In [None]:
ord('D')

### 2.5
- `train_bpe_tinystoires`
    - Current memory usage: 5.24 MB
    - Peak memory usage: 116.74 MB

In [None]:
b'a' in (b'a' + b'\x80')

In [None]:
b'abc'[:2]

In [None]:
list(b'bc')

In [None]:
b'ab'.replace(b'a', b'e')

In [None]:
from typing import Iterable
def _update_byte_tuple(byte_tuple: Iterable[bytes], merge_loc: int):
    """
    Merge the byte tuple at the merge location.
    """
    assert len(byte_tuple) > 1, "Cannot merge a byte tuple with length less than 2."
    prefix = byte_tuple[:merge_loc]
    tomerge = byte_tuple[merge_loc:merge_loc+2]
    suffix = byte_tuple[merge_loc+2:]
    new_byte_tuple = prefix + (b"".join(tomerge),) + suffix
    return new_byte_tuple, prefix, suffix

In [None]:
byte_tuple = tuple(bytes([c]) for c in 'xyz'.encode())
# tuple(bytes([b]) for b in pretoken)
_update_byte_tuple(byte_tuple, 1)

In [None]:
byte_tuple

In [None]:
tuple(b'xyz')