diff --git a/swift/tuners/peft.py b/swift/tuners/peft.py index 380f42ed58..12a8aa810a 100644 --- a/swift/tuners/peft.py +++ b/swift/tuners/peft.py @@ -77,40 +77,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional return self -def _get_target(*args, **kwargs): - target = None - if 'target' in kwargs: - target = kwargs['target'] - else: - for arg in args: - if isinstance(arg, torch.nn.Module): - target = arg - break - return target - - -def _create_and_replace_hook(self, *args, **kwargs): - target = _get_target(*args, **kwargs) - if target and target.__class__.__name__ == 'NonDynamicallyQuantizableLinear': - return - - return self._create_and_replace_origin(*args, **kwargs) - - -def _create_and_replace_hook2(self, *args, **kwargs): - target = _get_target(*args, **kwargs) - +def _create_and_replace_hook(self, peft_config, adapter_name, target, *args, **kwargs): all_supported_names = ('linear', ) all_supported_types = (torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D) + target_modules = getattr(peft_config, 'target_modules', None) + if target is None: + return - is_multimodal = getattr(self.model, 'is_multimodal', False) - - if is_multimodal and target is not None and (not any( + if isinstance(target_modules, str) and not any( [name in target.__class__.__name__.lower() - for name in all_supported_names]) and not any([isinstance(target, type) for type in all_supported_types])): + for name in all_supported_names]) and not any([isinstance(target, type_) for type_ in all_supported_types]): return - return _create_and_replace_hook(self, *args, **kwargs) + if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear': + return + + return self._create_and_replace_origin(peft_config, adapter_name, target, *args, **kwargs) def _convert_dtype(target: torch.nn.Module, adapter_name: str, lora_dtype: str): @@ -296,28 +278,24 @@ def keep_device_forward(self, *args, **kwargs): def hot_patch_peft_module(): from peft.tuners.lora import LoraLayer + if hasattr('LoraModel', '_create_and_replace_origin'): + return # Fix Lora does not support NonDynamicallyQuantizableLinear LoraModel._create_and_replace_origin = LoraModel._create_and_replace LoraModel._create_and_replace = _create_and_replace_hook VeraModel._create_and_replace_origin = VeraModel._create_and_replace - VeraModel._create_and_replace = _create_and_replace_hook2 + VeraModel._create_and_replace = _create_and_replace_hook BOFTModel._create_and_replace_origin = BOFTModel._create_and_replace - BOFTModel._create_and_replace = _create_and_replace_hook2 + BOFTModel._create_and_replace = _create_and_replace_hook IA3Model._create_and_replace_origin = IA3Model._create_and_replace - IA3Model._create_and_replace = _create_and_replace_hook2 + IA3Model._create_and_replace = _create_and_replace_hook if FourierFTModel is not None: FourierFTModel._create_and_replace_origin = FourierFTModel._create_and_replace - FourierFTModel._create_and_replace = _create_and_replace_hook2 + FourierFTModel._create_and_replace = _create_and_replace_hook # Support type conversion def init(self, model: torch.nn.Module, config: Dict[str, LoraConfig], adapter_name): - if isinstance(config, dict): - for _config in config.values(): # There is a target_modules as a string. - if isinstance(getattr(_config, 'target_modules', None), str): - # Make sure the regex can find all linear in the module. - LoraModel._create_and_replace = _create_and_replace_hook2 - break self.__init_origin__(model, config, adapter_name) if isinstance(self.active_adapter, list):