Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support merge lora module for 4bit and 8bit linear #851

Merged
merged 13 commits into from
Aug 28, 2023
56 changes: 54 additions & 2 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def _prepare_adapter_config(peft_config, model_config):

def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False):
if merge:
if getattr(self.model, "is_loaded_in_8bit", False) or getattr(self.model, "is_loaded_in_4bit", False):
if getattr(self.model, "is_loaded_in_8bit", False):
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
Expand All @@ -573,6 +573,17 @@ def _unload_and_optionally_merge(self, merge=True, progressbar: bool = False):
padding=target.padding,
dilation=target.dilation,
)
elif is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
bias = target.bias is not None
new_module = bnb.nn.Linear4bit(
target.in_features,
target.out_features,
bias=bias,
compute_dtype=target.compute_dtype,
compress_statistics=target.weight.compress_statistics,
quant_type=target.weight.quant_type,
device=target.weight.device,
)
else:
bias = target.bias is not None
if getattr(target, "is_target_conv_1d_layer", False):
Expand Down Expand Up @@ -1193,8 +1204,49 @@ def __init__(
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name

def merge(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A reference for merging of 4 bit weights that was shared on Twitter by Tim Dettmers: https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930

if self.active_adapter not in self.lora_A.keys():
return
if self.merged:
warnings.warn("Already merged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
warnings.warn(
"Merge lora module to 4-bit linear may get different generations due to rounding errors."
)
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
kwargs = self.weight.__dict__
lora_data = self.get_delta_weight(self.active_adapter)
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure but I think that you need to specify the quant_type here and below

Suggested change
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state, quant_type=self.weight.quant_type) + lora_data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment. Actually, there is no need to pass quant_type because it is already in quant_state.

I have added a comment that refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 before 4bit merge

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device)
self.merged = True

def unmerge(self):
if self.active_adapter not in self.lora_A.keys():
return
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
warnings.warn(
"Unmerge lora module to 4-bit linear may get different generations due to rounding errors."
)
kwargs = self.weight.__dict__
lora_data = self.get_delta_weight(self.active_adapter)
w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) - lora_data
self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device)
self.merged = False

def get_delta_weight(self, adapter):
return (
transpose(
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
False,
)
* self.scaling[adapter]
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# note: logic differs from default Linear because merging is not supported
result = super().forward(x)

if (
Expand Down
36 changes: 36 additions & 0 deletions tests/test_common_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@


if is_bnb_available():
import bitsandbytes as bnb

from peft.tuners.lora import Linear8bitLt

if is_bnb_4bit_available():
Expand Down Expand Up @@ -356,3 +358,37 @@ def test_modules_to_save_grad(self):
self.assertTrue(modules_to_save.weight.requires_grad is True)
self.assertTrue(original_module.weight.grad is None)
self.assertTrue(modules_to_save.weight.grad is not None)

@require_torch_gpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_merge_lora(self):
torch.manual_seed(3000)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_type=torch.float32,
)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
quantization_config=bnb_config,
torch_dtype=torch.float32,
)
config = LoraConfig(
r=8,
init_lora_weights=False,
)
model = get_peft_model(model, config)

random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
with torch.inference_mode():
out_before_merge = model.generate(random_input, max_new_tokens=1)

model.merge_and_unload("default")
with torch.inference_mode():
out_after_merge = model.generate(random_input, max_new_tokens=1)

self.assertTrue(torch.equal(out_before_merge, out_after_merge))
self.assertTrue(isinstance(model, PeftModel))
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear4bit))
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear4bit))
Loading