Skip to content

Conversation

@wilrop
Copy link
Contributor

@wilrop wilrop commented Apr 3, 2025

What does this PR do?

This PR solves the issue described in #3213. Additionally, it avoids the need for the PR in #3216.

I implemented support for an IterableDataset by overriding the get_train_dataloader and get_eval_dataloader methods from the Trainer class. Now, when GRPO is given an iterable dataset, the batch size is divided by self.num_generations and the data collator takes care of duplicating the samples afterwards.

Fixes #3213

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@wilrop
Copy link
Contributor Author

wilrop commented Apr 3, 2025

One thing I still want to mention, but didn't know how to handle, is that we need to set dispatch_batches=False in the GRPOConfig when using an iterable dataset. If this is not set, we get an error when torch tries to concatenate strings (i.e. the prompts). I think this is probably not a GRPO specific problem though, see e.g. here huggingface/transformers#26548 (comment)

@qgallouedec
Copy link
Member

Thanks for the PR. I don't really understand why we can't use the same approach as for the regular Dataset. Adapting the sampler should be enough, no?

@wilrop
Copy link
Contributor Author

wilrop commented Apr 5, 2025

Thanks for the PR. I don't really understand why we can't use the same approach as for the regular Dataset. Adapting the sampler should be enough, no?

The problem is that torch's IterableDataset does not support custom samplers, see: https://github.com/pytorch/pytorch/blob/1017927c83dd95a4be6074c48e0fb38f0a1bd8f3/torch/utils/data/dataloader.py#L301 with the following comments

# Arg-check dataset related before checking samplers because we want to
# tell users that iterable-style datasets are incompatible with custom
# samplers first, so that they don't learn that this combo doesn't work
# after spending time fixing the custom sampler errors.

I thought that updating the data collator to do the duplication on the fly would be the easiest solution for an iterable dataset instead.

I apologise if your comment was referring to something else. Let me know!

@qgallouedec
Copy link
Member

What about dataset.map? It's maybe the easiest way to address this. I'll try, keep you posted

@qgallouedec
Copy link
Member

I'm getting this error with both "my" approach and the approach of this PR:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/fsx/qgallouedec/trl/3213.py", line 16, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/trainer.py", line 2237, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/trainer.py", line 2506, in _inner_training_loop
[rank0]:     batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
[rank0]:                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/trainer.py", line 5235, in get_batch_samples
[rank0]:     batch_samples.append(next(epoch_iterator))
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/data_loader.py", line 858, in __iter__
[rank0]:     next_batch, next_batch_info = self._fetch_batches(main_iterator)
[rank0]:                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/data_loader.py", line 814, in _fetch_batches
[rank0]:     batch = concatenate(batches, dim=0)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 610, in concatenate
[rank0]:     return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 81, in honor_type
[rank0]:     return type(obj)(generator)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 610, in <genexpr>
[rank0]:     return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
[rank0]:                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 612, in concatenate
[rank0]:     return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 614, in concatenate
[rank0]:     raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
[rank0]: TypeError: Can only concatenate tensors but got <class 'str'>

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@wilrop
Copy link
Contributor Author

wilrop commented Apr 10, 2025

I'm getting this error with both "my" approach and the approach of this PR:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/fsx/qgallouedec/trl/3213.py", line 16, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/trainer.py", line 2237, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/trainer.py", line 2506, in _inner_training_loop
[rank0]:     batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
[rank0]:                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/transformers/src/transformers/trainer.py", line 5235, in get_batch_samples
[rank0]:     batch_samples.append(next(epoch_iterator))
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/data_loader.py", line 858, in __iter__
[rank0]:     next_batch, next_batch_info = self._fetch_batches(main_iterator)
[rank0]:                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/data_loader.py", line 814, in _fetch_batches
[rank0]:     batch = concatenate(batches, dim=0)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 610, in concatenate
[rank0]:     return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 81, in honor_type
[rank0]:     return type(obj)(generator)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 610, in <genexpr>
[rank0]:     return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
[rank0]:                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 612, in concatenate
[rank0]:     return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/fsx/qgallouedec/miniconda3/envs/trl/lib/python3.12/site-packages/accelerate/utils/operations.py", line 614, in concatenate
[rank0]:     raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
[rank0]: TypeError: Can only concatenate tensors but got <class 'str'>

Yes, this is exactly the error that I referenced in my note here:

One thing I still want to mention, but didn't know how to handle, is that we need to set dispatch_batches=False in the GRPOConfig when using an iterable dataset. If this is not set, we get an error when torch tries to concatenate strings (i.e. the prompts). I think this is probably not a GRPO specific problem though, see e.g. here huggingface/transformers#26548 (comment)

For debugging purposes, you can bypass it by setting dispatch_batches=False

Also, I think the map suggestion could work as well, I just hadn't considered that before.

@wilrop
Copy link
Contributor Author

wilrop commented Apr 17, 2025

Did you have a chance to take another look at this?

@brayan07
Copy link

brayan07 commented Jun 18, 2025

An easier way of doing this might be to create a wrapper around IterableDataset using the same logic used by the RepeatSampler as follows:

def repeat_iterable_dataset(
    dataset: IterableDataset,
    mini_repeat_count: int,
    batch_size: int = 1,
    repeat_count: int = 1,
):
    batched_dataset = dataset.batch(batch_size=batch_size, drop_last_batch=True)
    for _ in range(repeat_count):
        for batch in batched_dataset:
            within_batch_dataset = Dataset.from_dict(batch)
            for record in within_batch_dataset:
                for _ in range(mini_repeat_count):
                    yield record

effective_batch_size = (
    self.args.per_device_train_batch_size
    * self.accelerator.num_processes
    * self.args.gradient_accumulation_steps
)
train_dataset = IterableDataset.from_generator(
    repeat_iterable_dataset,
    gen_kwargs={
        "dataset": input_train_dataset,
        "mini_repeat_count": self.num_generations,
        "batch_size": effective_batch_size // self.num_generations,
        "repeat_count": self.num_iterations,
     }
)

@qgallouedec
Copy link
Member

This seems like a reasonable approach

@mcleish7
Copy link

mcleish7 commented Jul 1, 2025

@brayan07 Can you propose a merge request with the fix integrated please? If not I am happy to attempt it

Edit:
I implemented @brayan07's fix and it gets a similar error to the one seen by @wilrop and @qgallouedec but with int instead of str. Maybe adding in something like:

if isinstance(train_dataset, torch.utils.data.IterableDataset) and args.accelerator_config.dispatch_batches:
    warnings.warn(
        "You passed an IterableDataset, `args.accelerator_config.dispatch_batches` is being set to False"
        "See https://github.com/huggingface/transformers/issues/26548#issuecomment-1885798533"
    )
    args.accelerator_config.dispatch_batches = False

Just before super().__init__ is called could be a good stop gap?

@marcandrelarochelle
Copy link

Any updates on when this would be merged / fixed? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Critical issue in GRPO with iterable datasets

6 participants