In [9]:

import regex as re
special_tokens = ['<|endoftext|>']
vocab_size = 500
# with open("/home/code/cs336/assignment1/tests/fixtures/tinystories_sample_5M.txt", 'r', encoding='utf-8') as f:
#     data = f.read()
#     # 正则表达式进行分词

In [10]:
import os
from typing import BinaryIO


def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

In [22]:
chunks = []
with open("/home/code/cs336/assignment1/tests/fixtures/tinystories_sample_5M.txt", "rb") as f:
    num_processes = 4
    boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")
    for i, (start, end) in enumerate(zip(boundaries[:-1], boundaries[1:])):
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")
        
        # 如果不是第一个 chunk 且它以 <|endoftext|> 开头，去掉开头的 <|endoftext|>
        if i > 0:
            chunk = chunk[len("<|endoftext|>\r\n"):]

        chunks.append(chunk)

# 查看前两个块
chunks[:2]




In [16]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
tokens = []
token_number = {}

In [17]:
for part in parts[:2]:
        # 将每个字符都转换为 UTF-8 字节表示，处理所有字符（包括引号）
    part_tokens = re.findall(PAT, part)
    for token in part_tokens:
        #token_bytes = list(token.encode('utf-8'))  # 将 token 转换为 UTF-8 字节
        tokens.append(token)
tokens

['u',
 ' don',
 "'t",
 ' have',
 ' to',
 ' be',
 ' scared',
 ' of',
 ' the',
 ' loud',
 ' dog',
 ',',
 ' I',
 "'ll",
 ' protect',
 ' you',
 '".',
 ' The',
 ' mole',
 ' felt',
 ' so',
 ' safe',
 ' with',
 ' the',
 ' little',
 ' girl',
 '.',
 ' She',
 ' was',
 ' very',
 ' kind',
 ' and',
 ' the',
 ' mole',
 ' soon',
 ' came',
 ' to',
 ' trust',
 ' her',
 '.',
 ' He',
 ' leaned',
 ' against',
 ' her',
 ' and',
 ' she',
 ' kept',
 ' him',
 ' safe',
 '.',
 ' The',
 ' mole',
 ' had',
 ' found',
 ' his',
 ' best',
 ' friend',
 '.',
 '\r',
 '\n',
 '<|',
 'endoftext',
 '|>',
 '\r',
 '\n',
 'Once',
 ' upon',
 ' a',
 ' time',
 ',',
 ' in',
 ' a',
 ' warm',
 ' and',
 ' sunny',
 ' place',
 ',',
 ' there',
 ' was',
 ' a',
 ' big',
 ' pit',
 '.',
 ' A',
 ' little',
 ' boy',
 ' named',
 ' Tom',
 ' liked',
 ' to',
 ' play',
 ' near',
 ' the',
 ' pit',
 '.',
 ' One',
 ' day',
 ',',
 ' Tom',
 ' lost',
 ' his',
 ' red',
 ' ball',
 '.',
 ' He',
 ' was',
 ' very',
 ' sad',
 '.',
 '\r',
 '\n',
 'Tom',
 ' ask

In [None]:

    # 计算每个 token 的频率
for token in tokens:
    token_tuple = tuple(token)  # 转换为字节元组
    if token_tuple in token_number:
        token_number[token_tuple] += 1
    else:
        token_number[token_tuple] = 1
token_number

{(117,): 1,
 (32, 100, 111, 110): 604,
 (39, 116): 2657,
 (32, 104, 97, 118, 101): 2503,
 (32, 116, 111): 35102,
 (32, 98, 101): 2965,
 (32, 115, 99, 97, 114, 101, 100): 1432,
 (32, 111, 102): 5851,
 (32, 116, 104, 101): 48886,
 (32, 108, 111, 117, 100): 462,
 (32, 100, 111, 103): 3146,
 (44,): 55123,
 (32, 73): 4349,
 (39, 108, 108): 105,
 (32, 112, 114, 111, 116, 101, 99, 116): 51,
 (32, 121, 111, 117): 6596,
 (34, 46): 194,
 (32, 84, 104, 101): 10882,
 (32, 109, 111, 108, 101): 52,
 (32, 102, 101, 108, 116): 2059,
 (32, 115, 111): 4243,
 (32, 115, 97, 102, 101): 624,
 (32, 119, 105, 116, 104): 9935,
 (32, 108, 105, 116, 116, 108, 101): 5588,
 (32, 103, 105, 114, 108): 3001,
 (46,): 98136,
 (32, 83, 104, 101): 9031,
 (32, 119, 97, 115): 25403,
 (32, 118, 101, 114, 121): 5877,
 (32, 107, 105, 110, 100): 744,
 (32, 97, 110, 100): 45801,
 (32, 115, 111, 111, 110): 426,
 (32, 99, 97, 109, 101): 1895,
 (32, 116, 114, 117, 115, 116): 38,
 (32, 104, 101, 114): 9518,
 (32, 72, 101): 11496,
 

In [None]:
# 初始化词汇表
vocab = {}
current_id = 0

    # 添加所有单字节（0-255）的字节
for i in range(256):
    vocab[current_id] = bytes([i])
    current_id += 1
    # 添加特殊标记到词汇表
for token_str in special_tokens:
    vocab[current_id] = token_str.encode("utf-8")
    current_id += 1
vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [None]:

    # 频率统计
freq_dict = {}
merges = []

    # 将每个 token 转换为元组形式并计算频率
new_token_number = {}
for token, count in token_number.items():
    token_bytes_tuple = tuple(token)  # 将字符串转为字节元组
    new_token_number[token_bytes_tuple] = count
new_token_number

{(117,): 1,
 (32, 100, 111, 110): 604,
 (39, 116): 2657,
 (32, 104, 97, 118, 101): 2503,
 (32, 116, 111): 35102,
 (32, 98, 101): 2965,
 (32, 115, 99, 97, 114, 101, 100): 1432,
 (32, 111, 102): 5851,
 (32, 116, 104, 101): 48886,
 (32, 108, 111, 117, 100): 462,
 (32, 100, 111, 103): 3146,
 (44,): 55123,
 (32, 73): 4349,
 (39, 108, 108): 105,
 (32, 112, 114, 111, 116, 101, 99, 116): 51,
 (32, 121, 111, 117): 6596,
 (34, 46): 194,
 (32, 84, 104, 101): 10882,
 (32, 109, 111, 108, 101): 52,
 (32, 102, 101, 108, 116): 2059,
 (32, 115, 111): 4243,
 (32, 115, 97, 102, 101): 624,
 (32, 119, 105, 116, 104): 9935,
 (32, 108, 105, 116, 116, 108, 101): 5588,
 (32, 103, 105, 114, 108): 3001,
 (46,): 98136,
 (32, 83, 104, 101): 9031,
 (32, 119, 97, 115): 25403,
 (32, 118, 101, 114, 121): 5877,
 (32, 107, 105, 110, 100): 744,
 (32, 97, 110, 100): 45801,
 (32, 115, 111, 111, 110): 426,
 (32, 99, 97, 109, 101): 1895,
 (32, 116, 114, 117, 115, 116): 38,
 (32, 104, 101, 114): 9518,
 (32, 72, 101): 11496,
 

In [None]:
from collections import Counter
freq_dict = Counter()  # 重新统计频率
for token in new_token_number:
    number = len(token)
    for i in range(number - 1):
        pair = (token[i], token[i + 1])  # 生成字节对
        freq_dict[pair] += new_token_number[token]  # 累加频率

    # 找到频率最高的字符对
max_pair = max(freq_dict.items(), key=lambda x: (x[1], x[0][0], x[0][1]))
max_pair

((32, 116), 148857)

In [None]:
pair = (vocab[max_pair[0][0]], vocab[max_pair[0][1]])
pair

(b' ', b't')

In [None]:

        # 将最高频的字符对添加到 merges 列表
merges.append(pair)

        # 将合并后的字符对添加到 vocab
merged_str = chr(max_pair[0][0]) + chr(max_pair[0][1])  # 合并后的字符
vocab[current_id] = bytes([ord(c) for c in merged_str])
current_id += 1
vocab[current_id - 1]

b' t'

In [None]:
        # 更新 token，将频率最高的字符对替换为新合并的 token
new_tokens = {}
for token in new_token_number:
    updated_token = []
    i = 0
    while i < len(token) - 1:
        pair = (token[i], token[i + 1])
        if pair == max_pair[0]:
            updated_token.append(current_id - 1)  # 替换为合并后的 token
            i += 2  # 跳过已经合并的字符对
        else:
            updated_token.append(token[i])
            i += 1

            # 处理最后一个字符
    if i < len(token):
        updated_token.append(token[i])

    new_tokens[tuple(updated_token)] = new_token_number[token]

        # 更新 token_number 为新的 token
new_token_number = new_tokens
new_token_number

{(117,): 1,
 (32, 100, 111, 110): 604,
 (39, 116): 2657,
 (32, 104, 97, 118, 101): 2503,
 (258, 111): 35102,
 (32, 98, 101): 2965,
 (32, 115, 99, 97, 114, 101, 100): 1432,
 (32, 111, 102): 5851,
 (258, 257): 48886,
 (32, 108, 111, 117, 100): 462,
 (32, 100, 111, 103): 3146,
 (44,): 55123,
 (32, 73): 4349,
 (39, 108, 108): 105,
 (32, 112, 114, 111, 116, 101, 99, 116): 51,
 (32, 121, 111, 117): 6596,
 (34, 46): 194,
 (32, 84, 257): 10882,
 (32, 109, 111, 108, 101): 52,
 (32, 102, 101, 108, 116): 2059,
 (32, 115, 111): 4243,
 (32, 115, 97, 102, 101): 624,
 (32, 119, 105, 116, 104): 9935,
 (32, 108, 105, 116, 116, 108, 101): 5588,
 (32, 103, 105, 114, 108): 3001,
 (46,): 98136,
 (32, 83, 257): 9031,
 (32, 119, 97, 115): 25403,
 (32, 118, 101, 114, 121): 5877,
 (32, 107, 105, 110, 100): 744,
 (32, 97, 110, 100): 45801,
 (32, 115, 111, 111, 110): 426,
 (32, 99, 97, 109, 101): 1895,
 (258, 114, 117, 115, 116): 38,
 (32, 257, 114): 9518,
 (32, 72, 101): 11496,
 (32, 108, 101, 97, 110, 101, 100

In [None]:
merges

[(b'h', b'e')]

In [None]:
vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [None]:
vocab[0]+vocab[1]

b'\x00\x01'