diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index cf8745de74..b43cd401b2 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -296,12 +296,14 @@ def __call__(self, completions, **kwargs) -> List[float]: class CosineReward(ORM): # https://arxiv.org/abs/2502.03373 def __init__(self, + tokenizer=None, cosine_min_len_value_wrong: float = 0.0, cosine_max_len_value_wrong: float = -0.5, cosine_min_len_value_correct: float = 1.0, cosine_max_len_value_correct: float = 0.5, cosine_max_len: int = 1000, accuracy_orm=None): + self.tokenizer = tokenizer self.min_len_value_wrong = cosine_min_len_value_wrong self.max_len_value_wrong = cosine_max_len_value_wrong self.min_len_value_correct = cosine_min_len_value_correct @@ -326,7 +328,7 @@ def __call__(self, completions, solution, **kwargs) -> List[float]: else: min_value = self.min_len_value_wrong max_value = self.max_len_value_wrong - gen_len = len(content) + gen_len = len(self.tokenizer.encode(content)) reward = self.cosfn(gen_len, self.max_len, min_value, max_value) rewards.append(reward) return rewards diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 1fdf5034d6..25e18e441e 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -71,6 +71,7 @@ def __init__(self, self.args = args self.queue = Queue() self.processing_class = kwargs.get('template').tokenizer + self.tokenizer = kwargs.get('template').tokenizer if hasattr(kwargs.get('template'), 'tokenizer') else None if not isinstance(reward_funcs, list): reward_funcs = [reward_funcs] @@ -83,6 +84,8 @@ def __init__(self, key: getattr(args, key) for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) } + if reward_func_class.__name__ == 'CosineReward' and 'tokenizer' in reward_func_args: + reward_func_kwargs['tokenizer'] = self.tokenizer reward_funcs[i] = reward_func_class(**reward_func_kwargs) elif not callable(reward_func): raise ValueError(f'reward_function {reward_func} is not implemented in swift.llm.plugin')