In [1]:
import os
import pickle
import regex as re
import multiprocessing
from copy import deepcopy
from typing import BinaryIO, List, Tuple, Iterable
from collections import defaultdict

# Byte-Pair Encoding (BPE) Tokenizer

## The Unicode Standard

### Problem

In [None]:
chr(0)

In [None]:
print(chr(0))

In [None]:
"this is a test" + chr(0) + "string"

In [None]:
print("this is a test" + chr(0) + "string")

## Unicode Encodings

In [None]:
test_string = "hello! こんにちは!"

In [None]:
utf8_encoded = test_string.encode("utf-8")

In [None]:
print(utf8_encoded)

In [None]:
# Get the byte values for the encoded string (integers from 0 to 255).
list(utf8_encoded)

In [None]:
# One byte does not necessarily correspond to one Unicode character!
print(len(test_string))

In [None]:
print(len(utf8_encoded))

In [None]:
print(utf8_encoded.decode("utf-8"))

### Problem

#### a

In [None]:
list(test_string.encode("utf-8"))

In [None]:
list(test_string.encode("utf-8")).__len__()

In [None]:
list(test_string.encode("utf-16"))

In [None]:
list(test_string.encode("utf-16")).__len__()

In [None]:
list(test_string.encode("utf-32"))

In [None]:
list(test_string.encode("utf-32")).__len__()

#### b

In [None]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes): 
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

In [None]:
"hello 你好".encode("utf-8")

In [None]:
for b in "hello 你好".encode("utf-8"):
    print(b)

In [None]:
decode_utf8_bytes_to_str_wrong("hello 你好".encode("utf-8"))

In [None]:
bytes([104]).decode("utf-8")

In [None]:
bytes([228, 189, 160]).decode("utf-8")

In [None]:
bytes([228]).decode("utf-8")

#### c

In [None]:
bytes([228, 189]).decode("utf-8")

## Subword Tokenization

In [None]:
list("the".encode("utf-8"))

## BPE Tokenizer Training

In [None]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [None]:
re.findall(PAT, "some text that i'll pre-tokenize")

In [None]:
for i in re.finditer(PAT, "some text that i'll pre-tokenize"):
    print(i)

### Problem (train_bpe)

In [None]:
def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

In [None]:
with open("/data1/ds/scratch/assignment1-basics/data/TinyStoriesV2-GPT4-valid.txt", "rb") as f:
    num_processes = 4
    boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

    # The following is a serial implementation, but you can parallelize this
    # by sending each start/end pair to a set of processes.
    for start, end in zip(boundaries[:-1], boundaries[1:]):
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")
        # Run pre-tokenization on your chunk and store the counts for each pre-token

In [None]:
print(chunk)

In [None]:
eos_pat = re.escape("<|endoftext|>")

In [None]:
splits = re.split("|".join([eos_pat]), chunk)
splits

In [None]:
test_doc = splits[1].strip()

In [None]:
str_tokens = re.findall(PAT, test_doc)
byte_tokens = [s.encode("utf-8") for s in str_tokens]

In [None]:
byte_tokens

In [None]:
special_tokens_bytes = [
    token.encode("utf-8") for token in ["<|endoftext|>", "<|startoftext|>"]
]

In [None]:
special_tokens_bytes

In [None]:
byte_tokens

In [None]:
def train_bpe_naive(
    input_path: str,
    vocab_size: int,
    special_tokens: List[str]
):
    GPT2_REGEX_PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    idx2bytes = {i: bytes([i]) for i in range(256)}

    idx = 256

    # ------ 1. deal with special tokens
    special_token_patterns = []
    for special_token in special_tokens:
        encoded_special_token = special_token.encode("utf-8")
        idx2bytes[idx] = encoded_special_token
        idx += 1

        special_token_patterns.append(re.escape(special_token))
    
    if len(idx2bytes) > vocab_size:
        raise ValueError(f"desired vocabulary size of {vocab_size} is smaller than initial vocabulary size of {len(idx2bytes)}")
    
    # ------ 2. read text
    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    
    # ------ 3. initialize byte_tokens_cnt and byte_token_to_word mappings
    splits = re.split("|".join(special_token_patterns), text)
    print(splits)
    byte_tokens_cnt = defaultdict(int) # {(b'l', b'o', b'w'): 3, (b'l', b'o', b'b'): 4, ...}

    for split in splits:
        # doc = split.strip()

        str_tokens = re.findall(GPT2_REGEX_PAT, split)
        for s in str_tokens:
            encoded_s = list(s.encode("utf-8")) # [111,222,123,...]
            byte_s = tuple([bytes([e]) for e in encoded_s])
            byte_tokens_cnt[byte_s] += 1

    # ------ 4. perform merges
    merges = []
    while len(idx2bytes) < vocab_size:
        # ------ 4.1 perform pair counts
        pair_cnt = defaultdict(int) # {(b'l', b'o'): 7, (b'o', b'w'): 3, ...}
        for k, v in byte_tokens_cnt.items():
            word_len = len(k)
            for i in range(word_len-1):
                curr_pair = (k[i], k[i+1])
                pair_cnt[curr_pair] += v
        
        if len(pair_cnt) == 0:
            break
        
        # print(pair_cnt)
        # ------ 4.2 get candidate with max pair count
        max_cnt = max(pair_cnt.values())
        candidates = [k for k, v in pair_cnt.items() if v==max_cnt]
        l, r = max(candidates)
        curr_merge = l + r
        merges.append((l, r))

        # ------ 4.3 update byte_tokens_cnt
        new_byte_tokens_cnt = {}
        for k, v in byte_tokens_cnt.items():
            word_len = len(k)
            i = 0
            curr_byte_tokens = []

            while i < word_len:
                curr_l = k[i]
                if i < word_len - 1:
                    curr_r = k[i+1]
                else:
                    curr_r = None

                if curr_l == l and curr_r == r:
                    curr_byte_tokens.append(l+r)
                    i += 2
                else:
                    curr_byte_tokens.append(curr_l)
                    i += 1
            
            new_byte_tokens_cnt[tuple(curr_byte_tokens)] = v
        
        byte_tokens_cnt = new_byte_tokens_cnt
        # print(byte_tokens_cnt)
        
        # ------ 4.4 update idx2bytes
        idx2bytes[idx] = curr_merge
        idx += 1
    
    return idx2bytes, merges

In [None]:
vocab1, merges1 = train_bpe_naive(
    "/data1/ds/scratch/assignment1-basics/data/simple_test.txt",
    10000,
    ["<|endoftext|>"]
)

In [None]:
def pretokenize(input_path, start, end, special_token_patterns):
    GPT2_REGEX_PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    with open(input_path, "rb") as f:
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")

    splits = re.split("|".join(special_token_patterns), chunk)
    # print(splits)
    byte_tokens_cnt = defaultdict(int) # {(b'l', b'o', b'w'): 3, (b'l', b'o', b'b'): 4, ...}

    for split in splits:
        str_tokens = re.findall(GPT2_REGEX_PAT, split)

        for s in str_tokens:
            encoded_s = list(s.encode("utf-8")) # [111,222,123,...]
            byte_s = tuple([bytes([e]) for e in encoded_s])
            byte_tokens_cnt[byte_s] += 1
    
    return byte_tokens_cnt

In [None]:
def train_bpe_improved(
    input_path: str,
    vocab_size: int,
    special_tokens: List[str]
):
    idx2bytes = {i: bytes([i]) for i in range(256)}

    idx = 256

    # ------ 1. deal with special tokens
    special_token_patterns = []
    for special_token in special_tokens:
        encoded_special_token = special_token.encode("utf-8")
        idx2bytes[idx] = encoded_special_token
        idx += 1

        special_token_patterns.append(re.escape(special_token))
    
    if len(idx2bytes) > vocab_size:
        raise ValueError(f"desired vocabulary size of {vocab_size} is smaller than initial vocabulary size of {len(idx2bytes)}")
    
    # ------ 2. read text
    # with open(input_path, "r", encoding="utf-8") as f:
    #     text = f.read()

    args = []

    # copy from pretokenization_example.py
    with open(input_path, "rb") as f:
        num_processes = 4
        boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

        for start, end in zip(boundaries[:-1], boundaries[1:]):
            args.append((input_path, start, end, special_token_patterns))

    with multiprocessing.Pool() as pool:
        results = pool.starmap(pretokenize, args)
    
    # ------ 3. initialize byte_tokens_cnt
    byte_tokens_cnt = defaultdict(int)
    for _map in results:
        for k, v in _map.items():
            byte_tokens_cnt[k] += v

    # ------ 4. prepare pair_cnt
    pair_cnt = defaultdict(int) # {(b'l', b'o'): 7, (b'o', b'w'): 3, ...}
    pair2keys = defaultdict(set) # {(b'o', b'w'): { (b'l', b'o', b'w', b'e', b'r'), (b'p', b'o', b'w', b'e', b'r'), ... }
    for k, v in byte_tokens_cnt.items():
        word_len = len(k)
        for i in range(word_len-1):
            curr_pair = (k[i], k[i+1])
            pair_cnt[curr_pair] += v
            pair2keys[curr_pair].add(k)
    
    # ------ 5. perform merges (modify pair_cnts only)
    merges = []
    while len(idx2bytes) < vocab_size:        
        # ------ 5.1 get candidate with max pair count
        max_cnt = max(pair_cnt.values())

        if max_cnt == 0:
            break

        candidates = [k for k, v in pair_cnt.items() if v==max_cnt]
        l, r = max(candidates)
        curr_merge = l + r
        merges.append((l, r))

        # ------ 5.2 update keys
        keys = deepcopy(pair2keys[(l, r)])
        for prev_key in keys:
            key_len = len(prev_key)
            key_cnt = byte_tokens_cnt.pop(prev_key)

            new_key = []
            i = 0

            while i < key_len:
                curr_l = prev_key[i]
                if i < key_len - 1:
                    curr_r = prev_key[i+1]
                else:
                    curr_r = None
                
                if curr_l == l and curr_r == r:
                    new_key.append(l+r)
                    i += 2
                else:
                    new_key.append(curr_l)
                    i += 1
            
            new_key = tuple(new_key)
            byte_tokens_cnt[new_key] = key_cnt

            # ------ 5.3 update counts and mapping
            for left, right in zip(prev_key[:-1], prev_key[1:]):
                pair_cnt[(left, right)] -= key_cnt
                curr_set = pair2keys[(left, right)]

                if prev_key in curr_set:
                    curr_set.remove(prev_key)
            
            for left, right in zip(new_key[:-1], new_key[1:]):
                pair_cnt[(left, right)] += key_cnt
                pair2keys[(left, right)].add(new_key)
            
        # ------ 5.4 update idx2bytes
        idx2bytes[idx] = curr_merge
        idx += 1
    
    return idx2bytes, merges

In [None]:
vocab2, merges2 = train_bpe_improved(
    "/data1/ds/scratch/assignment1-basics/data/simple_test.txt",
    10000,
    ["<|endoftext|>"]
)

In [None]:
def run_train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
    **kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """

    # Step 1: Initialize Vocabulary
    vocab: Dict[int, bytes] = {i: bytes([i]) for i in range(256)}
    next_id = 256

    special_token_bytes = [token.encode("utf-8") for token in special_tokens]
    for token_bytes in special_token_bytes:
        if token_bytes not in vocab.values():
            vocab[next_id] = token_bytes
            next_id += 1

    # Step 2: Pre-tokenization
    pre_tokens_cnt = defaultdict(int)

    def to_bytes_tuple(word: str) -> Tuple[bytes]:
        l = list(tuple(word.encode("utf-8")))
        l = [bytes([x]) for x in l]
        return tuple(l)

    with open(input_path, "r", encoding="utf-8") as f:
        text = f.read()
    
    chunks = re.split("|".join(map(re.escape, special_tokens)), text)
    
    for chunk in chunks:
        for m in re.finditer(PAT, chunk):
            word = m.group(0)
            pre_tokens_cnt[to_bytes_tuple(word)] += 1   # key of pre_tokens_cnt e.g. (b'H', b'e', b'l', b'l', b'o')

    # Step 3: Compute BPE Merges
    merges = []

    while len(vocab) < vocab_size:
        pair_counts = defaultdict(int)

        # Count all adjacent byte pairs
        for token, cnt in pre_tokens_cnt.items():
            for i in range(len(token) - 1):
                pair = (token[i], token[i + 1])
                pair_counts[pair] += cnt

        if not pair_counts:
            break  # No more pairs to merge

        # Find the most frequent pair(s)
        max_count = max(pair_counts.values())
        candidates = [k for k, v in pair_counts.items() if v == max_count]
        best_pair = max(candidates)

        a, b = best_pair

        # Create new token
        new_token = a + b
        vocab[next_id] = new_token
        next_id += 1

        # Apply the merge to all pre-tokenized sequences
        # 收集变更
        changes = []
        for token, cnt in pre_tokens_cnt.items():
            # Find all occurrences of the `best_pair` in `token`
            indices = [i for i in range(len(token) - 1) if token[i:i + 2] == best_pair]
            if indices:
                # Replace each occurrence with `new_token`
                new_pre_token = []
                i = 0
                while i < len(token):
                    if i in indices:
                        new_pre_token.append(new_token)
                        i += 2
                    else:
                        new_pre_token.append(token[i])
                        i += 1
                new_pre_token = tuple(new_pre_token)
                changes.append((token, new_pre_token, cnt))

        # 应用变更
        for old_token, new_pre_token, cnt in changes:
            pre_tokens_cnt[new_pre_token] = pre_tokens_cnt.get(new_pre_token, 0) + cnt
            del pre_tokens_cnt[old_token]

        # Record the merge
        merges.append((a, b))

    return vocab, merges

In [None]:
vocab3, merges3 = run_train_bpe(
    "/data1/ds/scratch/assignment1-basics/data/simple_test.txt",
    10000,
    ["<|endoftext|>"]
)

In [None]:
merges1 == merges2

In [None]:
merges2 == merges3

In [None]:
vocab1 == vocab2

In [None]:
vocab3 == vocab2

### Problem (train_bpe_tinystories)

In [None]:
vocab_ts, merges_ts = train_bpe_improved(
    "/data1/ds/scratch/assignment1-basics/data/TinyStoriesV2-GPT4-train.txt",
    10000,
    ["<|endoftext|>"]
)

In [None]:
def save_with_pickle(vocab_data, merges_data, filename):
    data_to_save = {
        "vocab": vocab_data,
        "merges": merges_data
    }
    
    # Open the file in binary write mode ('wb')
    with open(filename, "wb") as f:
        pickle.dump(data_to_save, f)
        
    print(f"Data saved to {filename} using pickle.")

In [None]:
save_with_pickle(vocab_ts, merges_ts, "./TinyStoriesV2-results.pkl")

### Problem (train_bpe_expts_owt)

In [None]:
vocab_owt, merges_owt = train_bpe_improved(
    "/data1/ds/scratch/assignment1-basics/data/owt_train.txt",
    32000,
    ["<|endoftext|>"]
)

In [None]:
save_with_pickle(vocab_owt, merges_owt, "./owt-results.pkl")

## BPE Tokenizer: Encoding and Decoding

### Encoding Text

### Decoding Text

In [None]:
vocab3[256]

In [None]:
merges3

#### Problem (tokenizer)

In [105]:
class Tokenizer:
    def __init__(self, vocab, merges, special_tokens=None):
        self.vocab = vocab
        self.merges = merges

        if special_tokens is None:
            special_tokens = []
        
        self.special_tokens = set(special_tokens)

        self.token2idx = {}

        for k, v in self.vocab.items():
            self.token2idx[v] = k
        
        self.__add_special_tokens()
        # print(self.vocab)
        # print(self.merges)
        # print(self.special_tokens)
        # print(self.token2idx)
    

    def __add_special_tokens(self):
        for token in self.special_tokens:
            encoded_token = token.encode("utf-8")

            if encoded_token not in self.token2idx:
                self.vocab[len(self.vocab)] = encoded_token
                self.token2idx[encoded_token] = len(self.token2idx)


    @classmethod
    def from_files(cls, vocab_filepath, merges_filepath, filepath, special_tokens=None):
        assert vocab_filepath is None or len(vocab_filepath) == 0, f"Your input vocab_filepath is {vocab_filepath}. You should leave vocab_filepath empty and provide filepath instead"
        assert merges_filepath is None or len(merges_filepath) == 0, f"Your input merges_filepath is {merges_filepath}. You should leave merges_filepath empty and provide filepath instead"

        try:
            with open(filepath, 'rb') as file:
                data = pickle.load(file)
        except Exception as exc:
            raise
        
        vocab = data["vocab"]
        merges = data["merges"]

        return cls(vocab, merges, special_tokens)


    def encode(self, text: str):
        if len(self.special_tokens) > 0:
            escaped = [re.escape(token) for token in self.special_tokens]
            pattern = f"({'|'.join(escaped)})"
            parts = re.split(pattern, text)
        else:
            parts = [text]

        tokens = []
        for part in parts:
            if not parts:
                continue

            if part in self.special_tokens:
                encoded_part = self.token2idx[part.encode("utf-8")]
                tokens.append(encoded_part)
            else:
                encoded_part = self.encode_helper(part)
                tokens.extend(encoded_part)
        
        return tokens        


    def encode_helper(self, text):
        GPT2_REGEX_PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
        str_tokens = re.findall(GPT2_REGEX_PAT, text)
        # print(str_tokens)
        unique_str_tokens = {}

        for token in str_tokens:
            unique_str_tokens[token] = []

        for token, _ in unique_str_tokens.items():
            encoded = list(token.encode("utf-8")) # [111,222,123,...]
            encoded_byte = [bytes([e]) for e in encoded]

            pairs = []
            pairs_set = set()
            for i in range(len(encoded_byte)-1):
                pairs.append(encoded_byte[i]+encoded_byte[i+1])
                pairs_set.add(encoded_byte[i]+encoded_byte[i+1])
            
            idx = 0
            while len(pairs) >= 1 and idx < len(self.merges):
                while idx < len(self.merges):
                    curr_merge = self.merges[idx][0] + self.merges[idx][1]
                    if curr_merge in pairs_set:
                        break
                
                    idx += 1
                
                if idx >= len(self.merges):
                    break
            
                curr_pair = self.merges[idx][0] + self.merges[idx][1]
                new_encoded_byte = []

                i = 0
                while i < len(encoded_byte):
                    left = encoded_byte[i]
                    right = encoded_byte[i+1] if i < len(encoded_byte) - 1 else None

                    if right is not None and left + right == curr_pair:
                        new_encoded_byte.append(left+right)
                        i += 2
                    else:
                        new_encoded_byte.append(left)
                        i += 1
                
                new_pairs = []
                for i in range(len(new_encoded_byte)-1):
                    new_pairs.append(new_encoded_byte[i]+new_encoded_byte[i+1])
                
                pairs = new_pairs
                pairs_set = set(pairs)
                encoded_byte = new_encoded_byte
                
                # print(f"idx={idx}")
                # print(token)
                # print(f"pairs={pairs}")
                # print(f"encoded_byte={encoded_byte}")
            
            for item in encoded_byte:
                unique_str_tokens[token].append(self.token2idx[item])
        
        res = []
        for token in str_tokens:
            res.extend(unique_str_tokens[token])
            
        return res


    def encode_iterable(self, iterable: Iterable[str]):
        for string in iterable:
            encoded = self.encode(string)
            for _id in encoded:
                yield _id


    def decode(self, ids: list[int]):
        decoded = []

        for _id in ids:
            if _id in self.vocab:
                decoded.append(self.vocab[_id])
            else:
                decoded.append("\uFFFD".encode("utf-8"))

        if len(decoded) == 0:
            return ""
        
        res = decoded[0]
        for i in range(1, len(decoded)):
            res += decoded[i]

        return res.decode("utf-8", errors="replace")

In [79]:
def gpt2_bytes_to_unicode() -> dict[int, str]:
    # These 188 integers can used as-is, since they are not whitespace or control characters.
    # See https://www.ssec.wisc.edu/~tomw/java/unicode.html.
    bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    cs = bs[:]
    # now get the representations of the other 68 integers that do need shifting
    # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop
    # Get printable representations of the remaining integers 68 integers.
    n = 0
    for b in range(2**8):
        if b not in bs:
            # If this integer isn't in our list of visually-representable
            # charcters, then map it to the next nice character (offset by 256)
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    characters = [chr(n) for n in cs]
    d = dict(zip(bs, characters))
    return d

def get_tokenizer_from_vocab_merges_path(
    vocab_path: str | os.PathLike,
    merges_path: str | os.PathLike,
    special_tokens: list[str] | None = None,
):
    import json
    gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()}
    with open(vocab_path) as vocab_f:
        gpt2_vocab = json.load(vocab_f)
    gpt2_bpe_merges = []
    with open(merges_path) as f:
        for line in f:
            cleaned_line = line.rstrip()
            if cleaned_line and len(cleaned_line.split(" ")) == 2:
                gpt2_bpe_merges.append(tuple(cleaned_line.split(" ")))
    # The GPT-2 tokenizer uses a remapped unicode encoding for bytes. Let's
    # just return the original bytes, so we don't force students to use
    # any particular encoding scheme.
    vocab = {
        gpt2_vocab_index: bytes([gpt2_byte_decoder[token] for token in gpt2_vocab_item])
        for gpt2_vocab_item, gpt2_vocab_index in gpt2_vocab.items()
    }
    # If any of the special tokens don't exist in the vocab, append them to the vocab.
    if special_tokens:
        for special_token in special_tokens:
            byte_encoded_special_token = special_token.encode("utf-8")
            if byte_encoded_special_token not in set(vocab.values()):
                vocab[len(vocab)] = byte_encoded_special_token

    merges = [
        (
            bytes([gpt2_byte_decoder[token] for token in merge_token_1]),
            bytes([gpt2_byte_decoder[token] for token in merge_token_2]),
        )
        for merge_token_1, merge_token_2 in gpt2_bpe_merges
    ]
    return Tokenizer(vocab, merges, special_tokens)

In [106]:
VOCAB_PATH = "/data1/ds/scratch/assignment1-basics/tests/fixtures/gpt2_vocab.json"
MERGES_PATH = "/data1/ds/scratch/assignment1-basics/tests/fixtures/gpt2_merges.txt"

In [107]:
tokenizer = get_tokenizer_from_vocab_merges_path(VOCAB_PATH, MERGES_PATH, special_tokens=["<|endoftext|>", "<|endoftext|><|endoftext|>"])

In [108]:
test_string = "Hello, how <|endoftext|><|endoftext|> are you?<|endoftext|>"

ids = tokenizer.encode(test_string)
tokenized_string = [tokenizer.decode([x]) for x in ids]

In [109]:
tokenized_string

['Hello',
 ',',
 ' how',
 ' ',
 '<|endoftext|><|endoftext|>',
 ' are',
 ' you',
 '?',
 '<|endoftext|>']

In [110]:
tokenized_string.count("<|endoftext|>")

1

In [111]:
tokenized_string.count("<|endoftext|><|endoftext|>")

1

In [113]:
tokenizer.decode(ids) == test_string

True