-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Support iterable datasets in GRPO #3226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
One thing I still want to mention, but didn't know how to handle, is that we need to set |
|
Thanks for the PR. I don't really understand why we can't use the same approach as for the regular Dataset. Adapting the sampler should be enough, no? |
The problem is that torch's
I thought that updating the data collator to do the duplication on the fly would be the easiest solution for an iterable dataset instead. I apologise if your comment was referring to something else. Let me know! |
|
What about |
|
I'm getting this error with both "my" approach and the approach of this PR: |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Yes, this is exactly the error that I referenced in my note here:
For debugging purposes, you can bypass it by setting Also, I think the map suggestion could work as well, I just hadn't considered that before. |
|
Did you have a chance to take another look at this? |
|
An easier way of doing this might be to create a wrapper around IterableDataset using the same logic used by the RepeatSampler as follows: def repeat_iterable_dataset(
dataset: IterableDataset,
mini_repeat_count: int,
batch_size: int = 1,
repeat_count: int = 1,
):
batched_dataset = dataset.batch(batch_size=batch_size, drop_last_batch=True)
for _ in range(repeat_count):
for batch in batched_dataset:
within_batch_dataset = Dataset.from_dict(batch)
for record in within_batch_dataset:
for _ in range(mini_repeat_count):
yield record
effective_batch_size = (
self.args.per_device_train_batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
)
train_dataset = IterableDataset.from_generator(
repeat_iterable_dataset,
gen_kwargs={
"dataset": input_train_dataset,
"mini_repeat_count": self.num_generations,
"batch_size": effective_batch_size // self.num_generations,
"repeat_count": self.num_iterations,
}
) |
|
This seems like a reasonable approach |
|
@brayan07 Can you propose a merge request with the fix integrated please? If not I am happy to attempt it Edit: if isinstance(train_dataset, torch.utils.data.IterableDataset) and args.accelerator_config.dispatch_batches:
warnings.warn(
"You passed an IterableDataset, `args.accelerator_config.dispatch_batches` is being set to False"
"See https://github.com/huggingface/transformers/issues/26548#issuecomment-1885798533"
)
args.accelerator_config.dispatch_batches = FalseJust before |
|
Any updates on when this would be merged / fixed? Thanks! |
What does this PR do?
This PR solves the issue described in #3213. Additionally, it avoids the need for the PR in #3216.
I implemented support for an
IterableDatasetby overriding theget_train_dataloaderandget_eval_dataloadermethods from theTrainerclass. Now, when GRPO is given an iterable dataset, the batch size is divided byself.num_generationsand the data collator takes care of duplicating the samples afterwards.Fixes #3213
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.