Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE loses precision for Llama / Gemma + Gemma logits.float() #29285

Merged
merged 20 commits into from Feb 28, 2024

Conversation

danielhanchen
Copy link
Contributor

Tagging @ArthurZucker :)

When I was implementing Gemma for Unsloth, I noticed when one uses bfloat16, the RoPE embeddings get autocast to bfloat16, when we require it to be in float32. This causes the positional encodings to lose precision dramatically especially for very large context lengths.

Below I pasted the image on how HF for now handles RoPE. You can see the loss in precision when using bfloat16. I manually autocasted it to float32 in Unsloth, and you can see the expected positional encodings.
image

I couldn't find why Unsloth's error could not match that of HF's original Gemma implementation. On float16, this issue does not occur, with HF and Unsloth's training loss curve being equivalent:
image

However when I switched over to bfloat16, HF and Unsloth's training losses diverge at the start, and Unsloth always retains a lower loss as training goes on:
image

If you look at the losses more carefully (same seed), you can see the differences more closely.
image

The culprit I found was

inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)

where if one uses torch.autocast(), freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) gets done in bfloat16 and not float32. I propose we turn off autocast to force float32. Ie:

with torch.autocast(device_type=position_ids_expanded.device.type, enabled=False):
    freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)

This ensures torch.autocast to turn off automatic downcasting to float16 / bfloat16 for the RoPE embeddings. My proposed fix shows the following loss curve:
image

Also, in Gemma, a 1 liner was missed :) logits = logits.float() must be placed to upcast the logits to float32. Although it should be done automatically in torch.autocast, it's best to keep the convention as done in llama, mistral and other models. Gemma's implementation seems to maybe have forgotten this 1 line :)

Llama - Force float32 since bfloat16 loses precision on long contexts
Fix RoPE and logits.float()
@danielhanchen
Copy link
Contributor Author

Forgot to add I'm not certain if this will break CUDAGraphs for faster inference - hopefully not

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to check the compile test and everything, but we usually hate these kind of changes 🫣 the bug is real, I'll see if I can find a good alternative as this is pretty much only for training! Great catch 🤗

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
@danielhanchen
Copy link
Contributor Author

Sadly unsure if it's just for training :(( For inference I don't remember up to which context length, bfloat16 won't be an issue. I think it was up to 4096. However, bfloat16 loses precision even for inference sadly after 4096 context lengths. 8192 definitely - bfloat16 essentially thinks the last 4 tokens are all position 8192 ie [8192, 8192, 8192, 8192], whilst the correct float32 is [8188, 8189, 8190, 8191].

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM let's do no grad and autocast, I'll test compile once you have both!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, before merging I'll ping @pacman100, @younesbelkada and @fxmarty as this is pretty important! Feel free to comment if you are against these changes!

def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante
Copy link
Member

gante commented Feb 27, 2024

@danielhanchen .sin() and .cos() should ideally happen in FP32 as well. Have you noticed any performance changes if you force them to happen in FP32?

@gante gante mentioned this pull request Feb 27, 2024
4 tasks
@danielhanchen
Copy link
Contributor Author

@gante Actually interesting point - I can see torch.autocast does arc sin and sinh etc in float32, but it doesnt list sin itself - I'll have to check if .sin() is done in float32 or float16

@danielhanchen
Copy link
Contributor Author

@ArthurZucker I checked everything and it's working! You guys can double check if anything is wrong. You can push the commit whenever. Thank you! :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks a mile for this.
Let's make sure you run make style and make fixup for the last CIs

src/transformers/models/gemma/modeling_gemma.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
danielhanchen and others added 6 commits February 28, 2024 21:30
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This reverts commit b860a22.
@ArthurZucker
Copy link
Collaborator

        self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

I doubt that this is the cause of the issue since we used to do that before alrady

@ArthurZucker
Copy link
Collaborator

Thanks @danielhanchen 🤗 merging now!

@ArthurZucker ArthurZucker merged commit d3a4b47 into huggingface:main Feb 28, 2024
18 checks passed
ArthurZucker added a commit that referenced this pull request Feb 28, 2024
* Update modeling_llama.py

Llama - Force float32 since bfloat16 loses precision on long contexts

* Update modeling_llama.py

* Update modeling_gemma.py

Fix RoPE and logits.float()

* @torch.no_grad()

* @torch.no_grad()

* Cos, Sin to float32

* cos, sin to float32

* Update src/transformers/models/gemma/modeling_gemma.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Resolve PR conflicts

* Fix RoPE for llama

* Revert "Fix RoPE for llama"

This reverts commit b860a22.

* Fix RoPE for llama

* RoPE device

* Autocast device type

* RoPE

* RoPE isinstance

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@danielhanchen
Copy link
Contributor Author

Yay! :) Great work everyone :)

@suryabhupa
Copy link

(hello from the Gemma team!) Superb stuff, thanks for this lovely fix :)

ArthurZucker added a commit that referenced this pull request Mar 1, 2024
* Update modeling_llama.py

Llama - Force float32 since bfloat16 loses precision on long contexts

* Update modeling_llama.py

* Update modeling_gemma.py

Fix RoPE and logits.float()

* @torch.no_grad()

* @torch.no_grad()

* Cos, Sin to float32

* cos, sin to float32

* Update src/transformers/models/gemma/modeling_gemma.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Resolve PR conflicts

* Fix RoPE for llama

* Revert "Fix RoPE for llama"

This reverts commit b860a22.

* Fix RoPE for llama

* RoPE device

* Autocast device type

* RoPE

* RoPE isinstance

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@danielhanchen
Copy link
Contributor Author

@suryabhupa Thanks :)
Actually was gonna ask do you know if Gemma uses approximate gelu or exact gelu? When comparing Keras to HF, torch.dist gets 4.7943, while tanh approx gets 0.0057

image

@paulcx
Copy link

paulcx commented Mar 6, 2024

Did anyone realize that this fix would incur additional VRAM increase likely making it OOM for the code that could have been trained?

@danielhanchen
Copy link
Contributor Author

@paulcx Yes unfortunately that was a consideration we took - the issue is results become incorrect especially on longer context lengths - shorter is fine. We tried our best to isolate changes - for inference this is fine, but for training there will be more VRAM usage. There's always Unsloth for Gemma of course, where we allow 2.5x faster finetuning and 60% VRAM reductions :) And TRL, PEFT all HF libraries are all integrated. Gemma 7b notebook: https://colab.research.google.com/drive/10NbwlsRChbma1v55m8LAPYG15uQv6HLo?usp=sharing

@ArthurZucker
Copy link
Collaborator

And though this increases VRAM usage, precision and quality of the outputs should be improved. This is also more "accurate" with respect to the original implementations

@paulcx
Copy link

paulcx commented Mar 6, 2024

@danielhanchen Thank you for your finding. It would be exciting if it could be modified on top of the existing one without at least increasing the VRAM. At least, some resource-strained tranning may not be able to train now due to this VRAM (even if the batch size is set to 1).

@ArthurZucker
Copy link
Collaborator

How much of a VRAM increase are we talking about?

@paulcx
Copy link

paulcx commented Mar 6, 2024

It's difficult to give an exact number, but the growth in VRAM requirements could depend on the size of the model, for instance, for a 34B model, it could be in the tens of GBs?

@danielhanchen
Copy link
Contributor Author

danielhanchen commented Mar 6, 2024

@paulcx Wait 10s of GBs? The PR is only for Llama and Gemma, so I'm assuming CodeLlama 34b is using more VRAM? The RoPE upcasting should only use 8192 * 8192 * 16 bits approx of extra VRAM * n layers so say 32 layers = 128MB of extra VRAM per layer, and it should be cleared away since we don't need the 8192x8192 matrix. I don't see how an extra 10GB of VRAM is coming from - are you finetuning Gemma or Llama?

I can see why Gemma might be using more VRAM - since logits.float() is necessary otherwise the softmax will use torch.float16 which causes incorrect results when training over long periods of time.

For Llama, only at max 128MB of VRAM should be used extra

@paulcx
Copy link

paulcx commented Mar 6, 2024

@danielhanchen I must say that my estimate is likely to be quite imprecise given that I have not made accurate experiments and statistics. The only evidence I have to support GBs is from my own experience with a 34B CodeLlama (which is trained successfully on transformers==4.37.2). I know it could be caused by many reasons, for example, multi-GPU training with deepspeed ... etc.

309605687-962a3bf3-db1f-4c10-b478-1de18ddbafb6

@danielhanchen
Copy link
Contributor Author

danielhanchen commented Mar 6, 2024

@paulcx OHH this is a different issue!! It seems like a change was made in another PR which allocates a causal mask of size (16384, 16384) https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L940

The triu causes the causal mask to upcast to float32, using 16384^2 * 4bytes = 1GB of extra VRAM. We have n^2 * 4 / 1024 / 1024 = 37.25GB in your screenshot, so I'm assuming you're also doing RoPE Scaling to 100K context length? So ie a (100K, 100K) matrix was trying to be created.

So this looks like a separate problem than this PR unfortunately.

@paulcx
Copy link

paulcx commented Mar 6, 2024

Thanks @danielhanchen

@gante
Copy link
Member

gante commented Mar 6, 2024

@paulcx your issue is related to this one (#29484) -- let's keep the discussion there! :)

itazap pushed a commit that referenced this pull request May 14, 2024
* Update modeling_llama.py

Llama - Force float32 since bfloat16 loses precision on long contexts

* Update modeling_llama.py

* Update modeling_gemma.py

Fix RoPE and logits.float()

* @torch.no_grad()

* @torch.no_grad()

* Cos, Sin to float32

* cos, sin to float32

* Update src/transformers/models/gemma/modeling_gemma.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/llama/modeling_llama.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Resolve PR conflicts

* Fix RoPE for llama

* Revert "Fix RoPE for llama"

This reverts commit b860a22.

* Fix RoPE for llama

* RoPE device

* Autocast device type

* RoPE

* RoPE isinstance

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants