New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RoPE loses precision for Llama / Gemma + Gemma logits.float() #29285
Conversation
Llama - Force float32 since bfloat16 loses precision on long contexts
Fix RoPE and logits.float()
Forgot to add I'm not certain if this will break CUDAGraphs for faster inference - hopefully not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have to check the compile test and everything, but we usually hate these kind of changes 🫣 the bug is real, I'll see if I can find a good alternative as this is pretty much only for training! Great catch 🤗
Sadly unsure if it's just for training :(( For inference I don't remember up to which context length, bfloat16 won't be an issue. I think it was up to 4096. However, bfloat16 loses precision even for inference sadly after 4096 context lengths. 8192 definitely - bfloat16 essentially thinks the last 4 tokens are all position 8192 ie [8192, 8192, 8192, 8192], whilst the correct float32 is [8188, 8189, 8190, 8191]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM let's do no grad and autocast, I'll test compile once you have both!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, before merging I'll ping @pacman100, @younesbelkada and @fxmarty as this is pretty important! Feel free to comment if you are against these changes!
def forward(self, x, position_ids, seq_len=None): | ||
# x: [bs, num_attention_heads, seq_len, head_size] | ||
if self.inv_freq is None: | ||
self.inv_freq = 1.0 / ( | ||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) | ||
) | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@danielhanchen |
@gante Actually interesting point - I can see torch.autocast does arc sin and sinh etc in float32, but it doesnt list sin itself - I'll have to check if .sin() is done in float32 or float16 |
@ArthurZucker I checked everything and it's working! You guys can double check if anything is wrong. You can push the commit whenever. Thank you! :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks a mile for this.
Let's make sure you run make style
and make fixup
for the last CIs
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This reverts commit b860a22.
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) I doubt that this is the cause of the issue since we used to do that before alrady |
Thanks @danielhanchen 🤗 merging now! |
* Update modeling_llama.py Llama - Force float32 since bfloat16 loses precision on long contexts * Update modeling_llama.py * Update modeling_gemma.py Fix RoPE and logits.float() * @torch.no_grad() * @torch.no_grad() * Cos, Sin to float32 * cos, sin to float32 * Update src/transformers/models/gemma/modeling_gemma.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Resolve PR conflicts * Fix RoPE for llama * Revert "Fix RoPE for llama" This reverts commit b860a22. * Fix RoPE for llama * RoPE device * Autocast device type * RoPE * RoPE isinstance --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Yay! :) Great work everyone :) |
(hello from the Gemma team!) Superb stuff, thanks for this lovely fix :) |
* Update modeling_llama.py Llama - Force float32 since bfloat16 loses precision on long contexts * Update modeling_llama.py * Update modeling_gemma.py Fix RoPE and logits.float() * @torch.no_grad() * @torch.no_grad() * Cos, Sin to float32 * cos, sin to float32 * Update src/transformers/models/gemma/modeling_gemma.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Resolve PR conflicts * Fix RoPE for llama * Revert "Fix RoPE for llama" This reverts commit b860a22. * Fix RoPE for llama * RoPE device * Autocast device type * RoPE * RoPE isinstance --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@suryabhupa Thanks :) |
Did anyone realize that this fix would incur additional VRAM increase likely making it OOM for the code that could have been trained? |
@paulcx Yes unfortunately that was a consideration we took - the issue is results become incorrect especially on longer context lengths - shorter is fine. We tried our best to isolate changes - for inference this is fine, but for training there will be more VRAM usage. There's always Unsloth for Gemma of course, where we allow 2.5x faster finetuning and 60% VRAM reductions :) And TRL, PEFT all HF libraries are all integrated. Gemma 7b notebook: https://colab.research.google.com/drive/10NbwlsRChbma1v55m8LAPYG15uQv6HLo?usp=sharing |
And though this increases VRAM usage, precision and quality of the outputs should be improved. This is also more "accurate" with respect to the original implementations |
@danielhanchen Thank you for your finding. It would be exciting if it could be modified on top of the existing one without at least increasing the VRAM. At least, some resource-strained tranning may not be able to train now due to this VRAM (even if the batch size is set to 1). |
How much of a VRAM increase are we talking about? |
It's difficult to give an exact number, but the growth in VRAM requirements could depend on the size of the model, for instance, for a 34B model, it could be in the tens of GBs? |
@paulcx Wait 10s of GBs? The PR is only for Llama and Gemma, so I'm assuming CodeLlama 34b is using more VRAM? The RoPE upcasting should only use 8192 * 8192 * 16 bits approx of extra VRAM * n layers so say 32 layers = 128MB of extra VRAM per layer, and it should be cleared away since we don't need the 8192x8192 matrix. I don't see how an extra 10GB of VRAM is coming from - are you finetuning Gemma or Llama? I can see why Gemma might be using more VRAM - since For Llama, only at max 128MB of VRAM should be used extra |
@danielhanchen I must say that my estimate is likely to be quite imprecise given that I have not made accurate experiments and statistics. The only evidence I have to support GBs is from my own experience with a 34B CodeLlama (which is trained successfully on transformers==4.37.2). I know it could be caused by many reasons, for example, multi-GPU training with deepspeed ... etc. |
@paulcx OHH this is a different issue!! It seems like a change was made in another PR which allocates a causal mask of size (16384, 16384) https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L940 The triu causes the causal mask to upcast to float32, using 16384^2 * 4bytes = 1GB of extra VRAM. We have So this looks like a separate problem than this PR unfortunately. |
Thanks @danielhanchen |
* Update modeling_llama.py Llama - Force float32 since bfloat16 loses precision on long contexts * Update modeling_llama.py * Update modeling_gemma.py Fix RoPE and logits.float() * @torch.no_grad() * @torch.no_grad() * Cos, Sin to float32 * cos, sin to float32 * Update src/transformers/models/gemma/modeling_gemma.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Resolve PR conflicts * Fix RoPE for llama * Revert "Fix RoPE for llama" This reverts commit b860a22. * Fix RoPE for llama * RoPE device * Autocast device type * RoPE * RoPE isinstance --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Tagging @ArthurZucker :)
When I was implementing Gemma for Unsloth, I noticed when one uses bfloat16, the RoPE embeddings get autocast to bfloat16, when we require it to be in float32. This causes the positional encodings to lose precision dramatically especially for very large context lengths.
Below I pasted the image on how HF for now handles RoPE. You can see the loss in precision when using bfloat16. I manually autocasted it to float32 in Unsloth, and you can see the expected positional encodings.
I couldn't find why Unsloth's error could not match that of HF's original Gemma implementation. On float16, this issue does not occur, with HF and Unsloth's training loss curve being equivalent:
However when I switched over to bfloat16, HF and Unsloth's training losses diverge at the start, and Unsloth always retains a lower loss as training goes on:
If you look at the losses more carefully (same seed), you can see the differences more closely.
The culprit I found was
where if one uses
torch.autocast()
,freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
gets done in bfloat16 and not float32. I propose we turn off autocast to force float32. Ie:This ensures
torch.autocast
to turn off automatic downcasting to float16 / bfloat16 for the RoPE embeddings. My proposed fix shows the following loss curve:Also, in Gemma, a 1 liner was missed :)
logits = logits.float()
must be placed to upcast thelogits
to float32. Although it should be done automatically intorch.autocast
, it's best to keep the convention as done in llama, mistral and other models. Gemma's implementation seems to maybe have forgotten this 1 line :)