In [None]:
import glob
import os
from statistics import fmean

import pyrootutils
from tokenizers import Tokenizer

In [None]:
PROJECT_ROOT = path = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

In [None]:
tkz_dict = {
    "LM-merge-1": "outputs/2024-05-06-125114_93f0",
    "LM-nomerge-1": "outputs/2024-05-06-125115_9621",
    "LM-merge-2": "outputs/2024-05-06-125411_bfc9",
    "LM-nomerge-2": "outputs/2024-05-06-125411_90aa",
    "LM-merge-3": "outputs/2024-05-06-125412_5c63",
    "LM-nomerge-3": "outputs/2024-05-06-125411_10c9",
    "LM-merge-4": "outputs/2024-05-06-125411_e5c5",
    "LM-nomerge-4": "outputs/2024-05-06-125411_e6e7",
    "LM-merge-5": "outputs/2024-05-06-125411_1d73",
    "BPE-nomerge-retrain": "outputs/bpe-tokenizers/2024-05-10-114128_b1a1",
    "BPE-nomerge-noretrain": "outputs/bpe-tokenizers/2024-05-10-115123_88af",
    "BPE-merge-retrain": "outputs/bpe-tokenizers/2024-05-10-120039_8a27",
    "BPE-merge-noretrain": "outputs/bpe-tokenizers/2024-05-10-123454_58e0",
}

merge_toks = []
nomerge_toks = []

for tkz, path in tkz_dict.items():
    glob_path = PROJECT_ROOT / path / "[0-9]*-tokenizer.json"
    n_tokenizers = len(glob.glob(str(glob_path)))
    if n_tokenizers > 0 and "LM" in tkz:
        tok_path = str(PROJECT_ROOT / path / f"{n_tokenizers - 1}-tokenizer.json")
        tok = Tokenizer.from_file(tok_path)
        if "LM-merge" in tkz:
            merge_toks.append(tok)
        elif "LM-nomerge" in tkz:
            nomerge_toks.append(tok)

In [None]:
merge_vocabs = [set(tok.get_vocab().keys()) for tok in merge_toks]
nomerge_vocabs = [set(tok.get_vocab().keys()) for tok in nomerge_toks]

merge_lens = fmean([tok.get_vocab_size() for tok in merge_toks])
nomerge_lens = fmean([tok.get_vocab_size() for tok in nomerge_toks])

merge_vocab_int = set.intersection(*merge_vocabs)
nomerge_vocab_int = set.intersection(*nomerge_vocabs)

print(
    "merge vocab", len(merge_vocab_int), merge_lens, len(merge_vocab_int) / merge_lens
)
print(
    "no merge vocab",
    len(nomerge_vocab_int),
    nomerge_lens,
    len(nomerge_vocab_int) / nomerge_lens,
)

merge_vocab_int_lens = fmean([len(s) for s in merge_vocab_int])
nomerge_vocab_int_lens = fmean([len(s) for s in nomerge_vocab_int])

print("merge vocab int len:", merge_vocab_int_lens)
print("nomerge vocab int len:", nomerge_vocab_int_lens)

merge_vocab_total_len = fmean([len(s) for v in merge_vocabs for s in v])
nomerge_vocab_total_len = fmean([len(s) for v in nomerge_vocabs for s in v])

print("merge vocab total len:", merge_vocab_total_len)
print("nomerge vocab total len:", nomerge_vocab_total_len)

In [None]:
print(nomerge_vocabs[0])