You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Interleave two iterable datasets together with the interleave_datasets function, or shuffle an iterable dataset
Set the output format to torch tensors with .with_format('torch')
Then iterating through the dataset becomes over 100x slower than it is if you don't apply the torch formatting.
Steps to reproduce the bug
importdatasetsimporttorchfromtqdmimporttqdmrand_a=torch.randn(3,224,224)
rand_b=torch.randn(3,224,224)
a=torch.stack([rand_a] *1000)
b=torch.stack([rand_b] *1000)
features=datasets.Features({"tensor": datasets.Array3D(shape=(3,224,224), dtype="float32")})
ds_a=datasets.Dataset.from_dict({"tensor": a}, features=features).to_iterable_dataset()
ds_b=datasets.Dataset.from_dict({"tensor": b}, features=features).to_iterable_dataset()
# Iterating through either dataset with torch formatting is really fast (2000it/s on my machine)forexampleintqdm(ds_a.with_format('torch')):
pass# Iterating through either dataset shuffled is also pretty fast (100it/s on my machine)forexampleintqdm(ds_a.shuffle()):
pass# Iterating through this interleaved dataset is pretty fast (200it/s on my machine)ds_fast=datasets.interleave_datasets([ds_a, ds_b])
forexampleintqdm(ds_fast):
pass# Iterating through either dataset with torch formatting *after shuffling* is really slow... (<2it/s on my machine)forexampleintqdm(ds_a.shuffle().with_format('torch')):
pass# Iterating through this torch formatted interleaved dataset is also really slow (<2it/s on my machine)...ds_slow=datasets.interleave_datasets([ds_a, ds_b]).with_format('torch')
forexampleintqdm(ds_slow):
pass# Even doing this is way faster!! (70it/s on my machine)forexampleintqdm(ds_fast):
test=torch.tensor(example['tensor'])
Expected behavior
Applying torch formatting to the interleaved dataset shouldn't increase the time taken to iterate through the dataset by very much, since even explicitly converting every example is over 70x faster than calling .with_format('torch').
The "torch" formatting is usually fast because we do zero-copy conversion from the Arrow data on your disk to Torch tensors. However IterableDataset shuffling seems to do data copies that slow down the pipeline, and it shuffles python objects instead of Arrow data.
To fix this we need to implement BufferShuffledExamplesIterable.iter_arrow() (same as regular BufferShuffledExamplesIterable.__iter__() but yields Arrow tables)
Describe the bug
If you:
Then iterating through the dataset becomes over 100x slower than it is if you don't apply the torch formatting.
Steps to reproduce the bug
Expected behavior
Applying torch formatting to the interleaved dataset shouldn't increase the time taken to iterate through the dataset by very much, since even explicitly converting every example is over 70x faster than calling .with_format('torch').
Environment info
datasets
version: 2.16.1huggingface_hub
version: 0.20.3fsspec
version: 2023.10.0The text was updated successfully, but these errors were encountered: