diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py index b36adea5b3..fef4ceb04f 100644 --- a/swift/tuners/lora.py +++ b/swift/tuners/lora.py @@ -84,12 +84,32 @@ def __init__( r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + use_qa_lora: bool = False, **kwargs, ): - super(ActivationMixin, - self).__init__(adapter_name, quant_linear_module, r, - lora_alpha, lora_dropout, **kwargs) + from peft.tuners.lora import LoraLayer + torch.nn.Module.__init__(self) + self.group_size = kwargs.get('group_size', None) + self.use_qa_lora = use_qa_lora + if self.use_qa_lora: + assert self.group_size is not None, 'To use qa_lora you need to pass in the `group_size` param.' + LoraLayer.__init__( + self, + in_features=quant_linear_module.infeatures + if not self.use_qa_lora else quant_linear_module.infeatures + // self.group_size, + out_features=quant_linear_module.outfeatures) + self.quant_linear_module = quant_linear_module + self.weight = quant_linear_module.qweight + init_lora_weights = kwargs.pop('init_lora_weights', True) + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, + init_lora_weights) + self.active_adapter = adapter_name super(QuantLinear, self).__init__() + if self.use_qa_lora: + self.qa_pool = torch.nn.AvgPool1d( + self.group_size + ) # using pooling layer to conduct sum operation def call_quant_linear_module(*args, **kwargs): return quant_linear_module.forward_origin(*args, **kwargs) @@ -108,12 +128,16 @@ def forward(self, x: torch.Tensor): if not torch.is_autocast_enabled(): expected_dtype = result.dtype x = x.to(self.lora_A[self.active_adapter].weight.dtype) + if self.use_qa_lora: + x = self.qa_pool(x) * self.group_size output = ( self.lora_B[self.active_adapter]( self.lora_A[self.active_adapter](self.lora_dropout[ self.active_adapter](x))).to(expected_dtype) * self.scaling[self.active_adapter]) else: + if self.use_qa_lora: + x = self.qa_pool(x) * self.group_size output = ( self.lora_B[self.active_adapter]( self.lora_A[self.active_adapter]( @@ -179,6 +203,13 @@ class LoRAConfig(SwiftConfig): 'help': 'Bias type. Values ca be "none", "all" or "lora_only"' }) + use_qa_lora: bool = field( + default=False, + metadata={ + 'help': + 'Use [qa-lora](https://github.com/yuhuixu1993/qa-lora) or not' + }) + def __post_init__(self): from .mapping import SwiftTuners self.swift_type = SwiftTuners.LORA @@ -199,7 +230,8 @@ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str): merge_weights=config.merge_weights, use_merged_linear=config.use_merged_linear, enable_lora=config.enable_lora, - fan_in_fan_out=config.fan_in_fan_out) + fan_in_fan_out=config.fan_in_fan_out, + use_qa_lora=config.use_qa_lora) def state_dict_callback(state_dict, adapter_name): return lora_state_dict(state_dict, adapter_name, config.bias) @@ -237,8 +269,9 @@ def _dynamic_patch_lora(model: torch.nn.Module, modules = {} module_keys = [key for key, _ in model.named_modules()] assert isinstance(target_modules, (str, list)) - AutoGPTQQuantLinear = get_auto_gptq_quant_linear( - get_quantization_config(model, method='gptq')) + auto_gptq_config = get_quantization_config(model, method='gptq') + AutoGPTQQuantLinear = get_auto_gptq_quant_linear(auto_gptq_config) + use_qa_lora = kwargs.pop('use_qa_lora', False) for module_key in module_keys: if isinstance(target_modules, str): @@ -292,7 +325,13 @@ def _dynamic_patch_lora(model: torch.nn.Module, **four_bit_kwargs) elif AutoGPTQQuantLinear is not None and isinstance( sub_module, AutoGPTQQuantLinear): - lora_module = QuantLinear('default', sub_module, **kwargs) + lora_module = QuantLinear( + 'default', + sub_module, + use_qa_lora=use_qa_lora, + group_size=getattr(auto_gptq_config, 'group_size', + None), + **kwargs) sub_module.weight = sub_module.qweight elif isinstance(sub_module, torch.nn.Linear): if use_merged_linear: