Skip to content

Commit

Permalink
[chatgpt] fix trainer generate kwargs (#3166)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 committed Mar 17, 2023
1 parent c474fda commit 1e58d31
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions applications/ChatGPT/chatgpt/trainer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self,
**generate_kwargs) -> None:
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
self.actor = actor
Expand All @@ -73,7 +74,6 @@ def __init__(self,

self.actor_optim = actor_optim
self.critic_optim = critic_optim
self._set_default_generate_kwargs(generate_kwargs, actor)

def training_step(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
Expand Down Expand Up @@ -102,11 +102,15 @@ def training_step(self, experience: Experience) -> Dict[str, float]:

return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}

def _set_default_generate_kwargs(self, generate_kwargs: dict, actor: Actor) -> None:
origin_model = self.strategy._unwrap_actor(actor)
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
generate_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation

if 'update_model_kwargs_fn' not in generate_kwargs:
generate_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
origin_model = strategy._unwrap_actor(actor)
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation

if 'update_model_kwargs_fn' not in generate_kwargs:
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn

return new_kwargs

0 comments on commit 1e58d31

Please sign in to comment.