In [34]:
import regex as re
from tqdm import tqdm
import numpy as np
from datasets import load_dataset
import torch
import json
from simple_tokenizer import excluded_characters_pat, unallowed_characters, get_tok_strs


def get_str_counts(dataset, known_toks=None):
    str_counts = {}
    for item in tqdm(dataset):
        doc = item['text']

        # drop documents that include an explicitly excluded character; I think this is ~8% of docs?
        if excluded_characters_pat.search(doc) is not None:
            continue

        # make sure that no weird characters are inside the doc after filtering
        assert not unallowed_characters.search(doc)

        toks = get_tok_strs(doc, known_toks=known_toks)

        # if known_toks is not None:
        #     # drop documents that contain toks that are not in known_toks
        #     toks_set = set(toks)
        #     if len(toks_set.intersection(known_toks)) < len(toks_set):
        #         continue

        for tok in toks:
            if tok in str_counts:
                str_counts[tok] += 1
            else:
                str_counts[tok] = 1
    
    return str_counts


dataset = load_dataset('noanabeshima/TinyStoriesV2', split='train[:5%]')
str_counts = get_str_counts(dataset)

strs = np.array(list(str_counts.keys()))
counts = torch.tensor(list(str_counts.values()))

topk = counts.topk(k=8192-2) # add space for [bos], [unk]
common_strs = strs[topk.indices]

toks = np.array(['[bos]', '[unk]'] + common_strs.tolist())



with open('./simple_tokenizer/tokens.json', 'w') as f:
    json.dump(toks.tolist(), f)

100%|██████████| 135885/135885 [00:14<00:00, 9542.58it/s]
