In [1]:
import dataclasses
import glob
import json
import logging
import sys
import time
from pathlib import Path

import tokenizers
from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tqdm import tqdm

import bpeasy
from bpeasy.tokenizer import BPEasyTokenizer

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)


@dataclasses.dataclass
class TrainBPETokenizerArgs:
    dataset: str = "./data"
    vocab_size: int = 500
    max_sentencepiece_length: int = 64
    regex_pattern: str = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""

    def __post_init__(self):
        checkpoint_dir = Path(self.dataset)
        assert checkpoint_dir.is_dir(), checkpoint_dir


def jsonl_content_iterator(
    args: TrainBPETokenizerArgs,
):
    """
    Iterates over a jsonl file and yields the content of each line
    Tracks the number of characters yielded and stops when the limit is reached
    This is ripe for optimisation if you want to mess with more fine-grained
    character limits (eg. more Python than Java)
    """
    file_path = args.dataset
    chunk_num, character_count = 0, 0
    chunks = glob.glob(f"{file_path}/*.jsonl")

    while chunk_num < len(chunks):
        file_name = chunks[chunk_num]
        with open(file_name, "r", encoding="utf-8") as f:
            for line in f:
                obj = json.loads(line)
                text = obj["text"]
                text_character_count = len(text)
                character_count += text_character_count
                yield text
        chunk_num += 1


def train_huggingface(args: TrainBPETokenizerArgs):
    # should be at least 0.14.0 to train with char limit
    assert tokenizers.__version__ >= "0.14.0"
    tokenizer = Tokenizer(BPE(byte_fallback=True))
    trainer = BpeTrainer(
        vocab_size=args.vocab_size,
        special_tokens=[f"<0x{i:02X}>" for i in range(256)],  # seed sm vocab
        max_token_length=args.max_sentencepiece_length,
        show_progress=False,
    )
    gpt_regex = Regex(args.regex_pattern)

    split_pre_tokenizer = pre_tokenizers.Split(
        gpt_regex, behavior="isolated", invert=False
    )
    byte_pre_tokenizer = pre_tokenizers.ByteLevel(
        add_prefix_space=False, use_regex=False
    )
    tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
        [split_pre_tokenizer, byte_pre_tokenizer]
    )
    # Use ByteLevel Decoder
    tokenizer.decoder = decoders.Sequence(
        [decoders.ByteLevel(), decoders.ByteFallback()]
    )
    iterator = jsonl_content_iterator(args)
    # training the tokenizer
    tokenizer.train_from_iterator(iterator, trainer)

    return tokenizer


def train_bpeasy(args: TrainBPETokenizerArgs):
    # Use ByteLevel Decoder
    iterator = jsonl_content_iterator(args)
    # training the tokenizer
    vocab = bpeasy.train_bpe(
        iterator,
        args.regex_pattern,
        args.max_sentencepiece_length,
        args.vocab_size,
    )

    return BPEasyTokenizer(
        vocab,
        args.regex_pattern,
        special_tokens=[],
        fill_to_nearest_multiple_of_eight=False,
    )


def encode(tokenizer, args) -> float:
    iterator = jsonl_content_iterator(args)
    lengths = []
    for text in iterator:
        encoded = tokenizer.encode(text)
        lengths.append(len(encoded))
    return sum(lengths)


def get_mean_std_dev(times: list[float]) -> tuple[float, float]:
    avg_time = sum(times) / len(times)
    std_dev = sum([(t - avg_time) ** 2 for t in times])
    return avg_time, std_dev


NUM_ITERATIONS = 1
args = TrainBPETokenizerArgs()

times_train_huggingface = []
times_encode_huggingface = []
times_train_bpeasy = []
times_encode_bpeasy = []
lengths_huggingface = []
lengths_bpeasy = []


time_now = time.time()
hf_tokenizer = train_huggingface(args)
print(sorted(hf_tokenizer.get_vocab().items(), key=lambda x: x[1])[255:])
times_train_huggingface.append(time.time() - time_now)

time_now = time.time()
lengths_huggingface.append(encode(hf_tokenizer, args))
times_encode_huggingface.append(time.time() - time_now)

time_now = time.time()
bpeasy_tokenizer = train_bpeasy(args)
print(sorted(bpeasy_tokenizer.vocab.items(), key=lambda x: x[1])[255:])
times_train_bpeasy.append(time.time() - time_now)

time_now = time.time()
lengths_bpeasy.append(encode(bpeasy_tokenizer, args))
times_encode_bpeasy.append(time.time() - time_now)

m_hf, std_hf = get_mean_std_dev(times_train_huggingface)
m_bpeasy, std_bpeasy = get_mean_std_dev(times_train_bpeasy)

print(f"huggingface train time {m_hf} +/- {std_hf}")
print(f"bpeasy train time {m_bpeasy} +/- {std_bpeasy}")

m_hf, std_hf = get_mean_std_dev(times_encode_huggingface)
m_bpeasy, std_bpeasy = get_mean_std_dev(times_encode_bpeasy)

print(f"huggingface encode time {m_hf} +/- {std_hf}")
print(f"bpeasy encode time {m_bpeasy} +/- {std_bpeasy}")

m_hf, std_hf = get_mean_std_dev(lengths_huggingface)
m_bpeasy, std_bpeasy = get_mean_std_dev(lengths_bpeasy)

print(f"huggingface length {m_hf} +/- {std_hf}")
print(f"bpeasy length {m_bpeasy} +/- {std_bpeasy}")

[('<0xFF>', 255), ('!', 256), ('"', 257), ('#', 258), ('$', 259), ('%', 260), ('&', 261), ("'", 262), ('(', 263), (')', 264), ('*', 265), ('+', 266), (',', 267), ('-', 268), ('.', 269), ('/', 270), ('0', 271), ('1', 272), ('2', 273), ('3', 274), ('4', 275), ('5', 276), ('6', 277), ('7', 278), ('8', 279), ('9', 280), (':', 281), (';', 282), ('<', 283), ('=', 284), ('>', 285), ('?', 286), ('@', 287), ('A', 288), ('B', 289), ('C', 290), ('D', 291), ('E', 292), ('F', 293), ('G', 294), ('H', 295), ('I', 296), ('J', 297), ('K', 298), ('L', 299), ('M', 300), ('N', 301), ('O', 302), ('P', 303), ('Q', 304), ('R', 305), ('S', 306), ('T', 307), ('U', 308), ('V', 309), ('W', 310), ('X', 311), ('Y', 312), ('Z', 313), ('[', 314), ('\\', 315), (']', 316), ('_', 317), ('`', 318), ('a', 319), ('b', 320), ('c', 321), ('d', 322), ('e', 323), ('f', 324), ('g', 325), ('h', 326), ('i', 327), ('j', 328), ('k', 329), ('l', 330), ('m', 331), ('n', 332), ('o', 333), ('p', 334), ('q', 335), ('r', 336), ('s', 337

In [26]:
iterator = jsonl_content_iterator(args)

for text in iterator:
    if (
        len(hf_tokenizer.encode(text)) < len(bpeasy_tokenizer.encode(text))
        and len(text) < 100
    ):
        print(text)
        print(hf_tokenizer.encode(text).tokens)
        print([bpeasy_tokenizer.decode([t]) for t in bpeasy_tokenizer.encode(text)])
        break

Deployed from e27d7a207f. You are on web.3. UTC time is currently 23 Apr 2019 15:02:46 +00:00.
['D', 'e', 'p', 'l', 'o', 'y', 'ed', 'Ġf', 'r', 'om', 'Ġe', '2', '7', 'd', '7', 'a', '2', '0', '7', 'f', '.', 'Ġ', 'Y', 'ou', 'Ġa', 're', 'Ġ', 'on', 'Ġw', 'e', 'b', '.', '3', '.', 'Ġ', 'U', 'T', 'C', 'Ġt', 'i', 'm', 'e', 'Ġis', 'Ġc', 'u', 'r', 're', 'n', 't', 'ly', 'Ġ', '2', '3', 'Ġ', 'A', 'p', 'r', 'Ġ', '2', '0', '1', '9', 'Ġ', '1', '5', ':', '0', '2', ':', '4', '6', 'Ġ', '+', '0', '0', ':', '0', '0', '.']
['D', 'e', 'p', 'l', 'o', 'y', 'ed', ' ', 'f', 'r', 'o', 'm', ' ', 'e', '2', '7', 'd', '7', 'a', '2', '0', '7', 'f', '.', ' ', 'Y', 'o', 'u', ' are', ' on', ' ', 'w', 'e', 'b', '.', '3', '.', ' ', 'U', 'T', 'C', ' ', 't', 'i', 'm', 'e', ' is', ' ', 'c', 'u', 'r', 'r', 'ent', 'l', 'y', ' ', '2', '3', ' A', 'p', 'r', ' ', '2', '0', '1', '9', ' ', '1', '5', ':', '0', '2', ':', '4', '6', ' ', '+', '0', '0', ':', '0', '0', '.']


In [20]:
import 
iterator = jsonl_content_iterator(args)
vocab = bpeasy.train_bpe(
    iterator,
    args.regex_pattern,
    args.max_sentencepiece_length,
    args.vocab_size,
)

['D',
 'e',
 'p',
 'l',
 'o',
 'y',
 'ed',
 'Ġf',
 'r',
 'om',
 'Ġe',
 '2',
 '7',
 'd',
 '7',
 'a',
 '2',
 '0',
 '7',
 'f',
 '.',
 'Ġ',
 'Y',
 'ou',
 'Ġa',
 're',
 'Ġ',
 'on',
 'Ġw',
 'e',
 'b',
 '.',
 '3',
 '.',
 'Ġ',
 'U',
 'T',
 'C',
 'Ġt',
 'i',
 'm',
 'e',
 'Ġis',
 'Ġc',
 'u',
 'r',
 're',
 'n',
 't',
 'ly',
 'Ġ',
 '2',
 '3',
 'Ġ',
 'A',
 'p',
 'r',
 'Ġ',
 '2',
 '0',
 '1',
 '9',
 'Ġ',
 '1',
 '5',
 ':',
 '0',
 '2',
 ':',
 '4',
 '6',
 'Ġ',
 '+',
 '0',
 '0',
 ':',
 '0',
 '0',
 '.']