Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'with_format' is extremely slow when used together with 'interleave_datasets' or 'shuffle' on IterableDatasets #6637

Open
tobycrisford opened this issue Feb 1, 2024 · 1 comment

Comments

@tobycrisford
Copy link

Describe the bug

If you:

  1. Interleave two iterable datasets together with the interleave_datasets function, or shuffle an iterable dataset
  2. Set the output format to torch tensors with .with_format('torch')

Then iterating through the dataset becomes over 100x slower than it is if you don't apply the torch formatting.

Steps to reproduce the bug

import datasets
import torch
from tqdm import tqdm

rand_a = torch.randn(3,224,224)
rand_b = torch.randn(3,224,224)

a = torch.stack([rand_a] * 1000)
b = torch.stack([rand_b] * 1000)

features = datasets.Features({"tensor": datasets.Array3D(shape=(3,224,224), dtype="float32")})
ds_a = datasets.Dataset.from_dict({"tensor": a}, features=features).to_iterable_dataset()
ds_b = datasets.Dataset.from_dict({"tensor": b}, features=features).to_iterable_dataset()

# Iterating through either dataset with torch formatting is really fast (2000it/s on my machine)
for example in tqdm(ds_a.with_format('torch')):
    pass

# Iterating through either dataset shuffled is also pretty fast (100it/s on my machine)
for example in tqdm(ds_a.shuffle()):
    pass

# Iterating through this interleaved dataset is pretty fast (200it/s on my machine)
ds_fast = datasets.interleave_datasets([ds_a, ds_b])
for example in tqdm(ds_fast):
    pass

# Iterating through either dataset with torch formatting *after shuffling* is really slow... (<2it/s on my machine)
for example in tqdm(ds_a.shuffle().with_format('torch')):
    pass

# Iterating through this torch formatted interleaved dataset is also really slow (<2it/s on my machine)...
ds_slow = datasets.interleave_datasets([ds_a, ds_b]).with_format('torch')
for example in tqdm(ds_slow):
    pass

# Even doing this is way faster!! (70it/s on my machine)
for example in tqdm(ds_fast):
    test = torch.tensor(example['tensor'])

Expected behavior

Applying torch formatting to the interleaved dataset shouldn't increase the time taken to iterate through the dataset by very much, since even explicitly converting every example is over 70x faster than calling .with_format('torch').

Environment info

  • datasets version: 2.16.1
  • Platform: Linux-6.5.0-15-generic-x86_64-with-glibc2.38
  • Python version: 3.11.6
  • huggingface_hub version: 0.20.3
  • PyArrow version: 15.0.0
  • Pandas version: 2.2.0
  • fsspec version: 2023.10.0
@lhoestq
Copy link
Member

lhoestq commented Feb 5, 2024

The "torch" formatting is usually fast because we do zero-copy conversion from the Arrow data on your disk to Torch tensors. However IterableDataset shuffling seems to do data copies that slow down the pipeline, and it shuffles python objects instead of Arrow data.

To fix this we need to implement BufferShuffledExamplesIterable.iter_arrow() (same as regular BufferShuffledExamplesIterable.__iter__() but yields Arrow tables)

class BufferShuffledExamplesIterable(_BaseExamplesIterable):
def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator):
super().__init__()
self.ex_iterable = ex_iterable
self.buffer_size = buffer_size
self.generator = generator
# TODO(QL): implement iter_arrow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants