Skip to content

[bugfix] fix grpo generate#9183

Merged
Jintao-Huang merged 4 commits into
modelscope:mainfrom
Jintao-Huang:fix_grpo_generate
Apr 22, 2026
Merged

[bugfix] fix grpo generate#9183
Jintao-Huang merged 4 commits into
modelscope:mainfrom
Jintao-Huang:fix_grpo_generate

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

@hjh0119
Copy link
Copy Markdown
Collaborator

hjh0119 commented Apr 22, 2026

#9131

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the GRPOTrainer to set remove_unused_columns to True for the inference template, which prevents potential crashes in the TransformersEngine. The reviewer suggests extending this change to the vLLM engine path in RolloutTrainerMixin to maintain consistency and prevent similar issues.

infer_template = copy(self.template)
infer_template.padding_free = False
infer_template.sequence_parallel_size = 1
infer_template.remove_unused_columns = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This fix correctly addresses potential crashes in TransformersEngine when remove_unused_columns=False is set in the training configuration. By explicitly setting it to True for the inference template, you ensure that the data_collator filters out non-model inputs before generation.

However, the same issue likely affects the vLLM engine path. In RolloutTrainerMixin._prepare_vllm_engine (line 202 of rollout_mixin.py), a similar template copy is created but remove_unused_columns is not set to True. To maintain consistency and prevent similar issues when using vLLM, you should consider applying this fix there as well.

@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces column filtering in the GRPOTrainer by defining a set of keys to be excluded and applying this filter during batch preparation and model input generation. It also enables remove_unused_columns in the inference template. The review feedback recommends optimizing lookup performance by using a set for the filtered keys and making the filtering logic conditional on the template's configuration to avoid unnecessary processing.

Comment on lines +166 to +168
self._filtered_keys = [
'prompt_id', 'request_id', 'response_token_ids', 'finish_reason', 'is_truncated', 'add_eos'
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defining a set version of _filtered_keys would allow for O(1) lookup performance in the training loop. Keeping the list version is still useful for compatibility with list concatenations (like in _prepare_model_inputs).

Suggested change
self._filtered_keys = [
'prompt_id', 'request_id', 'response_token_ids', 'finish_reason', 'is_truncated', 'add_eos'
]
self._filtered_keys = [
'prompt_id', 'request_id', 'response_token_ids', 'finish_reason', 'is_truncated', 'add_eos'
]
self._filtered_keys_set = set(self._filtered_keys)

Comment on lines +887 to +891
for encoded_inputs in batch_encoded_inputs:
extra_kwargs = encoded_inputs.get('_extra_kwargs') or {}
for k in list(extra_kwargs.keys()):
if k not in self._filtered_keys:
extra_kwargs.pop(k)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This filtering logic is only necessary when template.remove_unused_columns is False. If it's True (the default), _extra_kwargs won't be merged into the collated batch by the template's data collator. Adding this check avoids unnecessary iterations over the batch samples. Additionally, using the set version of filtered keys improves lookup performance.

Suggested change
for encoded_inputs in batch_encoded_inputs:
extra_kwargs = encoded_inputs.get('_extra_kwargs') or {}
for k in list(extra_kwargs.keys()):
if k not in self._filtered_keys:
extra_kwargs.pop(k)
if not template.remove_unused_columns:
for encoded_inputs in batch_encoded_inputs:
extra_kwargs = encoded_inputs.get('_extra_kwargs') or {}
for k in list(extra_kwargs.keys()):
if k not in self._filtered_keys_set:
extra_kwargs.pop(k)

@Jintao-Huang Jintao-Huang merged commit f9aad77 into modelscope:main Apr 22, 2026
3 checks passed
Jintao-Huang added a commit that referenced this pull request Apr 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants