-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clamping hidden state values to allow FP16 #19229
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Hi, it seems that I have a test error, however, I didn't change the code that is falling. |
Thanks @SSamDav. And sorry, I forgot to mention in order to avoid the test failure you saw yesterday, you can rebase the working branch on |
7aa8470
to
2b1376c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @SSamDav 🚀 .
The fix follows exactly what has been done for T5
👍
Ok for me if we know that it helps for inference. In T5 it didn't really help for training at all in the end, so if this was implemented to enable training in fp16, I'm not sure it's a good idea (I think @patil-suraj won't have time to look into it btw). Also cc @ArthurZucker here just FYI |
Since @patrickvonplaten prefers to have an issue where this PR will solve, I think we are not going to merge this PR at this moment. Let's see if there will be such issues reported for And regarding |
I my tests when I run a finetuned version of the |
Hi, so it is running the inference, right? Is that finetuned checkpoint uploaded to Hub? |
Yes
No, it was trained in confidential data. |
Got it. However, it would be really nice if we have a public available checkpoint (on another dataset) that can show the issue and the effect of the fix. I understand that it may not easy to obtain another such checkpoint - and potentially time consuming. @patrickvonplaten Any further comment? |
Hi! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @SSamDav for the fix!
I managed to reproduce the initial issue with the following snippet:
import torch
from transformers import AutoTokenizer, LongT5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
model = LongT5ForConditionalGeneration.from_pretrained("google/long-t5-tglobal-base", torch_dtype=torch.float16).to(0)
inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt")
input_ids = inputs.input_ids
outputs = model.encoder(input_ids.to(0))
print(outputs.last_hidden_state.isnan().any())
However, it seems that the rootcause of this issue is happening at the LongT5LayerFF
layer, exactly at this line. It seems that adding the previous hidden states with the GeLU-ed hidden states (forwarded_states
) causes overflow issues. I tried adding scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min))
here but this seems to not help as the overflow comes after the attention layer. I propose to use these changes for now as it definitely helps to get inference in fp16 working. Maybe we should still add the line scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min))
, but I can see that the attention scores are casted in fp32 before the softmax, so maybe it's not necessary - cc @ydshieh 🙏
I think that we should also add a small slow test reproducing the behavior of the snippet above!
@slow
def test_fp16_inference(self):
tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
model = LongT5ForConditionalGeneration.from_pretrained("google/long-t5-tglobal-base", torch_dtype=torch.float16).to(0)
inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt")
input_ids = inputs.input_ids
outputs = model.encoder(input_ids.to(0))
self.assertFalse(outputs.last_hidden_state.isnan().any())
I also propose to change the comments and explicitly specify that this helps for fp16 inference, not training as mentioned by @patrickvonplaten
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Hey @SSamDav, If the PR as is now solves your problem for inference - it's good for me to merge! I don't think it'll fix problems with fine-tuning though |
Hey @patrickvonplaten, good thanks for the help! |
* Clamping hidden state values to allow FP16 * Reformating * Adding missing if condition * Update src/transformers/models/longt5/modeling_longt5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/longt5/modeling_longt5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/models/longt5/modeling_longt5.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Formating file Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
What does this PR do?
Fixes # (issue)
Following the discussion in #9295 and the solution proposed by #9487.
Implement the solution that enables the FP16 for LongT5 models.
Who can review:
@patrickvonplaten, @patil-suraj