Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed support #5369

Merged
merged 9 commits into from
Jan 16, 2023
Merged

Distributed support #5369

merged 9 commits into from
Jan 16, 2023

Conversation

lhoestq
Copy link
Member

@lhoestq lhoestq commented Dec 16, 2022

To split your dataset across your training nodes, you can use the new [datasets.distributed.split_dataset_by_node]:

import os
from datasets.distributed import split_dataset_by_node

ds = split_dataset_by_node(ds, rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"]))

This works for both map-style datasets and iterable datasets.
The dataset is split for the node at rank rank in a pool of nodes of size world_size.

For map-style datasets:

Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.

For iterable datasets:

If the dataset has a number of shards that is a factor of world_size (i.e. if dataset.n_shards % world_size == 0),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of world_size, skipping the other examples.

This can also be combined with a torch.utils.data.DataLoader if you want each node to use multiple workers to load the data.

This also supports shuffling. At each epoch, the iterable dataset shards are reshuffled across all the nodes - you just have to call iterable_ds.set_epoch(epoch_number).

TODO:

  • docs for usage in PyTorch
  • unit tests
  • integration tests with torch.distributed.launch

Related to huggingface/transformers#20770
Close #5360

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 16, 2022

The documentation is not available anymore as the PR was closed or merged.

@lhoestq
Copy link
Member Author

lhoestq commented Dec 20, 2022

Alright all the tests are passing - this is ready for review

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit.

src/datasets/arrow_dataset.py Show resolved Hide resolved
@github-actions
Copy link

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.015146 / 0.011353 (0.003793) 0.006683 / 0.011008 (-0.004326) 0.125994 / 0.038508 (0.087486) 0.041345 / 0.023109 (0.018235) 0.378609 / 0.275898 (0.102711) 0.483139 / 0.323480 (0.159659) 0.009669 / 0.007986 (0.001684) 0.005143 / 0.004328 (0.000814) 0.092015 / 0.004250 (0.087765) 0.052728 / 0.037052 (0.015676) 0.397166 / 0.258489 (0.138677) 0.465820 / 0.293841 (0.171979) 0.051025 / 0.128546 (-0.077521) 0.018451 / 0.075646 (-0.057196) 0.397311 / 0.419271 (-0.021960) 0.054842 / 0.043533 (0.011309) 0.391203 / 0.255139 (0.136064) 0.412743 / 0.283200 (0.129543) 0.111356 / 0.141683 (-0.030327) 1.697526 / 1.452155 (0.245372) 1.795017 / 1.492716 (0.302301)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.253737 / 0.018006 (0.235731) 0.583071 / 0.000490 (0.582581) 0.005958 / 0.000200 (0.005758) 0.000110 / 0.000054 (0.000056)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.030397 / 0.037411 (-0.007014) 0.112242 / 0.014526 (0.097716) 0.138807 / 0.176557 (-0.037749) 0.209820 / 0.737135 (-0.527316) 0.139530 / 0.296338 (-0.156808)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.574111 / 0.215209 (0.358902) 5.623713 / 2.077655 (3.546058) 2.416880 / 1.504120 (0.912760) 1.951013 / 1.541195 (0.409819) 2.124565 / 1.468490 (0.656075) 1.268854 / 4.584777 (-3.315923) 5.942368 / 3.745712 (2.196656) 5.413814 / 5.269862 (0.143952) 2.931638 / 4.565676 (-1.634038) 0.135070 / 0.424275 (-0.289205) 0.014290 / 0.007607 (0.006683) 0.708384 / 0.226044 (0.482340) 7.487994 / 2.268929 (5.219065) 3.074210 / 55.444624 (-52.370414) 2.380583 / 6.876477 (-4.495893) 2.522298 / 2.142072 (0.380226) 1.336741 / 4.805227 (-3.468486) 0.236761 / 6.500664 (-6.263903) 0.076592 / 0.075469 (0.001123)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.629415 / 1.841788 (-0.212373) 19.000640 / 8.074308 (10.926332) 21.474058 / 10.191392 (11.282666) 0.231227 / 0.680424 (-0.449197) 0.046213 / 0.534201 (-0.487988) 0.565703 / 0.579283 (-0.013580) 0.662956 / 0.434364 (0.228592) 0.656475 / 0.540337 (0.116137) 0.762534 / 1.386936 (-0.624402)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.010952 / 0.011353 (-0.000400) 0.006259 / 0.011008 (-0.004749) 0.132430 / 0.038508 (0.093922) 0.037920 / 0.023109 (0.014811) 0.483565 / 0.275898 (0.207667) 0.528190 / 0.323480 (0.204710) 0.008116 / 0.007986 (0.000130) 0.006768 / 0.004328 (0.002440) 0.100520 / 0.004250 (0.096270) 0.055208 / 0.037052 (0.018155) 0.484672 / 0.258489 (0.226183) 0.556937 / 0.293841 (0.263096) 0.057938 / 0.128546 (-0.070609) 0.020821 / 0.075646 (-0.054826) 0.430735 / 0.419271 (0.011464) 0.066317 / 0.043533 (0.022785) 0.496652 / 0.255139 (0.241513) 0.502004 / 0.283200 (0.218804) 0.125403 / 0.141683 (-0.016280) 1.833396 / 1.452155 (0.381241) 1.974517 / 1.492716 (0.481800)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.269198 / 0.018006 (0.251191) 0.620314 / 0.000490 (0.619824) 0.000535 / 0.000200 (0.000335) 0.000083 / 0.000054 (0.000029)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.032373 / 0.037411 (-0.005039) 0.130043 / 0.014526 (0.115517) 0.146217 / 0.176557 (-0.030339) 0.200187 / 0.737135 (-0.536948) 0.152839 / 0.296338 (-0.143499)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.677478 / 0.215209 (0.462268) 6.678856 / 2.077655 (4.601201) 3.025870 / 1.504120 (1.521750) 2.678196 / 1.541195 (1.137001) 2.740640 / 1.468490 (1.272150) 1.237163 / 4.584777 (-3.347614) 5.752621 / 3.745712 (2.006908) 3.170435 / 5.269862 (-2.099427) 2.049174 / 4.565676 (-2.516502) 0.147663 / 0.424275 (-0.276612) 0.016107 / 0.007607 (0.008500) 0.849666 / 0.226044 (0.623621) 8.395212 / 2.268929 (6.126283) 3.741120 / 55.444624 (-51.703505) 3.102926 / 6.876477 (-3.773550) 3.233655 / 2.142072 (1.091583) 1.520349 / 4.805227 (-3.284878) 0.267159 / 6.500664 (-6.233505) 0.083646 / 0.075469 (0.008177)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.640458 / 1.841788 (-0.201330) 19.043169 / 8.074308 (10.968861) 22.786126 / 10.191392 (12.594734) 0.218040 / 0.680424 (-0.462384) 0.032948 / 0.534201 (-0.501253) 0.569574 / 0.579283 (-0.009710) 0.658746 / 0.434364 (0.224382) 0.650501 / 0.540337 (0.110164) 0.730588 / 1.386936 (-0.656348)

@lhoestq
Copy link
Member Author

lhoestq commented Jan 13, 2023

just added a note :)

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks all good now!

@lhoestq lhoestq merged commit 9991c74 into main Jan 16, 2023
@lhoestq lhoestq deleted the distributed-support branch January 16, 2023 13:33
@rishabhm12
Copy link

rishabhm12 commented Jun 27, 2023

Hi @lhoestq ,
Can you please throw some light on the following statement
If the dataset has a number of shards that is a factor of world_size (i.e. if dataset.n_shards % world_size == 0), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of world_size, skipping the other examples.

Let's assume I have 127 parquet files and world_size is 4. I was not able to fully comprehend the above statement
What does this statement mean?
each node keeps 1 example out of world_size, skipping the other examples.
Thank you!

@lhoestq
Copy link
Member Author

lhoestq commented Jun 27, 2023

If you have 128 parquet files, then dataset.n_shards % world_size == 0. In this case each worker can take care of 32 parquet files.

On the other hand if you have dataset.n_shards % world_size != 0 (in your case 127 files), then we can't assign the same number of files to each worker. This is an issue because it may under-utilize your GPU at the end of your training since some workers will take longer to iterate on the dataset than others.

Therefore in this case, all the workers take care of the 127 parquet files but workers will skip examples to not end up with duplicates. That's what "each node keeps 1 example out of world_size, skipping the other examples" means, and in your case it implies:

  • rank=0 will read the samples with idx=0, 4, 8 etc.
  • rank=1 will read the samples with idx=1, 5, 9 etc.
  • rank=2 will read the samples with idx=2, 6, 10 etc.
  • rank=3 will read the samples with idx=3, 7, 11 etc.

@rishabhm12
Copy link

Thanks a lot @lhoestq , this helps!

@KatarinaYuan
Copy link

Hi, in the case above, if we use keep_in_memory=True for Dataset, then we still need to read in n times the dataset if we use DDP on n GPUs (1 node), right? That means we need n times the memory. Is there any way to only load the data once, to save memory?

@lhoestq
Copy link
Member Author

lhoestq commented Jun 29, 2023

Dataset objects are memory mapped from disk so they use almost no RAM (only the current batch)

Also they are perfectly sharded using split_dataset_by_node so it's going to be read exactly once in total using DDP.
You can also achieve the same thing using a DistributedSampler in pytorch for DDP instead of using split_dataset_by_node.

@KatarinaYuan
Copy link

Hi, please correct if I mistake anything:

  1. Dataset with keep_in_memory=True would explicitly pre-load the data into memory, instead of reading from disk via the memory map for every batch. The former way should be faster than the latter.
  2. When using DDP, before sending the Dataset object into split_dataset_by_node or incorporate it with DistributedSampler, every process still needs to pre-load the entire data into memory (when keep_in_memory=True) and then select the chunked indices from the loaded data.

Generally, the dilemma I'm facing is:
Suppose we have a data around 120GB, and we want to use DistributedLengthGroupedSampler to optimize batching. When using DDP and keep_in_memory=True, every process loads 120GB which is not acceptable. For now, I turned off keep_in_memory and try to increase the number of workers for DataLoader to get better pipelining.

But is it possible to load 120GB once into 4 * A100 (which has around 4*120GB memory) and make each process read from this shared data from memory? Theoretically, maybe it should be faster?

@lhoestq
Copy link
Member Author

lhoestq commented Jun 30, 2023

Feel free to ask your questions on the forum if you don't mind, this way the discussions may be useful to other people ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

IterableDataset returns duplicated data using PyTorch DDP
5 participants