-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DoRA uses lots of GPU VRAM due to fp32 upcasting #1692
Comments
Thanks for investigating this issue. Could you provide a bit more information so that we can reproduce this issue (ideally with a smaller model, so that we can create a unit test based on this). For instance, what dtype do you use, bf16?
What was the initial dtype for you?
If you want to consider contributing this to PEFT, let us know. |
@BenjaminBossan Thanks! I'm using bf16. Any broadly Llama-like model should work, eg. I've run tests with Meta-Llama-3-8B and Mistral-7B-v0.1 (the original, non-MoE Mistral). Here's the Triton code I wrote. When I benchmarked it, it reduced total training time by 20% and reduced GPU VRAM consumption by 1,974 MB (this is comparing it to the PyTorch implementation of (this is my first time writing Triton code, so I'm sure I did something dumb somewhere, but the tests pass and training seems stable)
|
Thanks for the info. I still couldn't quite replicate the issue, would it be possible for you to share a code snippet? We already have some code that should take care of setting the right dtype here: peft/src/peft/tuners/lora/bnb.py Lines 468 to 478 in 77b7238
But apparently, you're not hitting that code. You mention QLoRa but the code you quote is from the normal LoRA layers, so I'm a bit confused.
That sounds pretty sweet. I have never worked with triton, so I'm not sure if we can just add this code and it'll work for everyone or if more work is necessary. If you're willing to contribute this but are unsure yourself, maybe I can find someone at HF with more expertise to take a look. |
@rationalism gentle ping, do you have updates? |
@BenjaminBossan Hi! Sorry, have been busy with work, but will look at this later this week |
That would be great, thanks a lot! |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
System Info
peft 0.10.0, transformers 4.40.1, Python 3.10 on Ubuntu 22.04
Who can help?
No response
Information
Tasks
examples
folderReproduction
Doing language model fine-tuning using QLoRA with DoRA, eg. fine-tuning Meta-Llama-8-70B with https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py, with target_modules set to include all linear layers, uses much more GPU VRAM than training with ordinary LoRA.
Expected behavior
Fine-tuning a QLoRA language model using DoRA, with adapters applied to all linear layers, takes up much more GPU VRAM than ordinary LoRA and OOMed my machine. I think the issue is this line:
peft/src/peft/tuners/lora/layer.py
Line 238 in 608a90d
it looks like magnitude is in fp32, so the input vector x is upcast to fp32 when it gets returned as result_dora. If both MLP and attention layers are added to target_modules, that fp32 output then causes the next DoRA module (in the MLP layer) to get an fp32 vector as input. This then causes the dequantized weight matrix to get upcast to fp32:
peft/src/peft/tuners/lora/layer.py
Line 229 in 608a90d
which means the algebra in _get_weight_norm is done in fp32:
peft/src/peft/tuners/lora/layer.py
Line 176 in 608a90d
which OOMs my machine. Adding a cast back to x.dtype here:
peft/src/peft/tuners/lora/layer.py
Line 253 in 608a90d
fixes the problem. (I also wrote a custom Triton kernel for _get_weight_norm(), but that's probably not necessary for most purposes)
The text was updated successfully, but these errors were encountered: