diff --git a/applications/ChatGPT/chatgpt/trainer/ppo.py b/applications/ChatGPT/chatgpt/trainer/ppo.py index 789e0c2f8f1e..dacab4784039 100644 --- a/applications/ChatGPT/chatgpt/trainer/ppo.py +++ b/applications/ChatGPT/chatgpt/trainer/ppo.py @@ -63,6 +63,7 @@ def __init__(self, **generate_kwargs) -> None: experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef) replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload) + generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor) super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer, sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs) self.actor = actor @@ -73,7 +74,6 @@ def __init__(self, self.actor_optim = actor_optim self.critic_optim = critic_optim - self._set_default_generate_kwargs(generate_kwargs, actor) def training_step(self, experience: Experience) -> Dict[str, float]: self.actor.train() @@ -102,11 +102,15 @@ def training_step(self, experience: Experience) -> Dict[str, float]: return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} - def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None: - origin_model = self.strategy._unwrap_actor(actor) - # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): - generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation - if 'update_model_kwargs_fn' not in generate_kwargs: - generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn +def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: + origin_model = strategy._unwrap_actor(actor) + new_kwargs = {**generate_kwargs} + # use huggingface models method directly + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): + new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation + + if 'update_model_kwargs_fn' not in generate_kwargs: + new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn + + return new_kwargs