From f3a532f56b4aa7d4200f24d93fade4b2c9042736 Mon Sep 17 00:00:00 2001 From: tsingcwang Date: Fri, 8 Sep 2023 17:06:58 +0800 Subject: [PATCH] add Fine-tuning method: AdaLoRA --- src/llmtuner/hparams/finetuning_args.py | 12 +++++++-- src/llmtuner/tuner/core/adapter.py | 36 ++++++++++++++++++------- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index bda0adafc..91197e3ad 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -8,7 +8,7 @@ class FinetuningArguments: r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ - finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( + finetuning_type: Optional[Literal["lora", "adalora", "freeze", "full", "none"]] = field( default="lora", metadata={"help": "Which fine-tuning method to use."} ) @@ -37,6 +37,14 @@ class FinetuningArguments: Qwen choices: [\"mlp\", \"attn\"], \ LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} ) + target_r: Optional[int] = field( + default=8, + metadata={"help": "The target average rank of incremental matrix for AdaLoRA fine-tuning."} + ) + init_r: Optional[int] = field( + default=12, + metadata={"help": "The initial rank for each incremental matrix for AdaLoRA fine-tuning."} + ) lora_rank: Optional[int] = field( default=8, metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} @@ -82,7 +90,7 @@ def __post_init__(self): self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] - assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method." + assert self.finetuning_type in ["lora", "adalora", "freeze", "full", "none"], "Invalid fine-tuning method." def save_to_json(self, json_path: str): r"""Saves the content of this instance in JSON format inside `json_path`.""" diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 5db568763..edb4f29ce 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -6,6 +6,7 @@ PeftModel, TaskType, LoraConfig, + AdaLoraConfig, get_peft_model ) from peft.utils import CONFIG_NAME, WEIGHTS_NAME @@ -55,8 +56,11 @@ def init_adapter( if model_args.checkpoint_dir is not None: assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." - if finetuning_args.finetuning_type == "lora": - logger.info("Fine-tuning method: LoRA") + if finetuning_args.finetuning_type == "lora" or finetuning_args.finetuning_type == "adalora": + if finetuning_args.finetuning_type == "lora": + logger.info("Fine-tuning method: LoRA") + if finetuning_args.finetuning_type == "adalora": + logger.info("Fine-tuning method: AdaLoRA") latest_checkpoint = None if model_args.checkpoint_dir is not None: @@ -81,14 +85,26 @@ def init_adapter( model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) if is_trainable and latest_checkpoint is None: # create new lora weights while training - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - inference_mode=False, - r=finetuning_args.lora_rank, - lora_alpha=finetuning_args.lora_alpha, - lora_dropout=finetuning_args.lora_dropout, - target_modules=finetuning_args.lora_target - ) + if finetuning_args.finetuning_type == "lora": + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=finetuning_args.lora_rank, + lora_alpha=finetuning_args.lora_alpha, + lora_dropout=finetuning_args.lora_dropout, + target_modules=finetuning_args.lora_target + ) + if finetuning_args.finetuning_type == "adalora": + lora_config = AdaLoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + target_r=finetuning_args.target_r, + init_r=finetuning_args.init_r, + r=finetuning_args.lora_rank, + lora_alpha=finetuning_args.lora_alpha, + lora_dropout=finetuning_args.lora_dropout, + target_modules=finetuning_args.lora_target + ) model = get_peft_model(model, lora_config) if model_args.checkpoint_dir is not None: