Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Make DoRA work with Conv1D layers #1588

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/peft/tuners/lora/layer.py
Copy link

@arash2060 arash2060 Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 399 and 405 also needs to be updated. I needed to transpose the delta_weight in weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1) and transpose dora_factor.view(-1, 1) to match the dimensions.

Similarly for unsafe merge.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. Do you have a code snippet that results in an error without these additional changes?

Copy link

@arash2060 arash2060 Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing that I can share, unfortunately. To replicate, it should be enough to initialize a DoRA model on gpt2 and call merge_and_unload() on it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There you go:

import transformers, peft
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
dora_config = peft.LoraConfig(r=4, use_dora=True)
model = peft.get_peft_model(model, dora_config)
model.merge_and_unload('/tmp/m/')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unloading and merging model:   5%|▌         | 9/176 [00:00<00:01, 113.93it/s]
Traceback (most recent call last):
  File "...python3.9/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-4-3672d1da59ca>", line 1, in <module>
    model.merge_and_unload('/tmp/m/')
  File "...python3.9/site-packages/peft/tuners/lora/model.py", line 784, in merge_and_unload
    return self._unload_and_optionally_merge(
  File "...python3.9/site-packages/peft/tuners/lora/model.py", line 438, in _unload_and_optionally_merge
    target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
  File "...python3.9/site-packages/peft/tuners/lora/layer.py", line 420, in merge
    weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach()
  File "...python3.9/site-packages/peft/tuners/lora/layer.py", line 176, in _get_weight_norm
    weight = weight + scaling * lora_weight
RuntimeError: The size of tensor a (768) must match the size of tensor b (2304) at non-singleton dimension 1```

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, I added a fix and tests for merge_and_unload. If you could check again, that would be great.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works on my end. Thank you!

Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def loftq_init(self, adapter_name):

def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = transpose(weight, self.fan_in_fan_out)
weight = weight + scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
Expand Down Expand Up @@ -395,13 +396,16 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
weight_norm = self._get_weight_norm(orig_weights, delta_weight, scaling=1).detach()
weight_norm = self._get_weight_norm(
orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1
).detach()
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
orig_weights = dora_factor.view(-1, 1) * (orig_weights + delta_weight)
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
orig_weights = dora_factor * (orig_weights + delta_weight)

if not torch.isfinite(orig_weights).all():
raise ValueError(
Expand All @@ -416,13 +420,16 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N
else:
# handle dora
# since delta_weight already includes scaling, set it to 1 here
weight_norm = self._get_weight_norm(base_layer.weight, delta_weight, scaling=1).detach()
weight_norm = self._get_weight_norm(
base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1
).detach()
# We need to cache weight_norm because it has to be based on the original weights. We
# cannot calculate it on the fly based on the merged weights when unmerging because its a
# different value
self._cache_store(f"{active_adapter}-weight_norm", weight_norm)
dora_factor = self.lora_magnitude_vector[active_adapter] / weight_norm
new_weight = dora_factor.view(-1, 1) * (base_layer.weight.data + delta_weight)
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out)
new_weight = dora_factor * (base_layer.weight.data + delta_weight)
base_layer.weight.data = new_weight

self.merged_adapters.append(active_adapter)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,22 @@ def test_double_wrapping_merge_and_unload(self, method):

assert isinstance(unloaded.classifier, nn.Linear)

def test_gpt2_dora_merge_and_unload(self):
# see https://github.com/huggingface/peft/pull/1588#discussion_r1537914207
model = AutoModelForCausalLM.from_pretrained("gpt2")
config = LoraConfig(task_type="CAUSAL_LM", use_dora=True)
model = get_peft_model(model, config)
# should not raise an error
model.merge_and_unload()

def test_gpt2_dora_merge_and_unload_safe_merge(self):
# see https://github.com/huggingface/peft/pull/1588#discussion_r1537914207
model = AutoModelForCausalLM.from_pretrained("gpt2")
config = LoraConfig(task_type="CAUSAL_LM", use_dora=True)
model = get_peft_model(model, config)
# should not raise an error
model.merge_and_unload(safe_merge=True)


class TestMultiRankAdapter(unittest.TestCase):
"""Tests related to multirank LoRA adapters"""
Expand Down
59 changes: 55 additions & 4 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def test_causal_lm_training_4bit_dora(self):
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
load_in_4bit=True,
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
device_map="auto",
)

Expand Down Expand Up @@ -872,7 +872,7 @@ def test_causal_lm_training_multi_gpu_4bit_dora(self):
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map="auto",
load_in_4bit=True,
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)

assert set(model.hf_device_map.values()) == set(range(torch.cuda.device_count()))
Expand Down Expand Up @@ -931,7 +931,7 @@ def test_causal_lm_training_8bit_dora(self):
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
load_in_8bit=True,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map="auto",
)

Expand Down Expand Up @@ -989,7 +989,7 @@ def test_causal_lm_training_multi_gpu_8bit_dora(self):
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map="auto",
load_in_8bit=True,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)

assert set(model.hf_device_map.values()) == set(range(torch.cuda.device_count()))
Expand Down Expand Up @@ -1040,6 +1040,57 @@ def test_causal_lm_training_multi_gpu_8bit_dora(self):
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

@pytest.mark.single_gpu_tests
def test_causal_lm_training_gpt2_dora(self):
r"""
Same as test_causal_lm_training_4bit but with DoRA
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
use_dora=True,
)

model = get_peft_model(model, config)

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()

model.cpu().save_pretrained(tmp_dir)

assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)

# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None


@require_torch_gpu
@require_auto_gptq
Expand Down
Loading