diff --git a/examples/lisa_single_gpu/sft.sh b/examples/lisa_single_gpu/sft.sh new file mode 100644 index 000000000..be82cc163 --- /dev/null +++ b/examples/lisa_single_gpu/sft.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=0 python ../../src/train_bash.py \ + --stage sft \ + --do_train \ + --model_name_or_path meta-llama/Llama-2-7b-hf \ + --dataset alpaca_gpt4_en,glaive_toolcall \ + --dataset_dir ../../data \ + --template default \ + --finetuning_type full \ + --use_lisa \ + --lisa_activated_layers 2 \ + --lisa_interval_steps 5 \ + --output_dir ../../saves/LLaMA2-7B/lisa/sft \ + --overwrite_cache \ + --overwrite_output_dir \ + --cutoff_len 1024 \ + --preprocessing_num_workers 16 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --lr_scheduler_type cosine \ + --logging_steps 10 \ + --warmup_steps 20 \ + --save_steps 100 \ + --eval_steps 100 \ + --evaluation_strategy steps \ + --load_best_model_at_end \ + --learning_rate 5e-5 \ + --num_train_epochs 3.0 \ + --max_samples 3000 \ + --val_size 0.1 \ + --plot_loss \ + --fp16 + diff --git a/src/llmtuner/extras/callbacks.py b/src/llmtuner/extras/callbacks.py index 6e347c3c4..fd3e36ee8 100644 --- a/src/llmtuner/extras/callbacks.py +++ b/src/llmtuner/extras/callbacks.py @@ -2,23 +2,96 @@ import os import time from datetime import timedelta +from functools import reduce from typing import TYPE_CHECKING +import numpy as np from transformers import TrainerCallback from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length from .constants import LOG_FILE_NAME from .logging import get_logger -from .misc import fix_valuehead_checkpoint - +from .misc import fix_valuehead_checkpoint, count_parameters +from ..hparams import FinetuningArguments if TYPE_CHECKING: from transformers import TrainerControl, TrainerState, TrainingArguments - logger = get_logger(__name__) +class LisaTrainCallback(TrainerCallback): + def __init__(self, finetuning_args: "FinetuningArguments", trainer: None): + super().__init__() + self.trainer = trainer + self.layers_attribute = self.attention_layer_auto_detect(finetuning_args.lisa_attention_name) + self.step_interval = finetuning_args.lisa_interval_steps + self.lisa_activated_layers = finetuning_args.lisa_activated_layers + self.total_layers = len(self.get_layers()) + self.lisa_verbose = finetuning_args.lisa_verbose + self.trained_layers = set() + if self.lisa_activated_layers > self.total_layers: + raise ValueError( + f'lisa_activated_layers>({self.lisa_activated_layers})>total_layers({self.total_layers}), ' + f'please check your arguments.') + logger.info( + f"LISA will activate {self.lisa_activated_layers}/{self.total_layers} layers " + f"({self.lisa_activated_layers * 100 / self.total_layers}%) every {self.step_interval} steps" + ) + + def attention_layer_auto_detect(self, lisa_attention_name): + class_to_layers_map = { + 'LlamaForCausalLM': 'model.layers', + 'Qwen2ForCausalLM': 'model.layers', + 'MistralForCausalLM': 'model.layers', + 'MixtralForCausalLM': 'model.layers', + 'GemmaForCausalLM': 'model.layers', + 'GPT2LMHeadModel': 'transformer.h', + } + _atten_val = lisa_attention_name + model_class_name = self.trainer.model.__class__.__name__ + if _atten_val is None: + # Determine the way to access layers based on the model type + if model_class_name in class_to_layers_map: + _atten_val = class_to_layers_map[model_class_name] + + return _atten_val + + def on_step_begin(self, args, state, control, **kwargs): + if state.global_step % self.step_interval == 0: + self.switch_active_layers() + + def freeze_all_layers(self): + layers = self.get_layers() + for layer in layers: + for param in layer.parameters(): + param.requires_grad = False + + def get_layers(self): + return reduce(getattr, self.layers_attribute.split("."), self.trainer.model) + + def switch_active_layers(self): + # disable gradients for all layers + self.freeze_all_layers() + layers = self.get_layers() + active_layers_indices = np.random.choice(range(self.total_layers), self.lisa_activated_layers, + replace=False) + self.trained_layers.update(active_layers_indices) + for idx in active_layers_indices: + for param in layers[idx].parameters(): + param.requires_grad = True + if self.lisa_verbose: + trainable_params, all_param = count_parameters(self.trainer.model) + logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + )) + logger.info( + f"LISA will activate layers {','.join(map(str, sorted(active_layers_indices)))} for the next steps. " + f"{len(self.trained_layers)}/{self.total_layers} layers " + f"({len(self.trained_layers) * 100 / self.total_layers}%) " + f"are trained: {','.join(map(str, sorted(self.trained_layers)))}") + + class FixValueHeadModelCallback(TrainerCallback): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): r""" @@ -107,7 +180,7 @@ def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: self.max_steps = 0 def on_predict( - self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs ): r""" Event called after a successful prediction. @@ -153,7 +226,7 @@ def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "Tra f.write(json.dumps(logs) + "\n") def on_prediction_step( - self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs ): r""" Event called after a prediction step. diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 177a9f8a5..678c30996 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -204,7 +204,45 @@ class GaloreArguments: @dataclass -class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments): +class LisaArguments: + r""" + paper: https://arxiv.org/abs/2403.17919 + ref: https://github.com/OptimalScale/LMFlow + Arguments pertaining to the Lisa algorithm. + - 始终更新底层 embedding 和顶层 linear head; + - 随机更新少数中间的 self-attention 层,比如 2-4 层。 + """ + use_lisa: bool = field( + default=False, + metadata={ + "help": "the number of activated layers in LISA." + } + ) + lisa_activated_layers: int = field( + default=None, + metadata={ + "help": "the number of activated layers in LISA." + } + ) + lisa_interval_steps: int = field( + default=None, + metadata={ + "help": "the number of steps in each freezing interval of LISA, i.e. " + "the selected unfrozen layers are randomly switched every {lisa_interval_steps} steps." + } + ) + lisa_attention_name: str = field( + default="model.layers", + metadata={"help": "suffix name of attention names"} + ) + lisa_verbose: bool = field( + default=False, + metadata={"help": "output more for lisa"}, + ) + + +@dataclass +class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, LisaArguments): r""" Arguments pertaining to which techniques we are going to fine-tuning with. """ @@ -261,6 +299,12 @@ def split_arg(arg): if self.use_galore and self.finetuning_type == "lora": raise ValueError("Cannot use LoRA with GaLore together.") + if self.use_lisa: + if self.finetuning_type != 'full': + raise ValueError("`use_lisa` requires `finetuning_type` is `full`") + if self.lisa_interval_steps is None or self.lisa_activated_layers is None: + raise ValueError("`use_lisa` requires `lisa_interval_steps` and `lisa_activated_layers`") + def save_to_json(self, json_path: str): r"""Saves the content of this instance in JSON format inside `json_path`.""" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 9ab78850b..a38c3538f 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -5,6 +5,7 @@ from transformers import DataCollatorForSeq2Seq from ...data import get_dataset, split_dataset +from ...extras.callbacks import LisaTrainCallback from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss @@ -13,7 +14,6 @@ from .metric import ComputeMetrics from .trainer import CustomSeq2SeqTrainer - if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback @@ -60,6 +60,10 @@ def run_sft( **split_dataset(dataset, data_args, training_args), ) + # post callbacks + if finetuning_args.use_lisa: + trainer.add_callback(LisaTrainCallback(finetuning_args, trainer)) + # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict() gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids