In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer
import numpy as np
import grain

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", add_bos_token=True)
sequence_length = 1024
# https://github.com/huggingface/nanotron/blob/7bc9923285a03069ebffe994379a311aceaea546/src/nanotron/data/processing.py#L47
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: np.concatenate(v) for k, v in examples.items()}
    total_length = len(concatenated_examples[next(iter(examples.keys()))])
    # WARNING: We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= sequence_length + 1:
        total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
    # Split by chunks of sequence_length.
    result = {
        k: [
            t[i : i + sequence_length + 1] for i in range(0, total_length - (sequence_length + 1), sequence_length)
        ]
        for k, t in concatenated_examples.items()
    }
    return result

def tokenize_and_group_texts(texts):
    tokenized_batch = tokenizer(texts, return_attention_mask=False, return_token_type_ids=False, return_tensors="np")
    grouped_batch = group_texts(tokenized_batch)
    return grouped_batch



# ds = load_dataset(*hf_name, split=split, streaming=True)
ds = load_dataset("HuggingFaceFW/fineweb", "sample-10BT", split="train", num_proc=None, streaming=True)
ds = ds.map(
    tokenize_and_group_texts,
    input_columns=["text"],
    remove_columns=ds.column_names,
    batched=True,
)



In [3]:
hf_ds = ds

In [4]:
next(iter(hf_ds))

Token indices sequence length is longer than the specified maximum sequence length for this model (1048 > 1024). Running this sequence through the model will result in indexing errors


{'input_ids': array([  91, 7680,  278, ..., 2099,  286, 7977], shape=(1025,))}

In [5]:
class HFStreamingDataSource(grain.sources.RandomAccessDataSource):
    def __init__(self, iterable_ds):
        self._ds = iterable_ds
        self._it = None

    def __len__(self) -> int:
        return 10_000_000_000

    def __getitem__(self, record_key: int):
        if self._it is None:
            self._it = iter(self._ds)
        try:
            return next(self._it)
        except StopIteration:
            self._it = iter(self._ds)
            return next(self._it)

In [10]:
source = HFStreamingDataSource(hf_ds)

sampler = grain.samplers.IndexSampler(
    num_records=len(source),
    shuffle=False,
    seed=0,
)

class GetInputAndTarget(grain.transforms.Map):
  def map(self, x: int) -> int:
    return x['input_ids'][:-1], x['input_ids'][1:]

operations = []
operations.append(GetInputAndTarget())
operations.append(grain.transforms.Batch(2))

data_loader = grain.DataLoader(
    data_source=source,
    sampler=sampler,
    operations=operations,
    worker_count=0,
    read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=128),
)

In [11]:
for x in iter(data_loader):
    print(x)

(array([[  91, 7680,  278, ...,  257, 2099,  286],
       [7977, 9102,  326, ...,  340,  257, 4165]], shape=(2, 1024)), array([[ 7680,   278, 14206, ...,  2099,   286,  7977],
       [ 9102,   326, 16316, ...,   257,  4165,  2723]], shape=(2, 1024)))
(array([[ 2723,   286, 10014, ...,   486,   393,  5443],
       [   13,  5045, 31422, ...,   262,  1708,  8950]], shape=(2, 1024)), array([[  286, 10014,  1321, ...,   393,  5443,    13],
       [ 5045, 31422,  7423, ...,  1708,  8950,    25]], shape=(2, 1024)))
(array([[  25, 3594,   11, ...,  286,  616, 3662],
       [ 290,  314, 1138, ...,  284, 1309,  467]], shape=(2, 1024)), array([[3594,   11, 3999, ...,  616, 3662,  290],
       [ 314, 1138,  351, ..., 1309,  467,  286]], shape=(2, 1024)))
(array([[  286,   262,  8584, ...,   477,   938,  1285],
       [   13,   198,  2504, ...,   338,  4045, 10319]], shape=(2, 1024)), array([[  262,  8584,  3259, ...,   938,  1285,    13],
       [  198,  2504,  4753, ...,  4045, 10319, 10330]], sh

KeyboardInterrupt: 