[bugfix] fix grpo generate#9183
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the GRPOTrainer to set remove_unused_columns to True for the inference template, which prevents potential crashes in the TransformersEngine. The reviewer suggests extending this change to the vLLM engine path in RolloutTrainerMixin to maintain consistency and prevent similar issues.
| infer_template = copy(self.template) | ||
| infer_template.padding_free = False | ||
| infer_template.sequence_parallel_size = 1 | ||
| infer_template.remove_unused_columns = True |
There was a problem hiding this comment.
This fix correctly addresses potential crashes in TransformersEngine when remove_unused_columns=False is set in the training configuration. By explicitly setting it to True for the inference template, you ensure that the data_collator filters out non-model inputs before generation.
However, the same issue likely affects the vLLM engine path. In RolloutTrainerMixin._prepare_vllm_engine (line 202 of rollout_mixin.py), a similar template copy is created but remove_unused_columns is not set to True. To maintain consistency and prevent similar issues when using vLLM, you should consider applying this fix there as well.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces column filtering in the GRPOTrainer by defining a set of keys to be excluded and applying this filter during batch preparation and model input generation. It also enables remove_unused_columns in the inference template. The review feedback recommends optimizing lookup performance by using a set for the filtered keys and making the filtering logic conditional on the template's configuration to avoid unnecessary processing.
| self._filtered_keys = [ | ||
| 'prompt_id', 'request_id', 'response_token_ids', 'finish_reason', 'is_truncated', 'add_eos' | ||
| ] |
There was a problem hiding this comment.
Defining a set version of _filtered_keys would allow for O(1) lookup performance in the training loop. Keeping the list version is still useful for compatibility with list concatenations (like in _prepare_model_inputs).
| self._filtered_keys = [ | |
| 'prompt_id', 'request_id', 'response_token_ids', 'finish_reason', 'is_truncated', 'add_eos' | |
| ] | |
| self._filtered_keys = [ | |
| 'prompt_id', 'request_id', 'response_token_ids', 'finish_reason', 'is_truncated', 'add_eos' | |
| ] | |
| self._filtered_keys_set = set(self._filtered_keys) |
| for encoded_inputs in batch_encoded_inputs: | ||
| extra_kwargs = encoded_inputs.get('_extra_kwargs') or {} | ||
| for k in list(extra_kwargs.keys()): | ||
| if k not in self._filtered_keys: | ||
| extra_kwargs.pop(k) |
There was a problem hiding this comment.
This filtering logic is only necessary when template.remove_unused_columns is False. If it's True (the default), _extra_kwargs won't be merged into the collated batch by the template's data collator. Adding this check avoids unnecessary iterations over the batch samples. Additionally, using the set version of filtered keys improves lookup performance.
| for encoded_inputs in batch_encoded_inputs: | |
| extra_kwargs = encoded_inputs.get('_extra_kwargs') or {} | |
| for k in list(extra_kwargs.keys()): | |
| if k not in self._filtered_keys: | |
| extra_kwargs.pop(k) | |
| if not template.remove_unused_columns: | |
| for encoded_inputs in batch_encoded_inputs: | |
| extra_kwargs = encoded_inputs.get('_extra_kwargs') or {} | |
| for k in list(extra_kwargs.keys()): | |
| if k not in self._filtered_keys_set: | |
| extra_kwargs.pop(k) |
No description provided.