# experiment 2 - test train_bpe naive
* use test_train_bpe scripts to test bpe training code

In [1]:

import json
import os
import sys
from typing import BinaryIO, Dict, List, Set, Tuple

import pandas as pd
from transformers import AutoTokenizer
sys.path.append('assignment1-basics')
from cs336_basics.tokenizer.train_naive import train_bpe

from tests.common import gpt2_bytes_to_unicode

In [2]:
## Adapter
def run_train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
    **kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    # raise NotImplementedError
    return train_bpe(
        input_path=input_path,
        vocab_size=vocab_size,
        special_tokens=special_tokens
    )

# 1. Load Testcase

In [3]:
FIXTURES_PATH="assignment1-basics/tests/fixtures"

In [4]:
input_path =  os.path.join(FIXTURES_PATH, "corpus.en")

In [5]:
reference_vocab_path = os.path.join(FIXTURES_PATH, "train-bpe-reference-vocab.json")
reference_merges_path = os.path.join(FIXTURES_PATH, "train-bpe-reference-merges.txt")

gpt2_byte_decoder = {v: k for k, v in gpt2_bytes_to_unicode().items()}
with open(reference_merges_path) as f:
    gpt2_reference_merges = [tuple(line.rstrip().split(" ")) for line in f]
    reference_merges = [
        (
            bytes([gpt2_byte_decoder[token] for token in merge_token_1]),
            bytes([gpt2_byte_decoder[token] for token in merge_token_2]),
        )
        for merge_token_1, merge_token_2 in gpt2_reference_merges
    ]
    
with open(reference_vocab_path) as f:
    gpt2_reference_vocab = json.load(f)
    reference_vocab = {
        gpt2_vocab_index: bytes([gpt2_byte_decoder[token] for token in gpt2_vocab_item])
        for gpt2_vocab_item, gpt2_vocab_index in gpt2_reference_vocab.items()
    }

# 2. Check Vocab

In [6]:
vocab, merges = run_train_bpe(
    input_path=input_path,
    vocab_size=500,
    special_tokens=["<|endoftext|>"],
)

iron cement is a rea


In [7]:
print(set(vocab.keys()) == set(reference_vocab.keys()))
print(set(vocab.values()) == set(reference_vocab.values()))

True
True


In [8]:
vocab_diff = set(vocab.values()) - set(reference_vocab.values())
print(len(vocab_diff), vocab_diff)
vocab_diff = set(reference_vocab.values()) - set(vocab.values())
print(len(vocab_diff), vocab_diff)

0 set()
0 set()


In [9]:
print(len(vocab.keys()))

500


In [10]:
# vocab

In [11]:
# reference_vocab

# 3. Check Merges

In [12]:
print(len(merges))
merges[:3], merges[-3:]

243


([(b' ', b't'), (b' ', b'a'), (b'h', b'e')],
 [(b'f', b'ore'), (b' s', b'it'), (b' ', b'ver')])

In [13]:
print(len(reference_merges))
reference_merges[:3], reference_merges[-3:]

243


([(b' ', b't'), (b' ', b'a'), (b'h', b'e')],
 [(b'f', b'ore'), (b' s', b'it'), (b' ', b'ver')])

In [14]:
print(b' y' > b'u')
print(b'y'>b'wh')
print('Ġy' > 'Ġwh', ' y' > ' wh')

print(b' y' > b'u', b' y' > b' wh')
print(b' y'.decode('utf-8') > b'u'.decode('utf-8'))
# b' y'.decode('utf-8') > b'u'.decode('utf-8')

False
True
True True
False True
False


In [15]:
# Ordering Difference
for x,y in zip(merges, reference_merges, strict=True):
    if x!=y:
        print(x,y)

In [16]:
set(merges)-set(reference_merges)

set()

In [17]:
print(merges==reference_merges)

True


# test_train_bpe_special_tokens

In [18]:
input_path =  os.path.join(FIXTURES_PATH, "tinystories_sample_5M.txt")
vocab, merges = run_train_bpe(
    input_path=input_path,
    vocab_size=1000,
    special_tokens=["<|endoftext|>"],
    num_processes=1
)

u don't have to be s
<|endoftext|>
?? <|endoftext|>

Once upon a time, i
<|endoftext|>
?? <|endoftext|>



Tom and Lily were
<|endoftext|>
?? <|endoftext|>


Once upon a time t
<|endoftext|>
?? <|endoftext|>

One morning, a cat 
<|endoftext|>
?? <|endoftext|>



Lily and Tom were
<|endoftext|>
?? <|endoftext|>


Once upon a time, 
<|endoftext|>
?? <|endoftext|>



Lily and Max were
<|endoftext|>
?? <|endoftext|>

Once upon a time, t
<|endoftext|>
?? <|endoftext|>

Once upon a time, t
<|endoftext|>
?? <|endoftext|>

Once upon a time, i
<|endoftext|>
?? <|endoftext|>

Once upon a time, i
<|endoftext|>
?? <|endoftext|>

Once upon a time, t
<|endoftext|>
?? <|endoftext|>


Emily was very exc
<|endoftext|>
?? <|endoftext|>

Once upon a time, t
<|endoftext|>
?? <|endoftext|>


Once upon a time, 
<|endoftext|>
?? <|endoftext|>


It was a sunny day
<|endoftext|>
?? <|endoftext|>

Once upon a time, t
<|endoftext|>
?? <|endoftext|>

One day, a little c
<|endoftext|>
?? <|endoftext|>

Once upon a

In [19]:
vocabs_without_specials = [word for word in vocab.values() if word != b"<|endoftext|>"]

In [20]:
for word_bytes in vocabs_without_specials:
    if b"<|" in word_bytes:
        print(word_bytes)

In [21]:
len(set(vocab.values()))

1000

In [22]:
import pickle

with open('assignment1-basics/tests/_snapshots/test_train_bpe_special_tokens.pkl', 'rb') as f:
    expected_data = pickle.load(f)

In [23]:
print(set(vocab.values()) - set(expected_data['vocab_values']))
print(set(expected_data['vocab_values']) - set(vocab.values()))

set()
set()


In [24]:
b'\n' > b' lo'

False

In [25]:
reference_merges = expected_data['merges']
for x,y in zip(merges, reference_merges, strict=True):
    if x!=y:
        print(x,y)