diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 5bdeba9608..c0509a77f8 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -122,6 +122,14 @@ class RLHFArguments: default="lora", metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."} ) + ppo_use_separate_value_model: Optional[bool] = field( + default=False, + metadata={"help": "Use a separate value model which does not share parameters with policy."} + ) + value_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory containing the checkpoints of the value model."} + ) @dataclass diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index acd78b0ef5..befeac7ed3 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -3,6 +3,7 @@ import math import torch from tqdm import tqdm +from types import MethodType from typing import TYPE_CHECKING, List, Optional, Tuple from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl @@ -296,7 +297,16 @@ def batched_forward_pass( attention_mask = input_kwargs["attention_mask"] with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 - logits, _, values = model(**input_kwargs) + unwrapped_model = self.accelerator.unwrap_model(model) + if "value" in unwrapped_model.pretrained_model.peft_config: + # this model has a separate value model and policy model + unwrapped_model.pretrained_model.set_adapter("value") + _, _, values = model(**input_kwargs) + unwrapped_model.pretrained_model.set_adapter("default") + logits, _, _ = model(**input_kwargs) + else: + # this model has a shared value model and policy model + logits, _, values = model(**input_kwargs) if values.size(0) != input_ids.size(0): # adapt to chatglm2 values = torch.transpose(values, 0, 1) diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 88d5e49d3f..0307ff095b 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -1,7 +1,10 @@ # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py import math +import os +from peft import TaskType, LoraConfig from trl import PPOConfig +import torch from torch.optim import AdamW from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorWithPadding @@ -29,6 +32,22 @@ def run_ppo( ): dataset = get_dataset(model_args, data_args) model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") + if finetuning_args.ppo_use_separate_value_model: + if finetuning_args.value_model is not None: + model.pretrained_model.load_adapter(finetuning_args.value_model, "value", is_trainable=True) + state_dict = torch.load(os.path.join(finetuning_args.value_model, "pytorch_model.bin"), map_location="cpu") + model.load_state_dict(state_dict, strict=False) + else: + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=finetuning_args.lora_rank, + lora_alpha=finetuning_args.lora_alpha, + lora_dropout=finetuning_args.lora_dropout, + target_modules=finetuning_args.lora_target, + modules_to_save=finetuning_args.additional_target + ) + model.pretrained_model.add_adapter("value", lora_config) dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training