Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save and resume the state of a DataLoader #5454

Open
lhoestq opened this issue Jan 23, 2023 · 13 comments
Open

Save and resume the state of a DataLoader #5454

lhoestq opened this issue Jan 23, 2023 · 13 comments
Labels
enhancement New feature or request generic discussion Generic discussion on the library

Comments

@lhoestq
Copy link
Member

lhoestq commented Jan 23, 2023

It would be nice when using datasets with a PyTorch DataLoader to be able to resume a training from a DataLoader state (e.g. to resume a training that crashed)

What I have in mind (but lmk if you have other ideas or comments):

For map-style datasets, this requires to have a PyTorch Sampler state that can be saved and reloaded per node and worker.

For iterable datasets, this requires to save the state of the dataset iterator, which includes:

  • the current shard idx and row position in the current shard
  • the epoch number
  • the rng state
  • the shuffle buffer

Right now you can already resume the data loading of an iterable dataset by using IterableDataset.skip but it takes a lot of time because it re-iterates on all the past data until it reaches the resuming point.

cc @stas00 @sgugger

@lhoestq lhoestq added enhancement New feature or request generic discussion Generic discussion on the library labels Jan 23, 2023
@thomasw21
Copy link
Contributor

Something that'd be nice to have is "manual update of state". One of the learning from training LLMs is the ability to skip some batches whenever we notice huge spike might be handy.

@stas00
Copy link
Contributor

stas00 commented Jan 24, 2023

Your outline spec is very sound and clear, @lhoestq - thank you!

@thomasw21, indeed that would be a wonderful extra feature. In Megatron-Deepspeed we manually drained the dataloader for the range we wanted. I wasn't very satisfied with the way we did it, since its behavior would change if you were to do multiple range skips. I think it should remember all the ranges it skipped and not just skip the last range - since otherwise the data is inconsistent (but we probably should discuss this in a separate issue not to derail this much bigger one).

@yqy2001
Copy link

yqy2001 commented Jan 25, 2024

Hi there! I think this is a critical issue and have an urgent need for it, in my attempt to train on a super large-scale dataset using datasets. It is impossible to resume a time-consuming (like one month) experiment by iterating all seen data again, which could possibly cost several days.

@stas00 @thomasw21 @lhoestq Any updates on this problem after 1 year passed?

@dancingpipi
Copy link

any update?

@lhoestq
Copy link
Member Author

lhoestq commented Feb 2, 2024

No update so far, I wonder if someone implemented a resumable pytorch Sampler somwhere.

Then regarding resuming a streaming dataset, we'd first like to have an efficient way to skip shards automatically but this is not implemented yet

@lhoestq
Copy link
Member Author

lhoestq commented Feb 19, 2024

I opened a draft here for IterableDataset: #6658

"""Requires https://github.com/huggingface/datasets/pull/6658 (WIP)"""
from datasets import load_dataset
from torch.utils.data import DataLoader

ds = load_dataset(..., streaming=True)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42, buffer_size=1000)

# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = ds.state_dict()

# Resumable training loop
ds.load_state_dict(dataset_state_dict)
dataloader = DataLoader(ds, batch_size=batch_size)
for step, batch in enumerate(dataloader):
    ...
    if step % save_steps == 0:
        dataset_state_dict = ds.state_dict()

@jiawen-steven-liu
Copy link

Hi @lhoestq - can you provide more information and how to implement on saving and restoring vanilla DataLoader states with map-style datasets?

@lhoestq
Copy link
Member Author

lhoestq commented Feb 21, 2024

For now the easiest is probably to use the vanilla DataLoader only for batching and multiprocessing, and implement the resuming logic using a Dataset (it has .select() to skip examples) and a dataset_state_dict:

from datasets import load_dataset
from torch.utils.data import DataLoader

ds = load_dataset(...)
# ds = ds.map(tokenize)
# ds = ds.shuffle(seed=42)

# Init the dataset state_dict, or load it from a checkpoint
dataset_state_dict = {"step": 0}  

# Resumable training loop
start_step = dataset_state_dict["step"]
dataloader = DataLoader(ds.select(range(start_step * batch_size, len(ds))), batch_size=batch_size)
for step, batch in enumerate(dataloader, start=start_step):
    ...
    if step % save_steps == 0:
        dataset_state_dict = {"step": step}

@xgbj
Copy link

xgbj commented Mar 19, 2024

Hello, I found a similar implementation online that seems to solve your problem. https://github.com/facebookresearch/vissl/blob/main/vissl/data/data_helper.py#L93
it looks like we can set_start_iter in StatefulDistributedSampler to implement the stateful resume requirement we want.

@andrewkho
Copy link

Hi y'all, @lhoestq I wanted to flag that we currently have a StatefulDataLoader in pytorch/data/torchdata that has state_dict/load_state_dict methods, which will call a dataset's state_dict/load_state_dict methods but also handle multiprocessing under the hood. Any chance we can collaborate on this and try to get them to work well together? Please have a look here for some basic examples: https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader#saving-and-loading-state

@lhoestq
Copy link
Member Author

lhoestq commented Apr 30, 2024

Fantastic ! This will help pushing our IterableDataset state_dict implementation at #6658 :) I'll check if there is anything missing to maker them work together, and add tests and some docs referring to the StatefulDataLoader :)

@lhoestq
Copy link
Member Author

lhoestq commented Apr 30, 2024

Ah I just saw this disclaimer in the torchdata README and it feels like people should not rely on it. Should the StatefulDataLoader live elsewhere @andrewkho ?

⚠️ As of July 2023, we have paused active development on TorchData and have paused new releases. We have learnt a lot from building it and hearing from users, but also believe we need to re-evaluate the technical design and approach given how much the industry has changed since we began the project. During the rest of 2023 we will be re-evaluating our plans in this space. Please reach out if you suggestions or comments (please use pytorch/data#1196 for feedback).

@andrewkho
Copy link

@lhoestq Good find, we are in the midst of updating this disclaimer as we're re-starting development and regular releases, though our approach will be to iterate on DL V1 (ie StatefulDataLoader) instead of continuing development on datapipes+DLV2. Let's discuss on a call at some point to figure out the best path forward!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request generic discussion Generic discussion on the library
Projects
None yet
Development

No branches or pull requests

8 participants