In [16]:

import torch
import regex as re
import heapq
import random
import os

from concurrent.futures import ProcessPoolExecutor
from typing import BinaryIO

In [3]:
train_text_path = "/root/workspace/cs336/assignment1/data/TinyStoriesV2-GPT4-train.txt"
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
merge_file_path = "/root/workspace/cs336/assignment1/mergeslist.txt"
merge_ops = 10000
token_EOF = "<|endoftext|>"
num_processes = 4
random.seed(42)

In [4]:


def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

def pre_tokenize(filepath, bound_st, bound_ed, pattern):
    with open(filepath, "rb") as f:
        f.seek(bound_st)
        chunk = f.read(bound_ed - bound_st).decode("utf-8", errors="ignore")
        chunk_set = chunk.split(token_EOF)
        corpus_weights = {}
        for small_chunk in chunk_set:
            splited_text = re.findall(pattern, small_chunk)
            for words in splited_text:
                data_u8 = words.encode("utf-8")
                corpus_weights[data_u8] = corpus_weights.get(data_u8, 0) + 1
    return corpus_weights


        

In [5]:
with open(train_text_path, "rb") as f:
    boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

In [6]:
parellel_params = [(train_text_path, start, end, PAT) for start, end in zip(boundaries[:-1], boundaries[1:])]

In [7]:
with ProcessPoolExecutor(max_workers=num_processes) as ex:
    results = list(ex.map(pre_tokenize, *zip(*parellel_params)))

In [81]:
word_weights = {}   #{words: (word_now,frequency)}
dict_of_pair = {}   #{(ch1, ch2): frequency}, true frequency
pair_to_words = {}  #{(ch1, ch2): set(words)}
tokens = {i:(i,) for i in range(256)}  #{token_id: [bytestring]}


In [82]:
#初始化 word_weights, dict_of_pair, pair_to_words
for dic in results:
    for k, v in dic.items():
        word_weights[k] = word_weights.get(k, (k,0))
        word_weights[k] = (k, word_weights[k][1] + v)
for k,v in word_weights.items():
    for i in range(len(k)-1):
        ch1 = k[i]
        ch2 = k[i+1]
        pair = (ch1, ch2)
        dict_of_pair[pair] = dict_of_pair.get(pair, 0) + v[1]
        pair_to_words.setdefault(pair, set()).add(k)

In [83]:
#初始化优先队列
pair_freq_heap = [(-freq, pair) for pair, freq in dict_of_pair.items()]
heapq.heapify(pair_freq_heap)


(-63482199, (32, 116))

In [84]:
valid_merge = 0
token_id = 256
while(valid_merge < merge_ops):
    neg_freq, pair = heapq.heappop(pair_freq_heap)
    freq = -neg_freq
    if dict_of_pair.get(pair, 0) != freq:
        continue
    
    idx_now = token_id
    token_id += 1
    tokens[idx_now] = pair
    
    for word_id in pair_to_words[pair]:
        new_word = []
        word = word_weights[word_id][0]
        i = 0
        while i < len(word):
            if i + 1 < len(word) and word[i] == pair[0] and word[i + 1] == pair[1]:
                new_word.append(idx_now)
                i += 2           
                dict_of_pair[pair] -= word_weights[word_id][1]
            else:
                new_word.append(word[i])
                i += 1
        
        for i in range(len(new_word)-1):
            if new_word[i] == idx_now:
                if i + 1 < len(new_word):
                    if new_word[i+1] == idx_now:
                        new_pair_post = (idx_now, idx_now)
                        old_pair_post = (pair[1],pair[0])
                        dict_of_pair[new_pair_post] = dict_of_pair.get(new_pair_post, 0) + word_weights[word_id][1]
                        dict_of_pair[old_pair_post] = dict_of_pair.get(old_pair_post, 0) - word_weights[word_id][1]
                        heapq.heappush(pair_freq_heap, (-dict_of_pair[new_pair_post], new_pair_post))
                        heapq.heappush(pair_freq_heap, (-dict_of_pair[old_pair_post], old_pair_post))   
                        pair_to_words.setdefault(new_pair_post, set()).add(word_id)
                    else:
                        new_pair_post = (idx_now, new_word[i+1])
                        old_pair_post = (pair[1], new_word[i+1])
                        dict_of_pair[new_pair_post] = dict_of_pair.get(new_pair_post, 0) + word_weights[word_id][1]
                        dict_of_pair[old_pair_post] = dict_of_pair.get(old_pair_post, 0) - word_weights[word_id][1]
                        heapq.heappush(pair_freq_heap, (-dict_of_pair[new_pair_post], new_pair_post))
                        heapq.heappush(pair_freq_heap, (-dict_of_pair[old_pair_post], old_pair_post))
                        pair_to_words.setdefault(new_pair_post, set()).add(word_id)
                if i > 0:
                    if new_word[i-1] == idx_now:
                        pass
                    else:
                        new_pair_pre = (new_word[i-1], idx_now)
                        old_pair_pre = (new_word[i-1], pair[0])
                        dict_of_pair[new_pair_pre] = dict_of_pair.get(new_pair_pre, 0) + word_weights[word_id][1]
                        dict_of_pair[old_pair_pre] = dict_of_pair.get(old_pair_pre, 0) - word_weights[word_id][1]
                        heapq.heappush(pair_freq_heap, (-dict_of_pair[new_pair_pre], new_pair_pre))
                        heapq.heappush(pair_freq_heap, (-dict_of_pair[old_pair_pre], old_pair_pre))
                        pair_to_words.setdefault(new_pair_pre, set()).add(word_id)

        word_weights[word_id] = (new_word, word_weights[word_id][1])
        
    valid_merge += 1
    print(f"Token {idx_now} , pair {pair}", end="\r")

Token 10255 , pair (2520, 482)))

In [79]:
tokens

{0: (0,),
 1: (1,),
 2: (2,),
 3: (3,),
 4: (4,),
 5: (5,),
 6: (6,),
 7: (7,),
 8: (8,),
 9: (9,),
 10: (10,),
 11: (11,),
 12: (12,),
 13: (13,),
 14: (14,),
 15: (15,),
 16: (16,),
 17: (17,),
 18: (18,),
 19: (19,),
 20: (20,),
 21: (21,),
 22: (22,),
 23: (23,),
 24: (24,),
 25: (25,),
 26: (26,),
 27: (27,),
 28: (28,),
 29: (29,),
 30: (30,),
 31: (31,),
 32: (32,),
 33: (33,),
 34: (34,),
 35: (35,),
 36: (36,),
 37: (37,),
 38: (38,),
 39: (39,),
 40: (40,),
 41: (41,),
 42: (42,),
 43: (43,),
 44: (44,),
 45: (45,),
 46: (46,),
 47: (47,),
 48: (48,),
 49: (49,),
 50: (50,),
 51: (51,),
 52: (52,),
 53: (53,),
 54: (54,),
 55: (55,),
 56: (56,),
 57: (57,),
 58: (58,),
 59: (59,),
 60: (60,),
 61: (61,),
 62: (62,),
 63: (63,),
 64: (64,),
 65: (65,),
 66: (66,),
 67: (67,),
 68: (68,),
 69: (69,),
 70: (70,),
 71: (71,),
 72: (72,),
 73: (73,),
 74: (74,),
 75: (75,),
 76: (76,),
 77: (77,),
 78: (78,),
 79: (79,),
 80: (80,),
 81: (81,),
 82: (82,),
 83: (83,),
 84: (84,),
