# Imports

In [2]:
import lance
import pyarrow as pa

from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm  # for progress tracking

import warnings
warnings.simplefilter('ignore')

# Tokenizer and Dataset

For the sake of simplicity, we are using the `gpt2` tokenizer and the `wikitext-103-raw-v1` dataset.

Note that we are using the dataset in `streaming` mode meaning that no matter the size, the dataset won't be downloaded all at once, instead it will be downloaded as required during writing.

We also define a function to tokenize each sample and return the `input_ids`.

In [3]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')

dataset = load_dataset('wikitext', 'wikitext-103-raw-v1', streaming=True)['train']
dataset = dataset.shuffle(seed=1337)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

In [4]:
def tokenize(sample, field='text'):
    return tokenizer(sample[field])['input_ids']

# Process Samples

This is the most important function here. We go over each sample in the ðŸ¤— dataset, tokenize that sample and then yield a pyarrow `RecordBatch` consisting of the tokens we just tokenized.

In [5]:
def process_samples(dataset, num_samples=100_000, field='text'):
    current_sample = 0

    prog_bar = tqdm(total=num_samples)
    for sample in dataset:
        # We want to stop at num_samples number of samples
        if current_sample >= num_samples:
            break

        # wikitext has some empty strings so we skip them and don't count this sample
        if not sample[field]:
            continue

        # Tokenize the current sample
        tokenized_sample = tokenize(sample, field)

        # Increment the counter and update progress bar
        current_sample += 1
        prog_bar.update(1)

        # Yield a PyArrow RecordBatch
        yield pa.RecordBatch.from_arrays(
            [tokenized_sample],
            names=["input_ids"]
        )

# Writing the dataset to disk

Now that our processing function is ready, we define a schema that tells pyarrow what format of data it should be expecting in the table and we define a reader function that will take in the schema and an iterator (or a function) which will yield the RecordBatches.

Finally, we use that reader by calling `lance.write_dataset` to write this pyarrow table to disk in the highly efficient and fast, lance file format.

That's it!

In [6]:
# Schema to tell pyarrow the type of data we are expecting in our table
schema = pa.schema([
    pa.field("input_ids", pa.int64())
])


# This reader will be used by lance to write the dataset
reader = pa.RecordBatchReader.from_batches(
    schema,
    process_samples(dataset, num_samples=100_000, field='text')
)

In [7]:
# Write the dataset to disk
lance.write_dataset(
    reader,
    "wikitext_100K.lance",
    schema
)

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 100000/100000 [02:00<00:00, 830.09it/s]


<lance.dataset.LanceDataset at 0x7fa0b74b4c40>

## Sanity check

Let's load our newly created dataset and see how many tokens we have in our dataset.

Note that this is the total number of tokens in the dataset, not the number of samples there were in the original ðŸ¤— dataset.

In [8]:
# Load the dataset to inspect the total number of tokens
ds = lance.dataset('wikitext_100K.lance')
print(f"Total tokens in the dataset of 100K wikitext samples: {ds.count_rows():,d}")

Total tokens in the dataset of 100K wikitext samples: 10,007,854
