Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 16 additions & 38 deletions swift/tuners/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,40 +77,22 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, subfolder: Optional
return self


def _get_target(*args, **kwargs):
target = None
if 'target' in kwargs:
target = kwargs['target']
else:
for arg in args:
if isinstance(arg, torch.nn.Module):
target = arg
break
return target


def _create_and_replace_hook(self, *args, **kwargs):
target = _get_target(*args, **kwargs)
if target and target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
return

return self._create_and_replace_origin(*args, **kwargs)


def _create_and_replace_hook2(self, *args, **kwargs):
target = _get_target(*args, **kwargs)

def _create_and_replace_hook(self, peft_config, adapter_name, target, *args, **kwargs):
all_supported_names = ('linear', )
all_supported_types = (torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D)
target_modules = getattr(peft_config, 'target_modules', None)
if target is None:
return

is_multimodal = getattr(self.model, 'is_multimodal', False)

if is_multimodal and target is not None and (not any(
if isinstance(target_modules, str) and not any(
[name in target.__class__.__name__.lower()
for name in all_supported_names]) and not any([isinstance(target, type) for type in all_supported_types])):
for name in all_supported_names]) and not any([isinstance(target, type_) for type_ in all_supported_types]):
return

return _create_and_replace_hook(self, *args, **kwargs)
if target.__class__.__name__ == 'NonDynamicallyQuantizableLinear':
return

return self._create_and_replace_origin(peft_config, adapter_name, target, *args, **kwargs)


def _convert_dtype(target: torch.nn.Module, adapter_name: str, lora_dtype: str):
Expand Down Expand Up @@ -296,28 +278,24 @@ def keep_device_forward(self, *args, **kwargs):

def hot_patch_peft_module():
from peft.tuners.lora import LoraLayer
if hasattr('LoraModel', '_create_and_replace_origin'):
return

# Fix Lora does not support NonDynamicallyQuantizableLinear
LoraModel._create_and_replace_origin = LoraModel._create_and_replace
LoraModel._create_and_replace = _create_and_replace_hook
VeraModel._create_and_replace_origin = VeraModel._create_and_replace
VeraModel._create_and_replace = _create_and_replace_hook2
VeraModel._create_and_replace = _create_and_replace_hook
BOFTModel._create_and_replace_origin = BOFTModel._create_and_replace
BOFTModel._create_and_replace = _create_and_replace_hook2
BOFTModel._create_and_replace = _create_and_replace_hook
IA3Model._create_and_replace_origin = IA3Model._create_and_replace
IA3Model._create_and_replace = _create_and_replace_hook2
IA3Model._create_and_replace = _create_and_replace_hook
if FourierFTModel is not None:
FourierFTModel._create_and_replace_origin = FourierFTModel._create_and_replace
FourierFTModel._create_and_replace = _create_and_replace_hook2
FourierFTModel._create_and_replace = _create_and_replace_hook

# Support type conversion
def init(self, model: torch.nn.Module, config: Dict[str, LoraConfig], adapter_name):
if isinstance(config, dict):
for _config in config.values(): # There is a target_modules as a string.
if isinstance(getattr(_config, 'target_modules', None), str):
# Make sure the regex can find all linear in the module.
LoraModel._create_and_replace = _create_and_replace_hook2
break

self.__init_origin__(model, config, adapter_name)
if isinstance(self.active_adapter, list):
Expand Down
Loading