Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/Instruction/GRPO/DeveloperGuide/多轮训练.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ swift rollout \

可以在kwargs中获取 trajectory_inputs 获取完整轨迹的数据,具体实现参考[MultiTurnThinkingTips类](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)

### 多模态数据修改
在多模态多轮交互场景下,可能需要在对话过程中动态增删或修改多模态数据,并确保这些变更同步至 trainer。

实现方式:借助 rollout_infos,通过指定键值覆盖原始数据集的多模态内容。

现已支持覆盖的键:images、audios、videos。

具体请参考[DeepEyes Schduler](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/deepeyes/deepeyes_plugin.py#L403-L404)

### 返回 response token ids
在默认的多轮交互流程中,规划器先把模型生成的文本字符串返回给 trainer,trainer 再将其重新 encode 为 token id,用于后续训练。为了避免这一步重复编码的开销,你可以让规划器直接返回 response_token_ids,省去 trainer 侧的再次 encode。

Expand Down
9 changes: 9 additions & 0 deletions docs/source_en/Instruction/GRPO/DeveloperGuide/multi_turn.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ The complete trajectory can be accessed via `trajectory_inputs` in `kwargs`.

For a concrete implementation, see the [MultiTurnThinkingTips class](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py)

### Multimodal Data Override
In multimodal, multi-turn interactions, you may need to dynamically add, delete, or modify multimodal data during the conversation and ensure these changes are synchronized to the trainer.

Implementation: Use `rollout_infos` to override the original multimodal content in the dataset by specifying the corresponding keys.

Supported override keys: images, audios, videos.

For details, see [DeepEyes Scheduler](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/deepeyes/deepeyes_plugin.py#L403-L404).

### Returning response token IDs

In the default workflow the scheduler returns text, the trainer re-encodes it to token IDs for training.
Expand Down
1 change: 1 addition & 0 deletions examples/train/grpo/plugin/deepeyes/deepeyes_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def step(self, infer_request, response_choice, current_turn):
infer_request.messages.append({'role': 'user', 'content': query})
if cropped_img:
infer_request.images.append(cropped_img)
# override the images
extra_info['images'] = infer_request.images

# Return dictionary format according to new MultiTurnScheduler interface
Expand Down
55 changes: 28 additions & 27 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,12 +1613,7 @@ def _compute_loss_chunked(self, model, inputs: DataType):
end_idx = min(start_idx + new_chunk_size, batch_size)

if start_idx < batch_size:
# Create chunk inputs
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
chunk_inputs[key] = value[start_idx:end_idx]
else:
chunk_inputs[key] = value
chunk_inputs = self.get_chunked_inputs(inputs, start_idx, end_idx)

# Compute loss and metrics for this chunk (without updating global metrics)
chunk_loss, chunk_metrics_data = self._compute_loss_and_metrics(model, chunk_inputs)
Expand Down Expand Up @@ -1862,26 +1857,6 @@ def _get_per_token_logps_and_entropies_chunked(self,
``False``.
"""

def get_chunked_inputs(inputs, start_idx, end_idx):
chunk_inputs = {}
if not self.is_multimodal:
# for LLM, slice the inputs
for key, val in inputs.items():
if isinstance(val, torch.Tensor):
chunk_inputs[key] = val[start_idx:end_idx]
else:
chunk_inputs[key] = val
else:
# for MLLM, re-encode to get mm-related inputs
origin_data = inputs['_origin_data'][start_idx:end_idx]
template = self.template
with self._template_context(template):
chunk_inputs = [template.encode(data) for data in origin_data]
chunk_inputs = to_device(template.data_collator(chunk_inputs), self.model.device)
chunk_inputs['logits_to_keep'] = inputs['logits_to_keep']
chunk_inputs.pop('labels', None)
return chunk_inputs

batch_size = inputs['input_ids'].shape[0]
mode = 'train' if self.model.training else 'eval'
chunk_size = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size
Expand All @@ -1901,7 +1876,7 @@ def get_chunked_inputs(inputs, start_idx, end_idx):
end_idx = min(start_idx + new_chunk_size, batch_size)

if start_idx < end_idx:
chunk_inputs = get_chunked_inputs(inputs, start_idx, end_idx)
chunk_inputs = self.get_chunked_inputs(inputs, start_idx, end_idx)

chunk_logps, chunk_entropies = self._get_per_token_logps_and_entropies_single(
model, chunk_inputs, compute_entropy)
Expand Down Expand Up @@ -2594,6 +2569,14 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out
input_data['finish_reason'] = choice.finish_reason
input_data['is_truncated'] = choice.finish_reason == 'length'

# Step 5: override multi-modal data from rollout_infos
if output.rollout_infos:
multi_modal_keys = ['images', 'videos', 'audios']
for key in multi_modal_keys:
Comment on lines +2574 to +2575
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better maintainability and to avoid magic strings, consider defining ['images', 'videos', 'audios'] as a class-level or module-level constant. This makes it easier to manage and reuse these keys if they are needed elsewhere in the class.

if key in output.rollout_infos:
input_data[key] = output.rollout_infos[key]
logger.info_once(f'Overriding multi-modal data from rollout_infos for key: {key}')

return input_data

if not self.dynamic_num_samples:
Expand Down Expand Up @@ -2892,3 +2875,21 @@ def _get_last_indices(self, request_ids: List[str]) -> torch.Tensor:
for i, rid in enumerate(request_ids):
seen[rid] = i
return torch.tensor(list(seen.values()), dtype=torch.long, device=self.accelerator.device)

def get_chunked_inputs(self, inputs, start_idx, end_idx):
chunk_inputs = {}
# for LLM, slice the inputs
for key, val in inputs.items():
if isinstance(val, torch.Tensor):
chunk_inputs[key] = val[start_idx:end_idx]
else:
chunk_inputs[key] = val
if self.is_multimodal:
# for MLLM, re-encode to get mm-related inputs
origin_data = inputs['_origin_data'][start_idx:end_idx]
template = self.template
with self._template_context(template):
encoded_data = [template.encode(data) for data in origin_data]
chunk_inputs.update(to_device(template.data_collator(encoded_data), self.model.device))
chunk_inputs.pop('labels', None)
return chunk_inputs