Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DoRA #1474

Merged
merged 21 commits into from Feb 27, 2024
Merged

Implement DoRA #1474

merged 21 commits into from Feb 27, 2024

Conversation

BenjaminBossan
Copy link
Member

@BenjaminBossan BenjaminBossan commented Feb 16, 2024

https://arxiv.org/abs/2402.09353

State:

  • Only Linear layer supported, not conv or embedding
  • Quantized layers not supported

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@nbasyl
Copy link

nbasyl commented Feb 19, 2024

@BenjaminBossan Hi Benjamin, may I have your email, I am the author of DoRA, and would like to help with integrating DoRA into PEFT.

@BenjaminBossan
Copy link
Member Author

BenjaminBossan commented Feb 19, 2024

Thanks @nbasyl. You can contact me at benjamin@ + HF domain (to clarify: huggingface.co).

Might make sense to refactor this in the future to avoid a signficant
amount of code duplication.
@BenjaminBossan
Copy link
Member Author

@nbasyl Did you reach out yet? I didn't get any mail so far.

@nbasyl
Copy link

nbasyl commented Feb 20, 2024

@BenjaminBossan Sorry for the late reply, was consulting with my manager regrading this, I just sent out the email, please check!

@RonanKMcGovern
Copy link

Great this is being added, is the PR at a point where I can test it out? Seems like I just add use_dora=True in the Lora config?

(btw, does LoRA alpha become redundant then if use_dora is True)?

@BenjaminBossan
Copy link
Member Author

@RonanKMcGovern There are still a few smaller kinks to work out, so don't test unless you're ready to test again once the next commits roll in.

LoRA alpha should still be relevant when DoRA is being used.

@RonanKMcGovern
Copy link

Just gave this a quick spin and seems to do better on ppl than LoRA with the same r. Thanks for the nice work.

@BenjaminBossan
Copy link
Member Author

Just gave this a quick spin and seems to do better on ppl than LoRA with the same r. Thanks for the nice work.

Thanks, sounds fantastic. If you can share anything like scripts, that would be great.

@aliencaocao
Copy link

Hi @nbasyl do you think the concept of differential LR by Lora+ (https://arxiv.org/abs/2402.12354) can be integrated into DoRa? Both paper came out almost exact same time and they dont seem mutually exclusive.

@BenjaminBossan
Copy link
Member Author

@aliencaocao I skimmed the paper and code and I think there is no technical limitation to combining the two. I don't know if the gains from the two approaches would be additive though.

@RonanKMcGovern
Copy link

https://arxiv.org/abs/2402.12354

very nice, is there a pr for this yet for HF?

@RonanKMcGovern
Copy link

Just gave this a quick spin and seems to do better on ppl than LoRA with the same r. Thanks for the nice work.

Thanks, sounds fantastic. If you can share anything like scripts, that would be great.

I'm using the huggingface trainer and this LoRA config for Qwen models:

from peft import LoraConfig, get_peft_model

# Initialize LoRA configuration
config = LoraConfig(
    r=8, 
    lora_alpha=32, 
    target_modules=[
      "q_proj",
      "k_proj",
      "v_proj",
      "o_proj",
      # "self_attn.rotary_emb.inv_freq",
      "gate_proj",
      "up_proj",
      "down_proj",
      # "input_layernorm.weight",
      # "post_attention_layernorm.weight",
      # "model.norm.weight",
      # "lm_head.weight"
    ],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    use_dora=True
)

@BenjaminBossan
Copy link
Member Author

very nice, is there a pr for this yet for HF?

The code is published, there is a function here that you can simply call on your LoRA model and optimizer class:

https://github.com/nikhil-ghosh-berkeley/loraplus/blob/c8d30388f5d1ec6d8c93ff67330e4d36937d48dc/loraplus.py#L31

So far, the PEFT repo does not contain any code related to training directly, as we want to keep it agnostic with regard to the training method. If there is a big interest though, we could think about adding this function as a helper function for convenience. But that's a different topic than this PR ;)

@BenjaminBossan BenjaminBossan changed the title [WIP] Try to implement DoRA [WIP] Implement DoRA Feb 22, 2024
@BenjaminBossan
Copy link
Member Author

Thanks @RonanKMcGovern. It would be much appreciated if you could try if the last few changes helped with the runtime performance.

@aliencaocao
Copy link

@BenjaminBossan On a RTX 3090 finetuning llava1.6 mistral 7b, I get half the throughput as without dora but with lora.

@sayakpaul
Copy link
Member

And with DoRA?

@aliencaocao
Copy link

What i meant was, with dora is half the speed as without dora but with lora

@sayakpaul
Copy link
Member

Oh okay. On the same rank?

@sayakpaul
Copy link
Member

If not, maybe try with a lower rank which should hopefully compensate for the throughput without disturbing the performance.

@aliencaocao
Copy link

Yes same rank

@nbasyl
Copy link

nbasyl commented Feb 28, 2024

Hi, one way to further reduce the GPU cost (# of trainable parameters) as well as the training time is to only finetune the magnitude component of certain modules while finetuning both the magnitude and directional component for the remaining modules. You can refer to Sec.5.6 of the paper for more details. This feature can be added in another PR to give user more flexibility.

@aliencaocao
Copy link

so is double the cost to be expected here because there is double the trainable params?

@nbasyl
Copy link

nbasyl commented Feb 28, 2024

If the same rank is utilized, the number of trainable parameters is only slightly higher, approximately 0.01%, compared to LoRA. However, the current computation overhead arises from the need to calculate the weight's norm as well as the calculation of the second term in the given formulation (can't reuse the first term (base_result) due to dropout alignment issue).
image

@nbasyl
Copy link

nbasyl commented Feb 28, 2024

Hi, one way to further reduce the GPU cost (# of trainable parameters) as well as the training time is to only finetune the magnitude component of certain modules while finetuning both the magnitude and directional component for the remaining modules. You can refer to Sec.5.6 of the paper for more details. This feature can be added in another PR to give user more flexibility.

@BenjaminBossan, I have already implemented this feature. Once our code is released (should be released soon), you can take a look at it and see if you can integrate this functionality with another pull request (PR). Or we can discuss via email to start working on this new PR earlier.

@BenjaminBossan
Copy link
Member Author

On a RTX 3090 finetuning llava1.6 mistral 7b, I get half the throughput as without dora but with lora.

Thanks for the feedback. On my puny 2060 with bloomz-560m and fp16, I got 15-20% slowdown during training with DoRA enabled, same rank. So there seem to be pretty big differences based on the exact model or settings. @nbasyl do you have some numbers that you could share?

one way to further reduce the GPU cost (# of trainable parameters) as well as the training time is to only finetune the magnitude component of certain modules while finetuning both the magnitude and directional component for the remaining modules. You can refer to Sec.5.6 of the paper for more details.

It would make the whole thing more complex. Are the gains big enough to be worth it?

I'll definitely take a closer look and check what else we can do to improve runtime performance while keeping the code flexible and maintainable.

@RonanKMcGovern
Copy link

BTW, this merge is causing some breaking changes on packages using the code of peft, such as unsloth: unslothai/unsloth#201

TypeError                                 Traceback (most recent call last)
Cell In[23], line 2
      1 # Do model patching and add fast LoRA weights
----> 2 model = FastLanguageModel.get_peft_model(
      3     model,
      4     r = 8,
      5     target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
      6                       "gate_proj", "up_proj", "down_proj",],
      7     lora_alpha = 32,
      8     lora_dropout = 0, # Dropout = 0 is currently optimized
      9     bias = "none",    # Bias = "none" is currently optimized
     10     use_gradient_checkpointing = True,
     11     random_state = 3407,
     12 )

File /usr/local/lib/python3.10/dist-packages/unsloth/models/llama.py:1313, in FastLlamaModel.get_peft_model(model, r, target_modules, lora_alpha, lora_dropout, bias, layers_to_transform, layers_pattern, use_gradient_checkpointing, random_state, max_seq_length, use_rslora, init_lora_weights, loftq_config, **kwargs)
   1310 if not SUPPORTS_RSLORA: del arguments["use_rslora"]
   1312 lora_config = LoraConfig(**arguments)
-> 1313 model = _get_peft_model(model, lora_config)
   1315 model = FastLlamaModel.patch_peft_model(model, use_gradient_checkpointing)
   1316 return model

File /usr/local/lib/python3.10/dist-packages/peft/mapping.py:136, in get_peft_model(model, peft_config, adapter_name, mixed)
    134 if peft_config.is_prompt_learning:
    135     peft_config = _prepare_prompt_learning_config(peft_config, model_config)
--> 136 return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)

File /usr/local/lib/python3.10/dist-packages/peft/peft_model.py:1059, in PeftModelForCausalLM.__init__(self, model, peft_config, adapter_name)
   1058 def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
-> 1059     super().__init__(model, peft_config, adapter_name)
   1060     self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation

File /usr/local/lib/python3.10/dist-packages/peft/peft_model.py:126, in PeftModel.__init__(self, model, peft_config, adapter_name)
    124     self._peft_config = None
    125     cls = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type]
--> 126     self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
    127     self.set_additional_trainable_modules(peft_config, adapter_name)
    129 if getattr(model, "is_gradient_checkpointing", True):

File /usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py:111, in LoraModel.__init__(self, model, config, adapter_name)
    110 def __init__(self, model, config, adapter_name) -> None:
--> 111     super().__init__(model, config, adapter_name)

File /usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py:147, in BaseTuner.__init__(self, model, peft_config, adapter_name)
    144         self.peft_config.update(peft_config)
    146 self.active_adapter = adapter_name
--> 147 self.inject_adapter(self.model, adapter_name)
    149 # Copy the peft_config in the injected model.
    150 self.model.peft_config = self.peft_config

File /usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py:302, in BaseTuner.inject_adapter(self, model, adapter_name)
    300     is_target_modules_in_base_model = True
    301     parent, target, target_name = _get_submodules(model, key)
--> 302     self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
    304 if not is_target_modules_in_base_model:
    305     raise ValueError(
    306         f"Target modules {peft_config.target_modules} not found in the base model. "
    307         f"Please check the target modules and try again."
    308     )

File /usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py:182, in LoraModel._create_and_replace(self, lora_config, adapter_name, target, target_name, parent, current_key)
    172     target.update_layer(
    173         adapter_name,
    174         r,
   (...)
    179         use_dora=lora_config.use_dora,
    180     )
    181 else:
--> 182     new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
    183     if adapter_name != self.active_adapter:
    184         # adding an additional adapter: it is not automatically trainable
    185         new_module.requires_grad_(False)

File /usr/local/lib/python3.10/dist-packages/peft/tuners/lora/model.py:257, in LoraModel._create_new_module(lora_config, adapter_name, target, **kwargs)
    255 new_module = None
    256 for dispatcher in dispatchers:
--> 257     new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs)
    258     if new_module is not None:  # first match wins
    259         break

File /usr/local/lib/python3.10/dist-packages/peft/tuners/lora/layer.py:858, in dispatch_default(target, adapter_name, lora_config, **kwargs)
    856         kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
    857     kwargs.update(lora_config.loftq_config)
--> 858     new_module = Linear(target, adapter_name, **kwargs)
    859 elif isinstance(target_base_layer, Conv1D):
    860     if not kwargs["fan_in_fan_out"]:

File /usr/local/lib/python3.10/dist-packages/peft/tuners/lora/layer.py:289, in Linear.__init__(self, base_layer, adapter_name, r, lora_alpha, lora_dropout, fan_in_fan_out, is_target_conv_1d_layer, init_lora_weights, use_rslora, use_dora, **kwargs)
    286 self.fan_in_fan_out = fan_in_fan_out
    288 self._active_adapter = adapter_name
--> 289 self.update_layer(
    290     adapter_name,
    291     r,
    292     lora_alpha=lora_alpha,
    293     lora_dropout=lora_dropout,
    294     init_lora_weights=init_lora_weights,
    295     use_rslora=use_rslora,
    296     use_dora=use_dora,
    297 )
    298 self.is_target_conv_1d_layer = is_target_conv_1d_layer

TypeError: LoraLayer_update_layer() got an unexpected keyword argument 'use_dora'

@nbasyl
Copy link

nbasyl commented Feb 29, 2024

On a RTX 3090 finetuning llava1.6 mistral 7b, I get half the throughput as without dora but with lora.

Thanks for the feedback. On my puny 2060 with bloomz-560m and fp16, I got 15-20% slowdown during training with DoRA enabled, same rank. So there seem to be pretty big differences based on the exact model or settings. @nbasyl do you have some numbers that you could share?

one way to further reduce the GPU cost (# of trainable parameters) as well as the training time is to only finetune the magnitude component of certain modules while finetuning both the magnitude and directional component for the remaining modules. You can refer to Sec.5.6 of the paper for more details.

It would make the whole thing more complex. Are the gains big enough to be worth it?

I'll definitely take a closer look and check what else we can do to improve runtime performance while keeping the code flexible and maintainable.

I have tried finetuning LLaMA-7B/13B on 3090/4090/V100/A100 and only also got around 20% slow down. I am assuming that such drastic slowdown for @aliencaocao is probably caused by the optimization problem of deepspeed which is used by the LLaVA code base.

@nbasyl
Copy link

nbasyl commented Feb 29, 2024

On a RTX 3090 finetuning llava1.6 mistral 7b, I get half the throughput as without dora but with lora.

Thanks for the feedback. On my puny 2060 with bloomz-560m and fp16, I got 15-20% slowdown during training with DoRA enabled, same rank. So there seem to be pretty big differences based on the exact model or settings. @nbasyl do you have some numbers that you could share?

one way to further reduce the GPU cost (# of trainable parameters) as well as the training time is to only finetune the magnitude component of certain modules while finetuning both the magnitude and directional component for the remaining modules. You can refer to Sec.5.6 of the paper for more details.

It would make the whole thing more complex. Are the gains big enough to be worth it?

I'll definitely take a closer look and check what else we can do to improve runtime performance while keeping the code flexible and maintainable.

The gain here is not the accuracy improvement, but in the reduction of trainable parameters. You can refer to this table.
image
Besides. for my case, I implemented this feature under the PEFT package framework and didn't modify the code too much, so I think it wouldn't be as complicated as you thought.

@peterjc123
Copy link

peterjc123 commented Feb 29, 2024

According to my experiment, DoRA is ~2x slower than LoRA, and I am able to achieve lower loss with DoRA. But it is a little bit not worth it compared to training with QLoRA on a larger model because that gives a much lower loss, the GPU memory usage is slighter larger and it is only 20%-30% slower.
Some numbers & settings
GPU: 1x Nvidia 3080 (10GB)
LoRA rank: 64
LoRA target modules: All except lm_head & embeddings
LoRA dropout: 0.05
Model: Qwen 1.5 - 1.8B
Max Length: 512
Packing: True
Gradient Checkpointing: True
Batch: 2
Gradient Accumulation Steps: 8
Optim: Adam8Bit
GPU mem usage: ~6800MB
Use Flash Attention2: True
Dtype: BFloat16
LoRA timings: 2.5-2.7 s/Iter
DoRA timings: 5.3-5.5 s/Iter
4-bit QLoRA timings on Qwen 1.5-4B: 6.9-7.1 s/Iter

@BenjaminBossan
Copy link
Member Author

@nbasyl Okay, this suggestion sounds good for the purpose of saving on the number of trainable parameters (even if runtime would probably not change much). I'm not sure yet how to configure this so that it would work with most or all model architectures, we'd probably have to add a new config argument for that. If your code is to be released soon, we can wait for that and add it to PEFT afterwards.

When it comes to VeRA, would DVoRA help with reducing the overhead of calculating the weight norm (since W0 and BA are fixed) or is it the same cost as DoRA?

@peterjc123 Thanks for providing your settings. I tried to replicate (using Qwen1.5-0.5B for memory reasons) and got ~30% (2060) to ~40% (T4) slower training with DoRA activated.

it is a little bit not worth it compared to training with QLoRA on a larger model

I'll investigate if we can make DoRA work with bnb. It could be a bit tricky when it comes to calculating the weight norm, let's see.

@BenjaminBossan
Copy link
Member Author

I created a PR to support DoRA with bnb (QDoRA): #1518

I tested it on a small use case and it worked. If someone wants to give it a spin and report the results, that would be fantastic.

@152334H
Copy link

152334H commented Mar 2, 2024

when launching llava training (using deepspeed --include 'localhost:0,1,3,4') with use_dora=True added here, dora works perfectly on ZeRO3, but mysteriously deadlocks on ZeRO2.

ZeRO2

with a zero 2 config, I get a deadlock:

Loading checkpoint shards: 100%|__________________________________________________________________________________| 4/4 [00:03<00:00,  1.04it/s]
Loading checkpoint shards: 100%|__________________________________________________________________________________| 4/4 [00:03<00:00,  1.09it/s]
Adding LoRA adapters...
Loading checkpoint shards: 100%|__________________________________________________________________________________| 4/4 [00:04<00:00,  1.00s/it]
Loading checkpoint shards: 100%|__________________________________________________________________________________| 4/4 [00:03<00:00,  1.15it/s]
# < --- waiting for 10 minutes here, nothing happens

ZeRO3

with a zero 3 config, I mysteriously do not (??) get a deadlock:

[2024-03-02 12:41:55,286] [INFO] [partition_parameters.py:343:__exit__] finished initializing model - num_params = 687, num_elems = 7.57B
Loading checkpoint shards: 100%|__________________________________________________________________________________| 4/4 [00:07<00:00,  1.87s/it]
Loading checkpoint shards: 100%|__________________________________________________________________________________| 4/4 [00:07<00:00,  1.89s/it]
Loading checkpoint shards: 100%|__________________________________________________________________________________| 4/4 [00:07<00:00,  1.89s/it]
Loading checkpoint shards: 100%|__________________________________________________________________________________| 4/4 [00:07<00:00,  1.96s/it]Adding LoRA adapters...
openai/clip-vit-large-patch14-336 is already loaded, `load_model` called again, skipping.                                                       openai/clip-vit-large-patch14-336 is already loaded, `load_model` called again, skipping.
openai/clip-vit-large-patch14-336 is already loaded, `load_model` called again, skipping.
openai/clip-vit-large-patch14-336 is already loaded, `load_model` called again, skipping.                                                       Formatting inputs...Skip in lazy mode
Formatting inputs...Skip in lazy mode
Parameter Offload: Total persistent parameters: 7418880 in 521 params
# < --- (training run proceeds afterwards here)

@peterjc123
Copy link

peterjc123 commented Mar 3, 2024

I created a PR to support DoRA with bnb (QDoRA): #1518

I tested it on a small use case and it worked. If someone wants to give it a spin and report the results, that would be fantastic.

@BenjaminBossan

With bf16=True, your implementation OOM when batch = 2 and the configuration I posted above. It also throws the following warning.

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in bfloat16.

And then, I debugged your code and found out that the dtype of weight_norm in apply_dora is torch.float32. So the dtype of the output of the LoRA layer is torch.float32, too.

-> weight_norm = self._get_weight_norm(weight, lora_weight, scaling)
(Pdb) s
--Call--
> peft/src/peft/tuners/lora/layer.py(172)_get_weight_norm()
-> def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
(Pdb) n
> peft/src/peft/tuners/lora/layer.py(174)_get_weight_norm()
-> weight = weight + scaling * lora_weight
(Pdb) p weight.dtype
torch.bfloat16
(Pdb) p scaling
0.25
(Pdb) p lora_weight.dtype
torch.bfloat16
(Pdb) n
> peft/src/peft/tuners/lora/layer.py(175)_get_weight_norm()
-> weight_norm = torch.linalg.norm(weight, dim=1)
(Pdb) p weight.dtype
torch.bfloat16
(Pdb) n
> peft/tuners/lora/layer.py(176)_get_weight_norm()
-> return weight_norm
(Pdb) p weight_norm.dtype
torch.float32

As you can see, torch.linalg.norm returns a float32 tensor regardless of the bfloat16 input. (probably because AMP is effective?)
So I have to change that line to

weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)

After fixing this, it is still super slow, yields about 12.96s/it.
Update: And the sad thing is that the validation loss of the trained model with DoRA+QLoRA is actually worse than that trained with original QLoRA. I guess I'll try something like LoftQ to get better initialized values for the quantized weight matrix first.

@BenjaminBossan
Copy link
Member Author

when launching llava training (using deepspeed --include 'localhost:0,1,3,4') with use_dora=True added here, dora works perfectly on ZeRO3, but mysteriously deadlocks on ZeRO2.

Very strange, typically PEFT has issues with ZeRO3, not ZeRO2. I don't know enough about DS to tell what could be the cause of the deadlock here.

So I have to change that line to

Thanks for investigating, I pushed this change to the PR.

After fixing this, it is still super slow, yields about 12.96s/it.

That's unfortunate, but not unexpected. The issue is that we need to have an additional dequantization step for DoRA, as we have to calculate the weight norm of the quantized weight + LoRA. I couldn't come up with a way to avoid this or somehow cache the results. If you or someone else has an idea, please let me know.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Mar 14, 2024
Add DoRA (Weight-Decomposed Low-Rank Adaptation).

https://arxiv.org/abs/2402.09353

To use this with LoRA, add use_dora=True to the LoraConfig.

Currently only supports nn.Linear layers, not other types or
quantized linear layers like bnb.
lora_weight = lora_B.weight @ lora_A.weight
magnitude = self.lora_magnitude_vector[active_adapter]
weight = self.get_base_layer().weight
weight_norm = self._get_weight_norm(weight, lora_weight, scaling)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be the following? There are a few other places that we need to do the same, otherwise we get mismatching dimensions.

weight_norm = self._get_weight_norm(transpose(weight, self.fan_in_fan_out), lora_weight, scaling)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. Indeed, models that use Conv1D like GPT2 wouldn't work right now. I created a PR to fix this: #1588.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants