Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Resumable IterableDataset] Add IterableDataset state_dict #6658

Merged
merged 29 commits into from
Jun 3, 2024

Conversation

lhoestq
Copy link
Member

@lhoestq lhoestq commented Feb 11, 2024

A simple implementation of a mechanism to resume an IterableDataset.
This is WIP and untested.

Example:

from datasets import Dataset, concatenate_datasets


ds = Dataset.from_dict({"a": range(5)}).to_iterable_dataset(num_shards=3)
ds = concatenate_datasets([ds] * 2)

print(f"{ds.state_dict()=}")
for i, example in enumerate(ds):
    print(example)
    if i == 6:
        state_dict = ds.state_dict()
        print("checkpoint")
ds.load_state_dict(state_dict)
print(f"resuming from checkpoint {ds.state_dict()=}")
for example in ds:
    print(example)

returns

ds.state_dict()={'ex_iterable_idx': 0, 'ex_iterables': [{'shard_idx': 0, 'shard_example_idx': 0}, {'shard_idx': 0, 'shard_example_idx': 0}]}
{'a': 0}
{'a': 1}
{'a': 2}
{'a': 3}
{'a': 4}
{'a': 0}
{'a': 1}
checkpoint
{'a': 2}
{'a': 3}
{'a': 4}
resuming from checkpoint ds.state_dict()={'ex_iterable_idx': 1, 'ex_iterables': [{'shard_idx': 3, 'shard_example_idx': 0}, {'shard_idx': 0, 'shard_example_idx': 2}]}
{'a': 2}
{'a': 3}
{'a': 4}

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bwanglzu
Copy link

would be nice to have this feature in the new dataset release!

@lhoestq
Copy link
Member Author

lhoestq commented Apr 11, 2024

Before finalising this this I'd like to make sure this philosophy makes sense for other libs like accelerate for example.

cc @muellerzr I'd love your feedback on this one
cc @LysandreJik also if you think other people should take a look

Copy link

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this looks like a very nice API decision, and super easy for us to bring into Accelerate as part of load_state. Will be nice to not have to use skip_batches if a user is using an IterableDataset.

One design question though: what's the logic behind self._state_dict rather than having it all be state_dict?

Private stuff doesn't exist in python, so what's the aim in doing that here and having state_dict be a passthrough to it? (If this is a common design pattern over in datasets that's okay)

@lhoestq
Copy link
Member Author

lhoestq commented Apr 15, 2024

One design question though: what's the logic behind self._state_dict rather than having it all be state_dict?

The _state_dict is the internal object that is updated in-place while you iterate on the dataset.

We need to copy it every time the user accesses it.

Otherwise we would get

state_dict = ds.state_dict()
for x in ds:
    assert ds.state_dict() == state_dict  # and actually `assert ds.state_dict() is state_dict`

The state is updated in-place since it's made of dictionaries that are shared with the steps in the IterableDataset pipeline.

@muellerzr
Copy link

What do you think of making it a full property with a docstring explicitly stating users shouldn’t call/modify it directly?

I can imagine some exploratory users getting curious

@lhoestq
Copy link
Member Author

lhoestq commented Apr 15, 2024

I don't think users read docstrings of properties that often. What about explaining the logic in the .state_dict() docstring ? This also feels aligned with the way .state_dict() and .load_state_dict() works in pytorch (you should use load_state_dict to load a modified copy of the state dict)

@muellerzr
Copy link

Sure, I can agree with that!

@muellerzr
Copy link

Just a small note mentioning returns a copy of the state dict should be enough imo

@samsja
Copy link

samsja commented May 7, 2024

looking forward as well for this PR to be merge

@fyubang
Copy link

fyubang commented May 20, 2024

I don't think users read docstrings of properties that often. What about explaining the logic in the .state_dict() docstring ? This also feels aligned with the way .state_dict() and .load_state_dict() works in pytorch (you should use load_state_dict to load a modified copy of the state dict)

Hi, I'm experimenting with LLM pretraining using your code. I found that the time of resuming an iterable dataset can be reduced to 5% (my streaming process includes tokenization), but I'm not sure if I'm using it correctly. Could you help me check it? Thanks.

class CustomTrainer(Trainer):
    def _save_rng_state(self, output_dir):
        super()._save_rng_state(output_dir)
        if self.args.should_save:
            with open(os.path.join(output_dir, f'iterable_data_state_dict.json'), 'w', encoding='utf-8') as fo:
                json.dump(self.train_dataset.state_dict(), fo, ensure_ascii=False)
    dataset = <A IterableDataset constructed by (interleave, map(tokenization))>
    lask_ckpt_iterable_data_state_dict_file_path = os.path.join(training_args.resume_from_checkpoint, f'iterable_data_state_dict.json')
    if os.path.exists(lask_ckpt_iterable_data_state_dict_file_path) and finetuning_args.load_iteratable_state_dict:
        if not training_args.ignore_data_skip:
            raise ValueError(f'Found `iterable_data_state_dict_file_path`: `{lask_ckpt_iterable_data_state_dict_file_path}`. Please set `ignore_data_skip`=True to skip tokenization.')
        with open(lask_ckpt_iterable_data_state_dict_file_path) as f:
            lask_ckpt_iterable_data_state_dict = json.load(f)
            dataset.load_state_dict(lask_ckpt_iterable_data_state_dict)
            logger.info(f'Loading `iterable_data_state_dict` from {lask_ckpt_iterable_data_state_dict_file_path}')

@lhoestq
Copy link
Member Author

lhoestq commented May 21, 2024

it sounds good to me :)

@uygnef
Copy link

uygnef commented May 24, 2024

@lhoestq Hi, if I set prefetch, does this dataset work well?

@lhoestq
Copy link
Member Author

lhoestq commented May 24, 2024

It does work well if you prefetch and then resume from a state, but you might lose the samples that were in the prefetch buffer of the DataLoader (which could be acceptable in some circumstances).

Fortunately we're about to ship an integration with the new StatefulDataLoader from torchdata which can help on this matter :)

@uygnef
Copy link

uygnef commented May 24, 2024

yeah, what I meant is that prefetch might drop a few data entries. really looking forward to the new StatefulDataLoader. :)

@lhoestq lhoestq force-pushed the iterable-dataset-state-dict branch from 6254769 to c323af0 Compare May 30, 2024 10:46
@lhoestq lhoestq marked this pull request as ready for review May 30, 2024 12:35
@lhoestq lhoestq merged commit 43fd659 into main Jun 3, 2024
12 checks passed
@lhoestq lhoestq deleted the iterable-dataset-state-dict branch June 3, 2024 19:15
Copy link

github-actions bot commented Jun 3, 2024

Show benchmarks

PyArrow==8.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.005788 / 0.011353 (-0.005564) 0.004036 / 0.011008 (-0.006972) 0.064720 / 0.038508 (0.026212) 0.034990 / 0.023109 (0.011881) 0.245488 / 0.275898 (-0.030410) 0.272596 / 0.323480 (-0.050884) 0.003170 / 0.007986 (-0.004815) 0.002867 / 0.004328 (-0.001461) 0.049961 / 0.004250 (0.045711) 0.050951 / 0.037052 (0.013899) 0.257757 / 0.258489 (-0.000732) 0.292957 / 0.293841 (-0.000884) 0.027739 / 0.128546 (-0.100807) 0.010942 / 0.075646 (-0.064705) 0.205153 / 0.419271 (-0.214118) 0.037892 / 0.043533 (-0.005641) 0.247536 / 0.255139 (-0.007603) 0.267239 / 0.283200 (-0.015960) 0.021490 / 0.141683 (-0.120193) 1.107306 / 1.452155 (-0.344848) 1.144675 / 1.492716 (-0.348041)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.103212 / 0.018006 (0.085205) 0.315174 / 0.000490 (0.314684) 0.000229 / 0.000200 (0.000029) 0.000044 / 0.000054 (-0.000011)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.019771 / 0.037411 (-0.017641) 0.064033 / 0.014526 (0.049507) 0.076751 / 0.176557 (-0.099805) 0.122615 / 0.737135 (-0.614521) 0.078490 / 0.296338 (-0.217848)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.286236 / 0.215209 (0.071027) 2.841469 / 2.077655 (0.763814) 1.514079 / 1.504120 (0.009959) 1.393792 / 1.541195 (-0.147403) 1.432741 / 1.468490 (-0.035749) 0.571003 / 4.584777 (-4.013774) 2.369031 / 3.745712 (-1.376681) 2.825246 / 5.269862 (-2.444616) 1.858524 / 4.565676 (-2.707153) 0.065366 / 0.424275 (-0.358909) 0.005107 / 0.007607 (-0.002500) 0.341010 / 0.226044 (0.114965) 3.443894 / 2.268929 (1.174966) 1.879192 / 55.444624 (-53.565433) 1.603046 / 6.876477 (-5.273431) 1.807639 / 2.142072 (-0.334433) 0.646726 / 4.805227 (-4.158502) 0.119409 / 6.500664 (-6.381255) 0.044564 / 0.075469 (-0.030905)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 0.971026 / 1.841788 (-0.870762) 12.593884 / 8.074308 (4.519576) 10.305243 / 10.191392 (0.113851) 0.132018 / 0.680424 (-0.548406) 0.014387 / 0.534201 (-0.519814) 0.288597 / 0.579283 (-0.290686) 0.267373 / 0.434364 (-0.166991) 0.325626 / 0.540337 (-0.214711) 0.488808 / 1.386936 (-0.898128)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.005991 / 0.011353 (-0.005362) 0.004028 / 0.011008 (-0.006980) 0.051951 / 0.038508 (0.013443) 0.036870 / 0.023109 (0.013761) 0.263777 / 0.275898 (-0.012122) 0.290914 / 0.323480 (-0.032566) 0.004594 / 0.007986 (-0.003392) 0.002971 / 0.004328 (-0.001357) 0.049699 / 0.004250 (0.045449) 0.044939 / 0.037052 (0.007887) 0.275055 / 0.258489 (0.016566) 0.316244 / 0.293841 (0.022403) 0.030501 / 0.128546 (-0.098045) 0.011197 / 0.075646 (-0.064449) 0.058718 / 0.419271 (-0.360554) 0.034926 / 0.043533 (-0.008607) 0.259172 / 0.255139 (0.004033) 0.280127 / 0.283200 (-0.003072) 0.019775 / 0.141683 (-0.121908) 1.169468 / 1.452155 (-0.282687) 1.178098 / 1.492716 (-0.314619)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.101633 / 0.018006 (0.083626) 0.314684 / 0.000490 (0.314194) 0.000224 / 0.000200 (0.000024) 0.000055 / 0.000054 (0.000001)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.024071 / 0.037411 (-0.013341) 0.079894 / 0.014526 (0.065368) 0.090915 / 0.176557 (-0.085642) 0.132397 / 0.737135 (-0.604738) 0.091919 / 0.296338 (-0.204419)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.296237 / 0.215209 (0.081028) 2.891752 / 2.077655 (0.814097) 1.551937 / 1.504120 (0.047817) 1.414179 / 1.541195 (-0.127016) 1.450192 / 1.468490 (-0.018298) 0.556272 / 4.584777 (-4.028504) 0.952374 / 3.745712 (-2.793339) 2.709450 / 5.269862 (-2.560411) 1.771251 / 4.565676 (-2.794426) 0.061873 / 0.424275 (-0.362402) 0.005058 / 0.007607 (-0.002549) 0.344790 / 0.226044 (0.118746) 3.398982 / 2.268929 (1.130053) 1.905832 / 55.444624 (-53.538792) 1.632357 / 6.876477 (-5.244120) 1.822913 / 2.142072 (-0.319160) 0.643426 / 4.805227 (-4.161802) 0.117321 / 6.500664 (-6.383343) 0.042107 / 0.075469 (-0.033363)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 0.974921 / 1.841788 (-0.866867) 12.497801 / 8.074308 (4.423493) 11.216174 / 10.191392 (1.024782) 0.135288 / 0.680424 (-0.545136) 0.016731 / 0.534201 (-0.517470) 0.287987 / 0.579283 (-0.291296) 0.130246 / 0.434364 (-0.304117) 0.323282 / 0.540337 (-0.217055) 0.414595 / 1.386936 (-0.972341)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants