diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 3e196bd321..cc0a333bea 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -707,8 +707,6 @@ def _infer_multi_turn(self, inputs_slice, request_config) -> List[List[Dict[str, outputs.append(_choices) assert len(outputs) == prompt_lens assert all([len(o) == self.args.tensor_parallel_size for o in outputs]) - if isinstance(outputs[0][0], list): - outputs = [output[0] for output in outputs] return outputs def async_infer(self, inputs, inputs_slice, distributed_idx): @@ -832,6 +830,8 @@ def _generate_and_score_completions( self.model.train() if is_multimodal: self.template.register_post_encode_hook(models) + if isinstance(outputs[0][0], list): + outputs = [output[0] for output in outputs] # Slice to keep only the local part of the data process_slice = slice(