From e091386ad02ab7fe3e126b567f58cc9b86a41e6b Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 12 Feb 2025 12:51:40 +0800 Subject: [PATCH] fix grpo nan --- swift/trainers/rlhf_trainer/grpo_trainer.py | 28 ++++++++++++++++----- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 586b134f7c..5eabe3b51e 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -2,6 +2,7 @@ # Part of the implementation is borrowed from huggingface/trl. import inspect from collections import defaultdict +from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Union from unittest.mock import patch @@ -140,6 +141,21 @@ def __init__(self, self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) self.log_completions = args.log_completions + @staticmethod + @contextmanager + def _template_context(template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + mode = template.mode + if mode in {'vllm', 'pt', 'lmdeploy'}: + template.set_mode('train') + template.max_length = None + try: + yield + finally: + template.set_mode(mode) + template.max_length = max_length + def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device @@ -191,10 +207,9 @@ def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]: InferRequest.remove_response(messages) messages.append({'role': 'assistant', 'content': output.choices[0].message.content}) - self.template.set_mode('train') - batched_inputs = [self.template.encode(infer_request) for infer_request in inputs] - outputs = to_device(self.template.data_collator(batched_inputs), self.model.device) - self.template.set_mode('pt') # recover + with self._template_context(self.template): + batched_inputs = [self.template.encode(infer_request) for infer_request in inputs] + outputs = to_device(self.template.data_collator(batched_inputs), self.model.device) # we only need to compute the logits for the completion tokens labels = outputs.pop('labels') @@ -217,8 +232,9 @@ def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]: for i, (reward_func, reward_template) in enumerate(zip(self.reward_funcs, self.reward_templates)): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models - batched_inputs = [reward_template.encode(infer_request) for infer_request in inputs] - reward_inputs = to_device(reward_template.data_collator(batched_inputs), reward_func.device) + with self._template_context(reward_template): + batched_inputs = [reward_template.encode(infer_request) for infer_request in inputs] + reward_inputs = to_device(reward_template.data_collator(batched_inputs), reward_func.device) with torch.inference_mode(): rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]