-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support merge lora module for 4bit and 8bit linear #851
Changes from all commits
89548f9
542914a
466db90
93cf182
5c59c7d
faca2c0
53a7cf9
4f87884
6f5751d
4333a00
56d1bd2
75d416a
8091cba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -549,7 +549,7 @@ def _prepare_adapter_config(peft_config, model_config): | |||||
|
||||||
def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False): | ||||||
if merge: | ||||||
if getattr(self.model, "is_loaded_in_8bit", False) or getattr(self.model, "is_loaded_in_4bit", False): | ||||||
if getattr(self.model, "is_loaded_in_8bit", False): | ||||||
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode") | ||||||
if getattr(self.model, "quantization_method", None) == "gptq": | ||||||
raise ValueError("Cannot merge LORA layers when the model is gptq quantized") | ||||||
|
@@ -573,6 +573,17 @@ def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False): | |||||
padding=target.padding, | ||||||
dilation=target.dilation, | ||||||
) | ||||||
elif is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): | ||||||
bias = target.bias is not None | ||||||
new_module = bnb.nn.Linear4bit( | ||||||
target.in_features, | ||||||
target.out_features, | ||||||
bias=bias, | ||||||
compute_dtype=target.compute_dtype, | ||||||
compress_statistics=target.weight.compress_statistics, | ||||||
quant_type=target.weight.quant_type, | ||||||
device=target.weight.device, | ||||||
) | ||||||
else: | ||||||
bias = target.bias is not None | ||||||
if getattr(target, "is_target_conv_1d_layer", False): | ||||||
|
@@ -1193,8 +1204,49 @@ def __init__( | |||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) | ||||||
self.active_adapter = adapter_name | ||||||
|
||||||
def merge(self): | ||||||
if self.active_adapter not in self.lora_A.keys(): | ||||||
return | ||||||
if self.merged: | ||||||
warnings.warn("Already merged. Nothing to do.") | ||||||
return | ||||||
if self.r[self.active_adapter] > 0: | ||||||
warnings.warn( | ||||||
"Merge lora module to 4-bit linear may get different generations due to rounding errors." | ||||||
) | ||||||
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 | ||||||
kwargs = self.weight.__dict__ | ||||||
lora_data = self.get_delta_weight(self.active_adapter) | ||||||
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure but I think that you need to specify the
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your comment. Actually, there is no need to pass I have added a comment that refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 before 4bit merge There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks a lot! |
||||||
self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device) | ||||||
self.merged = True | ||||||
|
||||||
def unmerge(self): | ||||||
if self.active_adapter not in self.lora_A.keys(): | ||||||
return | ||||||
if not self.merged: | ||||||
warnings.warn("Already unmerged. Nothing to do.") | ||||||
return | ||||||
if self.r[self.active_adapter] > 0: | ||||||
warnings.warn( | ||||||
"Unmerge lora module to 4-bit linear may get different generations due to rounding errors." | ||||||
) | ||||||
kwargs = self.weight.__dict__ | ||||||
lora_data = self.get_delta_weight(self.active_adapter) | ||||||
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) - lora_data | ||||||
self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device) | ||||||
self.merged = False | ||||||
|
||||||
def get_delta_weight(self, adapter): | ||||||
return ( | ||||||
transpose( | ||||||
self.lora_B[adapter].weight @ self.lora_A[adapter].weight, | ||||||
False, | ||||||
) | ||||||
* self.scaling[adapter] | ||||||
) | ||||||
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||||||
# note: logic differs from default Linear because merging is not supported | ||||||
result = super().forward(x) | ||||||
|
||||||
if ( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A reference for merging of 4 bit weights that was shared on Twitter by Tim Dettmers: https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930