Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Part of the implementation is borrowed from huggingface/trl.
import inspect
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import patch

Expand Down Expand Up @@ -140,6 +141,21 @@ def __init__(self,
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
self.log_completions = args.log_completions

@staticmethod
@contextmanager
def _template_context(template):
# The max_length for prompt and completion has already been restricted, so there is no need for max_length here.
max_length = template.max_length
mode = template.mode
if mode in {'vllm', 'pt', 'lmdeploy'}:
template.set_mode('train')
template.max_length = None
try:
yield
finally:
template.set_mode(mode)
template.max_length = max_length

def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device

Expand Down Expand Up @@ -191,10 +207,9 @@ def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:
InferRequest.remove_response(messages)
messages.append({'role': 'assistant', 'content': output.choices[0].message.content})

self.template.set_mode('train')
batched_inputs = [self.template.encode(infer_request) for infer_request in inputs]
outputs = to_device(self.template.data_collator(batched_inputs), self.model.device)
self.template.set_mode('pt') # recover
with self._template_context(self.template):
batched_inputs = [self.template.encode(infer_request) for infer_request in inputs]
outputs = to_device(self.template.data_collator(batched_inputs), self.model.device)

# we only need to compute the logits for the completion tokens
labels = outputs.pop('labels')
Expand All @@ -217,8 +232,9 @@ def _prepare_inputs(self, inputs) -> Dict[str, Union[torch.Tensor, Any]]:

for i, (reward_func, reward_template) in enumerate(zip(self.reward_funcs, self.reward_templates)):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
batched_inputs = [reward_template.encode(infer_request) for infer_request in inputs]
reward_inputs = to_device(reward_template.data_collator(batched_inputs), reward_func.device)
with self._template_context(reward_template):
batched_inputs = [reward_template.encode(infer_request) for infer_request in inputs]
reward_inputs = to_device(reward_template.data_collator(batched_inputs), reward_func.device)

with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0]
Expand Down