In [38]:
import os

from tqdm.notebook import tqdm

from pyserini.pyclass import autoclass

from datasets import load_dataset, Dataset
from datasets.utils.py_utils import convert_file_size_to_int
from transformers import PreTrainedTokenizerFast
from tokenizers import (
    normalizers,
    pre_tokenizers,
    decoders,
    Tokenizer,
)
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer

In [None]:
JIndexCollection = autoclass('io.anserini.index.IndexCollection')

In [31]:
dset = load_dataset("imdb", split="train")
dset

Found cached dataset imdb (/mnt/1da05489-3812-4f15-a6e5-c8d3c57df39e/cache/huggingface/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

In [32]:
dset = dset.add_column("id", range(len(dset)))
dset

Dataset({
    features: ['text', 'label', 'id'],
    num_rows: 25000
})

In [33]:
dset = dset.rename_column("text", "contents")
dset

Dataset({
    features: ['contents', 'label', 'id'],
    num_rows: 25000
})

In [34]:
dset = dset.select_columns(["id", "contents"])
dset

Dataset({
    features: ['contents', 'id'],
    num_rows: 25000
})

In [35]:
shard_dir = f"../shards/imdb"
max_shard_size = convert_file_size_to_int('10MB')
dataset_nbytes = dset.data.nbytes
num_shards = int(dataset_nbytes / max_shard_size) + 1
num_shards = max(num_shards, 1)
print(f"Sharding into {num_shards} JSONL files.")
os.makedirs(shard_dir, exist_ok=True)
for shard_index in tqdm(range(num_shards)):
    shard = dset.shard(num_shards=num_shards, index=shard_index, contiguous=True)
    shard.to_json(f"{shard_dir}/docs-{shard_index:03d}.jsonl", orient="records", lines=True)

Sharding into 4 JSONL files.


  0%|          | 0/4 [00:00<?, ?it/s]

Creating json from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/7 [00:00<?, ?ba/s]

In [39]:
indexing_args = [
    "-input", shard_dir,
    "-index", "../indexes/imdb",
    "-collection", "JsonCollection",
    "-threads", "28",
    "-language", "en",
    "-storePositions","-storeDocvectors","-storeContents",
]

In [40]:
JIndexCollection.main(indexing_args)

2023-02-22 21:17:10,852 INFO  [main] index.IndexCollection (IndexCollection.java:391) - Setting log level to INFO
2023-02-22 21:17:10,852 INFO  [main] index.IndexCollection (IndexCollection.java:394) - Starting indexer...
2023-02-22 21:17:10,852 INFO  [main] index.IndexCollection (IndexCollection.java:396) - DocumentCollection path: ../shards/imdb
2023-02-22 21:17:10,852 INFO  [main] index.IndexCollection (IndexCollection.java:397) - CollectionClass: JsonCollection
2023-02-22 21:17:10,852 INFO  [main] index.IndexCollection (IndexCollection.java:398) - Generator: DefaultLuceneDocumentGenerator
2023-02-22 21:17:10,852 INFO  [main] index.IndexCollection (IndexCollection.java:399) - Threads: 28
2023-02-22 21:17:10,853 INFO  [main] index.IndexCollection (IndexCollection.java:400) - Language: en
2023-02-22 21:17:10,853 INFO  [main] index.IndexCollection (IndexCollection.java:401) - Stemmer: porter
2023-02-22 21:17:10,853 INFO  [main] index.IndexCollection (IndexCollection.java:402) - Keep st

In [41]:
dset

Dataset({
    features: ['contents', 'id'],
    num_rows: 25000
})

In [49]:
def batch_iterator(dataset, text_column_name, batch_size): # Batch size has to be a multiple of the dataset size
    for i in range(0, len(dataset), batch_size):
        yield dataset.select(range(i, i + batch_size))[text_column_name]

In [44]:
VOCAB_SIZE = 25_000

unicode_normalizer = normalizers.NFKD()
accent_stripper_normalizer = normalizers.StripAccents()

bytelevel_pretokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True, use_regex=True)

bytelevel_decoder = decoders.ByteLevel(add_prefix_space=True, use_regex=True)

tokenizer = Tokenizer(BPE())

tokenizer.normalizer = normalizers.Sequence([unicode_normalizer, accent_stripper_normalizer])
tokenizer.pre_tokenizer = bytelevel_pretokenizer
tokenizer.decoder = bytelevel_decoder

trainer = BpeTrainer(vocab_size=VOCAB_SIZE, show_progress=True)

In [50]:
%%time
tokenizer.train_from_iterator(batch_iterator(dset, "contents", 100), trainer=trainer)




CPU times: user 50.2 s, sys: 8.88 s, total: 59.1 s
Wall time: 3.38 s


In [55]:
tokenizer_model = PreTrainedTokenizerFast(tokenizer_object=tokenizer, vocab_size=VOCAB_SIZE) #Wrap the tokenizer into a transformers tokenizer object to get access to push_to_hub
tokenizer_model.push_to_hub("spacerini/bpe-imdb-25k")

CommitInfo(commit_url='https://huggingface.co/spacerini/bpe-imdb-25k/commit/47c190adc9e9d7e3664145a14281d3eb5dd0ffc4', commit_message='Upload tokenizer', commit_description='', oid='47c190adc9e9d7e3664145a14281d3eb5dd0ffc4', pr_url=None, pr_revision=None, pr_num=None)

In [56]:
indexing_args = [
    "-input", shard_dir,
    "-index", "../indexes/bpe-imdb-25k",
    "-collection", "JsonCollection",
    "-threads", "28",
    "-analyzeWithHuggingFaceTokenizer", "spacerini/bpe-imdb-25k",
    "-storePositions","-storeDocvectors","-storeContents",
]

In [57]:
JIndexCollection.main(indexing_args)

2023-02-22 21:30:15,415 INFO  [main] index.IndexCollection (IndexCollection.java:391) - Setting log level to INFO
2023-02-22 21:30:15,415 INFO  [main] index.IndexCollection (IndexCollection.java:394) - Starting indexer...
2023-02-22 21:30:15,415 INFO  [main] index.IndexCollection (IndexCollection.java:396) - DocumentCollection path: ../shards/imdb
2023-02-22 21:30:15,415 INFO  [main] index.IndexCollection (IndexCollection.java:397) - CollectionClass: JsonCollection
2023-02-22 21:30:15,415 INFO  [main] index.IndexCollection (IndexCollection.java:398) - Generator: DefaultLuceneDocumentGenerator
2023-02-22 21:30:15,415 INFO  [main] index.IndexCollection (IndexCollection.java:399) - Threads: 28
2023-02-22 21:30:15,416 INFO  [main] index.IndexCollection (IndexCollection.java:400) - Language: en
2023-02-22 21:30:15,416 INFO  [main] index.IndexCollection (IndexCollection.java:401) - Stemmer: porter
2023-02-22 21:30:15,416 INFO  [main] index.IndexCollection (IndexCollection.java:402) - Keep st