diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 5472a8d0da..0a88887859 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -127,6 +127,7 @@ - 🔥freeze_llm: 冻结LLM. 默认为False. 可用于全参和LoRA - 🔥target_modules: 指定lora模块, 默认为`all-linear`, 自动寻找除lm_head外的linear并附加tuner. 该参数不限于LoRA - 🔥target_regex: 指定lora模块的regex表达式. 默认为`None`, 如果该值传入, 则target_modules不生效. 该参数不限于LoRA +- 🔥init_weights: 初始化weights的方法, LoRA可以指定为`true`, `false`, `guassian`, `pissa`, `pissa_niter_[number of iters]`, Bone可以指定为`true`, `false`, `bat`, 默认值`true` - modules_to_save: 在已附加tuner后,原模型参与训练和存储的模块,默认为`[]`. 该参数不限于LoRA #### 全参 @@ -138,7 +139,6 @@ - 🔥lora_rank: 默认为`8` - 🔥lora_alpha: 默认为`32` - lora_dropout: 默认为`0.05` -- 🔥init_lora_weights: 初始化LoRA weights的方法, 可以指定为`true`, `false`, `guassian`, `pissa`, `pissa_niter_[number of iters]`, 默认值`true` - lora_bias: 默认为`'none'`, 可以选择的值: 'none', 'all'. 如果你要将bias全都设置为可训练, 你可以设置为`'all'` - lora_dtype: 指定lora模块的dtype类型. 支持'float16', 'bfloat16', 'float32',不设置默认跟随原模型类型 - 🔥use_dora: 默认为`False`, 是否使用`DoRA` diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 4ff6d2e4bc..59daefb686 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -130,6 +130,7 @@ Other important parameters: - 🔥freeze_llm: Freeze LLM. Default is False. Applicable for full parameters and LoRA. - 🔥target_modules: Specify the LoRA module, default is `all-linear`, automatically finds linear layers except for lm_head and attaches the tuner. This parameter is not limited to LoRA. - 🔥target_regex: Specify a regex expression for the LoRA module. Default is `None`, if this value is provided, target_modules does not take effect. This parameter is not limited to LoRA. +- 🔥init_weights: The method of init tuner weights, For lora the accepted values are `true`, `false`, `guassian`, `pissa`, `pissa_niter_[number of iters]`, for bone are `true`, `false`, `bat`, default is `true` - modules_to_save: After the tuner is attached, the original model's modules used during training and storage, default is `[]`. This parameter is not limited to LoRA. #### Full Arguments @@ -143,7 +144,6 @@ Other important parameters: - 🔥lora_rank: Default is `8`. - 🔥lora_alpha: Default is `32`. - lora_dropout: Default is `0.05`. -- 🔥init_lora_weights: Method to initialize LoRA weights, can be specified as `true`, `false`, `gaussian`, `pissa`, `pissa_niter_[number of iters]`, default is `true`. - lora_bias: Default is `'none'`, selectable values are: 'none', 'all'. If you want to set all biases as trainable, you can set it to `'all'`. - lora_dtype: Specify the dtype of the LoRA module. Supports 'float16', 'bfloat16', 'float32', defaults to the original model type. - 🔥use_dora: Default is `False`, whether to use `DoRA`. diff --git a/examples/train/tuners/bone/train.sh b/examples/train/tuners/bone/train.sh new file mode 100644 index 0000000000..88c220facc --- /dev/null +++ b/examples/train/tuners/bone/train.sh @@ -0,0 +1,17 @@ +# 17.3GiB +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model Qwen/Qwen2.5-7B-Instruct \ + --train_type bone \ + --label_names labels \ + --dataset swift/self-cognition#1000 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --learning_rate 1e-4 \ + --gradient_accumulation_steps 16 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 5 \ + --model_author swift \ + --model_name swift-robot diff --git a/requirements/framework.txt b/requirements/framework.txt index 91d3d96c31..800ca9449d 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -14,7 +14,7 @@ nltk numpy<2.0 oss2 pandas -peft>=0.11.0,<0.14.0 +peft>=0.11.0,<0.15.0 pillow requests rouge diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 6f4f2cd499..04f44ca092 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -22,8 +22,8 @@ def get_supported_tuners(): - return {'lora', 'full', 'longlora', 'adalora', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft', 'reft'} | set( - extra_tuners.keys()) + return {'lora', 'full', 'longlora', 'adalora', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft', 'reft', 'bone' + } | set(extra_tuners.keys()) @dataclass diff --git a/swift/llm/argument/tuner_args.py b/swift/llm/argument/tuner_args.py index 8e4aacb355..b89512092e 100644 --- a/swift/llm/argument/tuner_args.py +++ b/swift/llm/argument/tuner_args.py @@ -28,8 +28,7 @@ class TunerArguments: lorap_lr_ratio (float): Learning rate ratio for LoRA. Default is None. use_rslora (bool): Flag to indicate if RSLora is used. Default is False. use_dora (bool): Flag to indicate if Dora is used. Default is False. - init_lora_weights (str): Initialization method for LoRA weights. Default is 'true'. - Allowed values are 'gaussian', 'pissa', 'pissa_niter_[number of iters]', 'olora', 'loftq', 'true', 'false'. + init_weights (str): Initialization method for weights of supported tuners. Default is 'true'. fourier_n_frequency (int): Number of frequencies for FourierFT. Default is 2000. fourier_scaling (float): Scaling factor for FourierFT. Default is 300.0. @@ -110,8 +109,10 @@ class TunerArguments: lorap_lr_ratio: Optional[float] = None use_rslora: bool = False use_dora: bool = False - # Literal['gaussian', 'pissa', 'pissa_niter_[number of iters]', 'olora', 'loftq', 'true', 'false'] - init_lora_weights: str = 'true' + # Lora: Literal['gaussian', 'pissa', 'pissa_niter_[number of iters]', 'olora', 'loftq', 'true', 'false'] + + # Bone: Literal['bat', 'true', 'false'] + init_weights: str = 'true' # fourierft fourier_n_frequency: int = 2000 @@ -181,8 +182,8 @@ class TunerArguments: use_liger: bool = False def __post_init__(self): - if isinstance(self.init_lora_weights, str) and self.init_lora_weights.lower() in {'true', 'false'}: - self.init_lora_weights = bool(strtobool(self.init_lora_weights)) + if isinstance(self.init_weights, str) and self.init_weights.lower() in {'true', 'false'}: + self.init_weights = bool(strtobool(self.init_weights)) self._init_multimodal_full() if self.target_regex: self.target_modules = self.target_regex diff --git a/swift/llm/infer/utils.py b/swift/llm/infer/utils.py index 8f654ac088..1975b02374 100644 --- a/swift/llm/infer/utils.py +++ b/swift/llm/infer/utils.py @@ -147,6 +147,9 @@ def _prepare_pt_engine(args: InferArguments, pt_engine): pt_engine.processor = processor else: pt_engine.model = Swift.from_pretrained(pt_engine.model, args.ckpt_dir, inference_mode=True) + if args.train_type == 'bone': + # Bone has a problem of float32 matmul with bloat16 in `peft==0.14.0` + pt_engine.model.to(pt_engine.model.dtype) def prepare_pt_engine_template(args: InferArguments, load_model: bool = True, **kwargs) -> Tuple[PtEngine, Template]: diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py index 2737191358..a2fe60425a 100644 --- a/swift/llm/train/tuner.py +++ b/swift/llm/train/tuner.py @@ -18,17 +18,26 @@ def apply_liger(model_type: str): from liger_kernel.transformers import (apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, apply_liger_kernel_to_gemma, - apply_liger_kernel_to_qwen2) - if 'llama3' in model_type: + apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen2_vl, + apply_liger_kernel_to_gemma2, apply_liger_kernel_to_phi3, + apply_liger_kernel_to_mllama) + from swift.llm import ModelType + if model_type in (ModelType.llama, ModelType.llama3, ModelType.llama3_1, ModelType.llama3_2): apply_liger_kernel_to_llama() - elif 'mistral' in model_type: + elif model_type in (ModelType.mistral): apply_liger_kernel_to_mistral() - elif 'mixtral' in model_type: + elif model_type in (ModelType.mixtral): apply_liger_kernel_to_mixtral() - elif 'gemma' in model_type: + elif model_type in (ModelType.gemma): apply_liger_kernel_to_gemma() - elif 'qwen2' in model_type: + elif model_type in (ModelType.gemma2): apply_liger_kernel_to_qwen2() + elif model_type in (ModelType.phi3): + apply_liger_kernel_to_phi3() + elif model_type in (ModelType.llama3_2_vision): + apply_liger_kernel_to_mllama() + elif model_type in (ModelType.qwen2_vl): + apply_liger_kernel_to_qwen2_vl() else: raise ValueError(f'Unsupported liger model_type: {model_type}') @@ -111,7 +120,7 @@ def prepare_adapter(args: TrainArguments, model): 'use_rslora': args.use_rslora, 'use_dora': args.use_dora, 'lorap_lr_ratio': args.lorap_lr_ratio, - 'init_lora_weights': args.init_lora_weights, + 'init_lora_weights': args.init_weights, } if args.train_type in ('lora', 'longlora'): @@ -224,6 +233,16 @@ def prepare_adapter(args: TrainArguments, model): ) logger.info(f'reft config: {reft_config}') model = Swift.prepare_model(model, {'reft': reft_config}) + elif args.train_type == 'bone': + # Version loosing + from peft import BoneConfig + bone_config = BoneConfig( + target_modules=target_modules, + r=args.reft_rank, + init_weights=args.init_weights, + ) + logger.info(f'bone config: {bone_config}') + model = Swift.prepare_model(model, bone_config) return model diff --git a/swift/tuners/lora_layers.py b/swift/tuners/lora_layers.py index 46647f12d6..6eff1ce934 100644 --- a/swift/tuners/lora_layers.py +++ b/swift/tuners/lora_layers.py @@ -134,7 +134,6 @@ def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, module_key: st eightbit_kwargs = kwargs.copy() eightbit_kwargs.update({ 'has_fp16_weights': target.state.has_fp16_weights, - 'memory_efficient_backward': target.state.memory_efficient_backward, 'threshold': target.state.threshold, 'index': target.index, }) @@ -590,7 +589,11 @@ def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: else: raise NotImplementedError(f'Requested bias: {bias}, is not implemented.') - def inject_adapter(self, model: nn.Module, adapter_name: str): + def inject_adapter(self, + model: nn.Module, + adapter_name: str, + autocast_adapter_dtype: bool = True, + low_cpu_mem_usage: bool = False): r""" Override code: 1. ModulesToSaveWrapper construction method: add module_key=key argument to offload to cpu @@ -789,13 +792,15 @@ def _replace_module(self, parent, child_name, new_module, child): new_module.state = child.state new_module.to(child.weight.device) + meta = torch.device('meta') # dispatch to correct device for name, module in new_module.named_modules(): if (self.prefix in name) or ('ranknum' in name): weight = ( child.qweight if hasattr(child, 'qweight') else child.W_q if hasattr(child, 'W_q') else child.weight if hasattr(child, 'weight') else next(child.parameters())) - module.to(weight.device) + if not any(p.device == meta for p in module.parameters()): + module.to(weight.device) @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): diff --git a/swift/tuners/peft.py b/swift/tuners/peft.py index 37eb3ca8cd..3461cc24d3 100644 --- a/swift/tuners/peft.py +++ b/swift/tuners/peft.py @@ -12,8 +12,8 @@ import torch.nn import transformers from modelscope import snapshot_download -from peft import (AdaLoraConfig, BOFTConfig, BOFTModel, IA3Config, IA3Model, LoftQConfig, LoHaConfig, LoKrConfig, - LoraModel, OFTConfig, PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM, +from peft import (AdaLoraConfig, BOFTConfig, BOFTModel, LoftQConfig, LoHaConfig, LoKrConfig, LoraModel, OFTConfig, + PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM, PeftModelForSequenceClassification, PeftModelForTokenClassification, PrefixTuningConfig, PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, VeraConfig, VeraModel, get_peft_config, get_peft_model, get_peft_model_state_dict) @@ -28,6 +28,11 @@ except ImportError: FourierFTModel = None +try: + from peft import BoneModel +except ImportError: + BoneModel = None + logger = get_logger() dispatchers = [] @@ -280,11 +285,12 @@ def hot_patch_peft_module(): VeraModel._create_and_replace = _create_and_replace_hook BOFTModel._create_and_replace_origin = BOFTModel._create_and_replace BOFTModel._create_and_replace = _create_and_replace_hook - IA3Model._create_and_replace_origin = IA3Model._create_and_replace - IA3Model._create_and_replace = _create_and_replace_hook if FourierFTModel is not None: FourierFTModel._create_and_replace_origin = FourierFTModel._create_and_replace FourierFTModel._create_and_replace = _create_and_replace_hook + if BoneModel is not None: + BoneModel._create_and_replace_origin = BoneModel._create_and_replace + BoneModel._create_and_replace = _create_and_replace_hook # Support type conversion def __new_init__(self, model: torch.nn.Module, config: Dict[str, LoraConfig], adapter_name: str): @@ -367,7 +373,6 @@ def wrap_module(module): PromptLearningConfig = wrap_module(PromptLearningConfig) LoraConfig = wrap_module(LoraConfig) AdaLoraConfig = wrap_module(AdaLoraConfig) -IA3Config = wrap_module(IA3Config) LoHaConfig = wrap_module(LoHaConfig) LoKrConfig = wrap_module(LoKrConfig) LoftQConfig = wrap_module(LoftQConfig) diff --git a/swift/ui/llm_train/lora.py b/swift/ui/llm_train/lora.py index 6d3f87fb14..eeba2c340b 100644 --- a/swift/ui/llm_train/lora.py +++ b/swift/ui/llm_train/lora.py @@ -63,7 +63,7 @@ class LoRA(BaseUI): 'en': 'The dtype of lora parameters' } }, - 'init_lora_weights': { + 'init_weights': { 'label': { 'zh': 'lora初始化方法', 'en': 'init lora weights' @@ -99,4 +99,4 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): gr.Textbox(elem_id='lorap_lr_ratio', scale=2) gr.Checkbox(elem_id='use_rslora', scale=2) gr.Checkbox(elem_id='use_dora', scale=2) - gr.Textbox(elem_id='init_lora_weights', scale=4) + gr.Textbox(elem_id='init_weights', scale=4)