Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions swift/tuners/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _feed_forward_chunk(self, attention_output):
setattr(module, config.method_name,
types.MethodType(_forward, module))
adapter_module = AdapterModule(config.dim, adapter_name,
module_key,
config.adapter_length,
ACT2CLS[config.act_layer])
setattr(module, f'adapter_{adapter_name}', adapter_module)
Expand All @@ -152,13 +153,17 @@ def mark_trainable_callback(model):
mark_trainable_callback)

@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str,
activate: bool):
modules: List[torch.nn.Module] = find_sub_module(
module, f'adapter_{adapter_name}')
def activate_adapter(module: torch.nn.Module,
adapter_name: str,
activate: bool,
offload: str = None):
modules = find_sub_module(module, f'adapter_{adapter_name}')
for _module in modules:
_module: ActivationMixin
_module: nn.Module
_module.set_activation(adapter_name, activate)
SwiftAdapter.save_memory(_module, adapter_name, _module.module_key,
activate, offload)


class AdapterModule(nn.Module, ActivationMixin):
Expand All @@ -177,11 +182,12 @@ def __init__(
self,
dim,
adapter_name,
module_key,
adapter_length=None,
act_layer=nn.GELU,
):
super(AdapterModule, self).__init__()
super(nn.Module, self).__init__()
super(nn.Module, self).__init__(module_key)
self.dim = dim
self.adapter_name = adapter_name
self.adapter_length = adapter_length
Expand Down
10 changes: 6 additions & 4 deletions swift/tuners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,9 @@ def save_pretrained(self,
def base_model(self):
return self.model

def set_active_adapters(self, adapter_names: Union[List[str], str]):
def set_active_adapters(self,
adapter_names: Union[List[str], str],
offload=None):
if not adapter_names:
return

Expand All @@ -444,7 +446,7 @@ def set_active_adapters(self, adapter_names: Union[List[str], str]):
self.activate_adapter(adapter_name)

for adapter_name in (set(self.adapters.keys()) - adapter_names):
self.deactivate_adapter(adapter_name)
self.deactivate_adapter(adapter_name, offload)

def activate_adapter(self, adapter_name):
if adapter_name not in self.adapters:
Expand All @@ -456,15 +458,15 @@ def activate_adapter(self, adapter_name):
SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\
.activate_adapter(self.base_model, adapter_name, True)

def deactivate_adapter(self, adapter_name):
def deactivate_adapter(self, adapter_name, offload=None):
if adapter_name not in self.adapters:
logger.warning(
f'{adapter_name} not in adapters: {self.adapters.keys()}')
return

from .mapping import SWIFT_MAPPING
SWIFT_MAPPING[self.adapters[adapter_name].config.swift_type][1]\
.activate_adapter(self.base_model, adapter_name, False)
.activate_adapter(self.base_model, adapter_name, False, offload=offload)

def get_trainable_parameters(self):
"""
Expand Down
10 changes: 7 additions & 3 deletions swift/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from packaging import version
from peft.tuners.lora import LoraLayer

from swift import LoraConfig
from .lora_layers import * # noqa
Expand Down Expand Up @@ -69,12 +70,15 @@ def mark_trainable_callback(model):
mark_trainable_callback)

@staticmethod
def activate_adapter(module: torch.nn.Module, adapter_name: str,
activate: bool):
set_adapter(module, adapter_name, activate)
def activate_adapter(module: torch.nn.Module,
adapter_name: str,
activate: bool,
offload: str = None):
set_adapter(module, adapter_name, activate, offload)
for sub_module in module.modules():
if isinstance(sub_module, (LoraLayer, LoRALayer)):
sub_module.set_activation(adapter_name, activate)
sub_module.save_memory(adapter_name, activate, offload)

@staticmethod
def unpatch_lora(model, config: LoRAConfig, adapter_name: str):
Expand Down
96 changes: 72 additions & 24 deletions swift/tuners/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from peft.tuners.lora import Conv2d as _Conv2d
from peft.tuners.lora import Embedding as _Embedding
from peft.tuners.lora import Linear as _Linear
from peft.tuners.lora import LoraLayer
from peft.tuners.lora import LoraModel as _LoraModel
from peft.tuners.lora.tp_layer import LoraParallelLinear as _LoraParallelLinear
from peft.tuners.tuners_utils import BaseTunerLayer
Expand All @@ -25,7 +24,7 @@
from transformers import Conv1D

from swift import get_logger
from .utils import ActivationMixin, ModulesToSaveWrapper
from .utils import ActivationMixin, ModulesToSaveWrapper, SwiftAdapter

logger = get_logger()

Expand All @@ -52,7 +51,7 @@ def active_adapters(self):
def active_adapter(self) -> str:
return self.get_activated_adapters()

def set_adapter(self, adapter_names):
def set_adapter(self, adapter_names, offload=None):
if isinstance(adapter_names, str):
adapter_names = [adapter_names]

Expand All @@ -63,9 +62,28 @@ def set_adapter(self, adapter_names):
if key in adapter_names:
self.set_activation(key, True)
layer.requires_grad_(True)
SwiftAdapter.save_memory(layer, key, self.module_key, True)
else:
self.set_activation(key, False)
layer.requires_grad_(False)
SwiftAdapter.save_memory(
layer, key, self.module_key, False, offload=offload)

def save_memory(self, adapter_name, activate, offload=None):
for layer_name in self.adapter_layer_names:
module_dict = getattr(self, layer_name)
for key, layer in module_dict.items():
if key == adapter_name:
if activate:
SwiftAdapter.save_memory(layer, layer_name + '.' + key,
self.module_key, True)
else:
SwiftAdapter.save_memory(
layer,
layer_name + '.' + key,
self.module_key,
False,
offload=offload)

def merge(self, *args, **kwargs):
if not self.unique_thread:
Expand All @@ -85,9 +103,10 @@ class Linear8bitLt(LoRAActivationMixin, _Linear8bitLt):
def __init__(
self,
*args,
module_key: str,
**kwargs,
):
super(Linear8bitLt, self).__init__()
super(Linear8bitLt, self).__init__(module_key)
self.set_activation(args[1], True)
super(ActivationMixin, self).__init__(*args, **kwargs)

Expand All @@ -100,9 +119,10 @@ class Linear4bit(LoRAActivationMixin, _Linear4bit):
def __init__(
self,
*args,
module_key: str,
**kwargs,
):
super(Linear4bit, self).__init__()
super(Linear4bit, self).__init__(module_key)
self.set_activation(args[1], True)
super(ActivationMixin, self).__init__(*args, **kwargs)

Expand All @@ -117,9 +137,10 @@ def __init__(
*args,
use_qa_lora=False,
group_size=None,
module_key: str,
**kwargs,
):
super(QuantLinear, self).__init__()
super(QuantLinear, self).__init__(module_key)
self.set_activation(args[1], True)
super(ActivationMixin, self).__init__(*args, **kwargs)
self.group_size = group_size
Expand Down Expand Up @@ -166,33 +187,34 @@ class Embedding(LoRAActivationMixin, _Embedding):
def __init__(
self,
*args,
module_key: str,
**kwargs,
) -> None:
super(Embedding, self).__init__()
super(Embedding, self).__init__(module_key)
self.set_activation(args[1], True)
super(ActivationMixin, self).__init__(*args, **kwargs)


class Linear(LoRAActivationMixin, _Linear):

def __init__(self, *args, **kwargs):
super(Linear, self).__init__()
def __init__(self, *args, module_key: str, **kwargs):
super(Linear, self).__init__(module_key)
self.set_activation(args[1], True)
super(ActivationMixin, self).__init__(*args, **kwargs)


class Conv2d(LoRAActivationMixin, _Conv2d):

def __init__(self, *args, **kwargs):
super(Conv2d, self).__init__()
def __init__(self, *args, module_key: str, **kwargs):
super(Conv2d, self).__init__(module_key)
self.set_activation(args[1], True)
super(ActivationMixin, self).__init__(*args, **kwargs)


class LoraParallelLinear(LoRAActivationMixin, _LoraParallelLinear):

def __init__(self, *args, **kwargs):
super(LoraParallelLinear, self).__init__()
def __init__(self, *args, module_key: str, **kwargs):
super(LoraParallelLinear, self).__init__(module_key)
self.set_activation(args[1], True)
super(ActivationMixin, self).__init__(*args, **kwargs)

Expand Down Expand Up @@ -249,7 +271,8 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
parent, target, target_name = _get_submodules(model, key)

if not isinstance(target, ModulesToSaveWrapper):
new_module = ModulesToSaveWrapper(target, adapter_name)
new_module = ModulesToSaveWrapper(
target, adapter_name, module_key=key)
setattr(parent, target_name, new_module)
else:
target.update(adapter_name)
Expand Down Expand Up @@ -384,8 +407,12 @@ def _create_and_replace(
)
self._convert_dtype(target, lora_config.lora_dtype)
else:
new_module = self._create_new_module(lora_config, adapter_name,
target, **kwargs)
new_module = self._create_new_module(
lora_config,
adapter_name,
target,
current_key=current_key,
**kwargs)
if new_module is not None:
if adapter_name != self.active_adapter:
# adding an additional adapter: it is not automatically trainable
Expand All @@ -395,6 +422,7 @@ def _create_and_replace(

@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
current_key = kwargs.pop('current_key')
gptq_quantization_config = kwargs.get('gptq_quantization_config', None)
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(
gptq_quantization_config)
Expand Down Expand Up @@ -422,7 +450,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
'threshold': target.state.threshold,
'index': target.index,
})
new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
new_module = Linear8bitLt(
target,
adapter_name,
module_key=current_key,
**eightbit_kwargs)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(
target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
Expand All @@ -434,19 +466,26 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
'quant_type':
target_base_layer.weight.quant_type,
})
new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
new_module = Linear4bit(
target, adapter_name, module_key=current_key, **fourbit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(
target_base_layer, AutoGPTQQuantLinear):
new_module = QuantLinear(target, adapter_name, **kwargs)
new_module = QuantLinear(
target, adapter_name, module_key=current_key, **kwargs)
target.qweight = target_base_layer.qweight
elif isinstance(target_base_layer, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop('fan_in_fan_out', None)
embedding_kwargs.update(lora_config.loftq_config)
new_module = Embedding(target, adapter_name, **embedding_kwargs)
new_module = Embedding(
target,
adapter_name,
module_key=current_key,
**embedding_kwargs)
elif isinstance(target_base_layer, torch.nn.Conv2d):
kwargs.update(lora_config.loftq_config)
new_module = Conv2d(target, adapter_name, **kwargs)
new_module = Conv2d(
target, adapter_name, module_key=current_key, **kwargs)
elif lora_config.use_merged_linear:
new_module = MergedLinear(
adapter_name,
Expand All @@ -461,7 +500,8 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
'Setting fan_in_fan_out to False.')
kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = False
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, **kwargs)
new_module = Linear(
target, adapter_name, module_key=current_key, **kwargs)
elif megatron_core and isinstance(
target_base_layer, # noqa
( # noqa
Expand All @@ -486,6 +526,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
new_module = LoraParallelLinear(
base_layer=target,
adapter_name=adapter_name,
module_key=current_key,
backend=megatron_core.tensor_parallel,
**megatron_kwargs)
elif isinstance(target_base_layer, Conv1D):
Expand All @@ -496,7 +537,11 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
kwargs['fan_in_fan_out'] = lora_config.fan_in_fan_out = True
kwargs.update(lora_config.loftq_config)
new_module = Linear(
target, adapter_name, is_target_conv_1d_layer=True, **kwargs)
target,
adapter_name,
module_key=current_key,
is_target_conv_1d_layer=True,
**kwargs)
else:
logger.debug(
f'Target module {target} is not supported. Currently, only the following modules are supported: '
Expand All @@ -512,12 +557,13 @@ class LoRALayer(ActivationMixin):
def __init__(
self,
adapter_name: str,
module_key: str,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
super().__init__()
super().__init__(module_key)
self.adapter_name = adapter_name
self.r = r
self.lora_alpha = lora_alpha
Expand All @@ -537,6 +583,7 @@ class MergedLinear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(self,
adapter_name: str,
module_key: str,
base_layer: nn.Linear,
r: int = 0,
lora_alpha: int = 1,
Expand All @@ -558,6 +605,7 @@ def __init__(self,
LoRALayer.__init__(
self,
adapter_name,
module_key,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
Expand Down
Loading