Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions swift/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,32 @@ def __init__(
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
use_qa_lora: bool = False,
**kwargs,
):
super(ActivationMixin,
self).__init__(adapter_name, quant_linear_module, r,
lora_alpha, lora_dropout, **kwargs)
from peft.tuners.lora import LoraLayer
torch.nn.Module.__init__(self)
self.group_size = kwargs.get('group_size', None)
self.use_qa_lora = use_qa_lora
if self.use_qa_lora:
assert self.group_size is not None, 'To use qa_lora you need to pass in the `group_size` param.'
LoraLayer.__init__(
self,
in_features=quant_linear_module.infeatures
if not self.use_qa_lora else quant_linear_module.infeatures
// self.group_size,
out_features=quant_linear_module.outfeatures)
self.quant_linear_module = quant_linear_module
self.weight = quant_linear_module.qweight
init_lora_weights = kwargs.pop('init_lora_weights', True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout,
init_lora_weights)
self.active_adapter = adapter_name
super(QuantLinear, self).__init__()
if self.use_qa_lora:
self.qa_pool = torch.nn.AvgPool1d(
self.group_size
) # using pooling layer to conduct sum operation

def call_quant_linear_module(*args, **kwargs):
return quant_linear_module.forward_origin(*args, **kwargs)
Expand All @@ -108,12 +128,16 @@ def forward(self, x: torch.Tensor):
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
if self.use_qa_lora:
x = self.qa_pool(x) * self.group_size
output = (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[
self.active_adapter](x))).to(expected_dtype)
* self.scaling[self.active_adapter])
else:
if self.use_qa_lora:
x = self.qa_pool(x) * self.group_size
output = (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](
Expand Down Expand Up @@ -179,6 +203,13 @@ class LoRAConfig(SwiftConfig):
'help': 'Bias type. Values ca be "none", "all" or "lora_only"'
})

use_qa_lora: bool = field(
default=False,
metadata={
'help':
'Use [qa-lora](https://github.com/yuhuixu1993/qa-lora) or not'
})

def __post_init__(self):
from .mapping import SwiftTuners
self.swift_type = SwiftTuners.LORA
Expand All @@ -199,7 +230,8 @@ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
merge_weights=config.merge_weights,
use_merged_linear=config.use_merged_linear,
enable_lora=config.enable_lora,
fan_in_fan_out=config.fan_in_fan_out)
fan_in_fan_out=config.fan_in_fan_out,
use_qa_lora=config.use_qa_lora)

def state_dict_callback(state_dict, adapter_name):
return lora_state_dict(state_dict, adapter_name, config.bias)
Expand Down Expand Up @@ -237,8 +269,9 @@ def _dynamic_patch_lora(model: torch.nn.Module,
modules = {}
module_keys = [key for key, _ in model.named_modules()]
assert isinstance(target_modules, (str, list))
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(
get_quantization_config(model, method='gptq'))
auto_gptq_config = get_quantization_config(model, method='gptq')
AutoGPTQQuantLinear = get_auto_gptq_quant_linear(auto_gptq_config)
use_qa_lora = kwargs.pop('use_qa_lora', False)

for module_key in module_keys:
if isinstance(target_modules, str):
Expand Down Expand Up @@ -292,7 +325,13 @@ def _dynamic_patch_lora(model: torch.nn.Module,
**four_bit_kwargs)
elif AutoGPTQQuantLinear is not None and isinstance(
sub_module, AutoGPTQQuantLinear):
lora_module = QuantLinear('default', sub_module, **kwargs)
lora_module = QuantLinear(
'default',
sub_module,
use_qa_lora=use_qa_lora,
group_size=getattr(auto_gptq_config, 'group_size',
None),
**kwargs)
sub_module.weight = sub_module.qweight
elif isinstance(sub_module, torch.nn.Linear):
if use_merged_linear:
Expand Down