Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DataLoader with num_workers > 0 in streaming mode #4375

Merged
merged 24 commits into from
Jun 10, 2022

Conversation

lhoestq
Copy link
Member

@lhoestq lhoestq commented May 19, 2022

Issue

It's currently not possible to properly stream a dataset using multiple torch.utils.data.DataLoader workers:

Solution in this PR

I fixed these to enable passing an IterableDataset to a torch.utils.data.DataLoader with num_workers > 0.

I also had to shard the IterableDataset to give each worker a shard, otherwise data would be duplicated. This is implemented in TorchIterableDataset.__iter__ and uses the new IterableDataset._iter_shard(shard_idx) method

I also had to do a few changes the patching that enable streaming in dataset scripts:

  • the patches are now always applied - not just for streaming mode. They're applied when a builder is instantiated
  • I improved it to also check for renamed modules or attributes (ex: pandas vs pd)
  • I grouped all the patches of pathlib.Path into a class xPath, so that Path outside of dataset scripts stay unchanged - otherwise I didn't change the content of the extended Path methods for streaming
  • I fixed a bug with the pd.read_csv patch, opening the file in "rb" mode was missing and causing some datasets to not work in streaming mode, and compression inference was missing

A few details regarding fsspec in multiprocessing

From fsspec/filesystem_spec#963 (comment) :

Non-async instances might be safe in the forked child, if they hold no open files/sockets etc.; I'm not sure any implementations pass this test!
If any async instance has been created, the newly forked processes must:

  1. discard references to locks, threads and event loops and make new ones
  2. not use any async fsspec instances from the parent process
  3. clear all class instance caches

Therefore in a DataLoader's worker, I clear the reference to the loop and thread (1). We should be fine for 2 and 3 already since we don't use fsspec class instances from the parent process.

Fix #3950
Fix #3951

TODO:

  • fix tests

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 19, 2022

The documentation is not available anymore as the PR was closed or merged.

@lhoestq lhoestq marked this pull request as ready for review June 7, 2022 11:56
@lhoestq lhoestq requested a review from mariosasko June 7, 2022 11:56
@lhoestq
Copy link
Member Author

lhoestq commented Jun 7, 2022

Alright this is finally ready for review ! It's quite long I'm sorry, but it's not easy to disentangle everything ^^'

The main additions are in

  • src/datasets/formatting/dataset_wrappers/torch_iterable_dataset.py
  • src/datasets/iterable_dataset.py
  • src/datasets/utils/patching.py



def xpathrglob(path, pattern, **kwargs):
"""Rglob function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs.
class xPath(type(Path())):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

many changes in this file are just about moving functions inside this class.

For example I moved xpathrglob to xPath.rglob

Comment on lines +31 to +49
if worker_info.id == 0 and self.n_shards < worker_info.num_workers:
logger.warning(
f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={self.n_shards}). "
f"Stopping dataloader workers [{self.n_shards}...{worker_info.num_workers -1}]."
)
logger.warning(
f"To parallelize data loading, we give each process some shards (or data sources) to process. "
f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={self.n_shards}."
f"To enable more parallelism, please split the dataset in more files than {self.n_shards}."
)
# split workload
shards_indices = list(range(worker_info.id, self.n_shards, worker_info.num_workers))
if shards_indices:
logger.debug(
f"dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{self.n_shards} shards."
)
for shard_idx in shards_indices:
for key, example in self._iter_shard(shard_idx):
yield self._apply_feature_types(example)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is where we shard the iterable dataset when it's passed to a DataLoader worker

ex_iterable = self._ex_iterable.shuffle_data_sources(self._effective_generator())
else:
ex_iterable = self._ex_iterable
yield from ex_iterable.shard_data_sources(shard_idx)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what is called when iterating in a DataLoader worker. The idea is to iterate only on one shard out of the self.n_shards available

# xml.etree.ElementTree
for submodule in ["ElementTree", "ET"]:
patch_submodule(module, f"{submodule}.parse", wrap_auth(xet_parse)).start()
patch_submodule(module, "pathlib.Path", xPath).start()
Copy link
Member Author

@lhoestq lhoestq Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we can just pass the source object to patch in the module, and it will patch it even if the attribute alone is imported, or even if a parent module has been imported and renamed (see test_patching.py for a list of all supported cases - I probably have to add a docstring to patch_submodule as well)

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is 🔥. Just two comments:

(Also thanks for the the review instructions/code clarifications)

src/datasets/iterable_dataset.py Outdated Show resolved Hide resolved
src/datasets/utils/patching.py Outdated Show resolved Hide resolved
@lhoestq
Copy link
Member Author

lhoestq commented Jun 9, 2022

Added some comments and an error when lists have different lengths for sharding :)

@mariosasko
Copy link
Collaborator

Let's resolve the merge conflict and the CI error (if it's related to the changes), and I can review the PR again.

@lhoestq
Copy link
Member Author

lhoestq commented Jun 10, 2022

Feel free to review again :) The CI fail is unrelated to this PR and will be fixed by #4472 (the hub now returns 401 instead of 404 for unauthenticated requests to non-existing repos)

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks good now! Thanks!

src/datasets/iterable_dataset.py Outdated Show resolved Hide resolved
@lhoestq
Copy link
Member Author

lhoestq commented Jun 10, 2022

CI failures are unrelated to this PR - merging :)

(CI fails are a mix of pip install fails and Hub fails)

@lhoestq lhoestq merged commit ab7d304 into master Jun 10, 2022
@lhoestq lhoestq deleted the parallel-torch-iterable-dataset branch June 10, 2022 20:47
@justheuristic
Copy link

@lhoestq you're our hero :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants