# experiment 2 - test train_bpe naive
* use test_train_bpe scripts to test bpe training code

In [1]:

import json
import os
import sys
from typing import BinaryIO, Dict, List, Set, Tuple

import pandas as pd
from transformers import AutoTokenizer
sys.path.append('assignment1-basics')
from cs336_basics.tokenizer.train_naive import train_bpe

from tests.common import gpt2_bytes_to_unicode

In [2]:
## Adapter
def run_train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
    **kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    # raise NotImplementedError
    return train_bpe(
        input_path=input_path,
        vocab_size=vocab_size,
        special_tokens=special_tokens
    )

# 1. Load Testcase

In [3]:
FIXTURES_PATH="assignment1-basics/tests/fixtures"

In [4]:
input_path =  os.path.join(FIXTURES_PATH, "corpus.en")

In [5]:
reference_vocab_path = os.path.join(FIXTURES_PATH, "train-bpe-reference-vocab.json")
reference_merges_path = os.path.join(FIXTURES_PATH, "train-bpe-reference-merges.txt")

gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()}
with open(reference_merges_path) as f:
    gpt2_reference_merges = [tuple(line.rstrip().split(" ")) for line in f]
    reference_merges = [
        (
            bytes([gpt2_byte_decoder[token] for token in merge_token_1]),
            bytes([gpt2_byte_decoder[token] for token in merge_token_2]),
        )
        for merge_token_1, merge_token_2 in gpt2_reference_merges
    ]
    
with open(reference_vocab_path) as f:
    gpt2_reference_vocab = json.load(f)
    reference_vocab = {
        gpt2_vocab_index: bytes([gpt2_byte_decoder[token] for token in gpt2_vocab_item])
        for gpt2_vocab_item, gpt2_vocab_index in gpt2_reference_vocab.items()
    }

# 2. Test

In [6]:
vocab, merges = run_train_bpe(
    input_path=input_path,
    vocab_size=500,
    special_tokens=["<|endoftext|>"],
)

CUR: 257, NUM MERGES 243


In [7]:
print(set(vocab.keys()) == set(reference_vocab.keys()))
print(set(vocab.values()) == set(reference_vocab.values()))

True
True


In [8]:
vocab_diff = set(vocab.values()) - set(reference_vocab.values())
print(len(vocab_diff), vocab_diff)
vocab_diff = set(reference_vocab.values()) - set(vocab.values())
print(len(vocab_diff), vocab_diff)

0 set()
0 set()


In [9]:
print(len(vocab.keys()))

500


In [10]:
vocab

{0: b'<|endoftext|>',
 1: b'\x00',
 2: b'\x01',
 3: b'\x02',
 4: b'\x03',
 5: b'\x04',
 6: b'\x05',
 7: b'\x06',
 8: b'\x07',
 9: b'\x08',
 10: b'\t',
 11: b'\n',
 12: b'\x0b',
 13: b'\x0c',
 14: b'\r',
 15: b'\x0e',
 16: b'\x0f',
 17: b'\x10',
 18: b'\x11',
 19: b'\x12',
 20: b'\x13',
 21: b'\x14',
 22: b'\x15',
 23: b'\x16',
 24: b'\x17',
 25: b'\x18',
 26: b'\x19',
 27: b'\x1a',
 28: b'\x1b',
 29: b'\x1c',
 30: b'\x1d',
 31: b'\x1e',
 32: b'\x1f',
 33: b' ',
 34: b'!',
 35: b'"',
 36: b'#',
 37: b'$',
 38: b'%',
 39: b'&',
 40: b"'",
 41: b'(',
 42: b')',
 43: b'*',
 44: b'+',
 45: b',',
 46: b'-',
 47: b'.',
 48: b'/',
 49: b'0',
 50: b'1',
 51: b'2',
 52: b'3',
 53: b'4',
 54: b'5',
 55: b'6',
 56: b'7',
 57: b'8',
 58: b'9',
 59: b':',
 60: b';',
 61: b'<',
 62: b'=',
 63: b'>',
 64: b'?',
 65: b'@',
 66: b'A',
 67: b'B',
 68: b'C',
 69: b'D',
 70: b'E',
 71: b'F',
 72: b'G',
 73: b'H',
 74: b'I',
 75: b'J',
 76: b'K',
 77: b'L',
 78: b'M',
 79: b'N',
 80: b'O',
 81: b'P',
 82: b

In [11]:
reference_vocab

{0: b'<|endoftext|>',
 1: b'!',
 2: b'"',
 3: b'#',
 4: b'$',
 5: b'%',
 6: b'&',
 7: b"'",
 8: b'(',
 9: b')',
 10: b'*',
 11: b'+',
 12: b',',
 13: b'-',
 14: b'.',
 15: b'/',
 16: b'0',
 17: b'1',
 18: b'2',
 19: b'3',
 20: b'4',
 21: b'5',
 22: b'6',
 23: b'7',
 24: b'8',
 25: b'9',
 26: b':',
 27: b';',
 28: b'<',
 29: b'=',
 30: b'>',
 31: b'?',
 32: b'@',
 33: b'A',
 34: b'B',
 35: b'C',
 36: b'D',
 37: b'E',
 38: b'F',
 39: b'G',
 40: b'H',
 41: b'I',
 42: b'J',
 43: b'K',
 44: b'L',
 45: b'M',
 46: b'N',
 47: b'O',
 48: b'P',
 49: b'Q',
 50: b'R',
 51: b'S',
 52: b'T',
 53: b'U',
 54: b'V',
 55: b'W',
 56: b'X',
 57: b'Y',
 58: b'Z',
 59: b'[',
 60: b'\\',
 61: b']',
 62: b'^',
 63: b'_',
 64: b'`',
 65: b'a',
 66: b'b',
 67: b'c',
 68: b'd',
 69: b'e',
 70: b'f',
 71: b'g',
 72: b'h',
 73: b'i',
 74: b'j',
 75: b'k',
 76: b'l',
 77: b'm',
 78: b'n',
 79: b'o',
 80: b'p',
 81: b'q',
 82: b'r',
 83: b's',
 84: b't',
 85: b'u',
 86: b'v',
 87: b'w',
 88: b'x',
 89: b'y',
 90: b'

In [12]:
print(len(merges))
merges[:3], merges[-3:]

243


([(b' ', b't'), (b' ', b'a'), (b'h', b'e')],
 [(b'f', b'ore'), (b' s', b'it'), (b' ', b'ver')])

In [13]:
print(len(reference_merges))
reference_merges[:3], reference_merges[-3:]

243


([(b' ', b't'), (b' ', b'a'), (b'h', b'e')],
 [(b'f', b'ore'), (b' s', b'it'), (b' ', b'ver')])