From 138e3d20c8d0856427db766e1e8241cdfe6ddbc3 Mon Sep 17 00:00:00 2001 From: yaozhewei Date: Sat, 1 Jul 2023 01:59:30 +0500 Subject: [PATCH] make training more stable --- .../step3_rlhf_finetuning/ppo_trainer.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py index 2e730aed3..da088bb85 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py @@ -58,7 +58,7 @@ def __init__(self, rlhf_engine, args): args.end_of_conversation_token)['input_ids'][-1] # Those value can be changed - self.kl_ctl = 0.02 + self.kl_ctl = 0.1 self.clip_reward_value = 5 self.cliprange = 0.2 self.cliprange_value = 0.2 @@ -70,17 +70,20 @@ def _generate_sequence(self, prompts, mask): max_min_length = self.max_answer_seq_len + prompts.shape[1] with torch.no_grad(): - seq = self.actor_model.module.generate(prompts, - attention_mask=mask, - max_length=max_min_length, - min_length=max_min_length) + seq = self.actor_model.module.generate( + prompts, + attention_mask=mask, + max_length=max_min_length, + pad_token_id=self.tokenizer.pad_token_id, + # min_length=max_min_length + ) # Filter out seq with no answers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning # NOTE: this will causes each GPU has different number of examples batch_size = seq.shape[0] prompt_length = prompts.shape[1] - ans = seq[:, prompt_length:] self.prompt_length = prompt_length + ans = seq[:, prompt_length:] valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1) out_seq = [] for i in range(batch_size): @@ -100,7 +103,6 @@ def generate_experience(self, prompts, mask): pad_token_id = self.tokenizer.pad_token_id attention_mask = seq.not_equal(pad_token_id).long() - with torch.no_grad(): output = self.actor_model(seq, attention_mask=attention_mask) output_ref = self.ref_model(seq, attention_mask=attention_mask) @@ -131,7 +133,7 @@ def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score, kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs) rewards = kl_divergence_estimate start = prompts.shape[1] - 1 - ends = start + action_mask[:, start:].sum(1) + ends = start + action_mask[:, start:].sum(1) + 1 reward_clip = torch.clamp(reward_score, -self.clip_reward_value, self.clip_reward_value) batch_size = log_probs.shape[0] @@ -159,6 +161,12 @@ def train_rlhf(self, inputs): old_rewards = self.compute_rewards(prompts, log_probs, ref_log_probs, reward_score, action_mask) + ends = start + action_mask[:, start:].sum(1) + 1 + # we need to zero out the reward and value after the end of the conversation + # otherwise the advantage/return will be wrong + for i in range(old_rewards.shape[0]): + old_rewards[i, ends[i]:] = 0 + old_values[i, ends[i]:] = 0 advantages, returns = self.get_advantages_and_returns( old_values, old_rewards, start)