Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clamping hidden state values to allow FP16 #19229

Merged
merged 7 commits into from
Oct 4, 2022

Conversation

SSamDav
Copy link
Contributor

@SSamDav SSamDav commented Sep 28, 2022

What does this PR do?

Fixes # (issue)
Following the discussion in #9295 and the solution proposed by #9487.

Implement the solution that enables the FP16 for LongT5 models.

Who can review:
@patrickvonplaten, @patil-suraj

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 28, 2022

The documentation is not available anymore as the PR was closed or merged.

@SSamDav
Copy link
Contributor Author

SSamDav commented Sep 28, 2022

Hi, it seems that I have a test error, however, I didn't change the code that is falling.
Does anyone know how I can pass these tests?

@ydshieh
Copy link
Collaborator

ydshieh commented Sep 29, 2022

Thanks @SSamDav.

And sorry, I forgot to mention in order to avoid the test failure you saw yesterday, you can rebase the working branch on main. (If it still appears in the new run)

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @SSamDav 🚀 .

The fix follows exactly what has been done for T5 👍

@ydshieh ydshieh requested a review from patil-suraj September 29, 2022 10:05
@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Sep 29, 2022

Ok for me if we know that it helps for inference. In T5 it didn't really help for training at all in the end, so if this was implemented to enable training in fp16, I'm not sure it's a good idea (I think @patil-suraj won't have time to look into it btw).

Also cc @ArthurZucker here just FYI

@ydshieh
Copy link
Collaborator

ydshieh commented Sep 30, 2022

Since @patrickvonplaten prefers to have an issue where this PR will solve, I think we are not going to merge this PR at this moment. Let's see if there will be such issues reported for LongT5 in the future. We can make an investigation and decide if to re-open/merge this PR (or with a different fix). WDYT? cc @younesbelkada @ArthurZucker .


And regarding nan: My though is that it's most likely coming from the sequences with all -inf after adding the mask to the attention scores, and nan after softmax, like what we observed in OPT or Bloom recently, where we provided a fix as close as to where nan happens. (However, in the case of T5, a clamp is done T5LayerFF which is not attention-related).

@SSamDav
Copy link
Contributor Author

SSamDav commented Sep 30, 2022

I my tests when I run a finetuned version of the google/long-t5-tglobal-base in FP16 I got Nan in the forward step, I could check if the values com from the LongT5LayerFF.

@ydshieh
Copy link
Collaborator

ydshieh commented Sep 30, 2022

I my tests when I run a finetuned version of the google/long-t5-tglobal-base in FP16 I got Nan in the forward step, I could check if the values com from the LongT5LayerFF.

Hi, so it is running the inference, right? Is that finetuned checkpoint uploaded to Hub?

@SSamDav
Copy link
Contributor Author

SSamDav commented Sep 30, 2022

Hi, so it is running the inference, right?

Yes

Is that finetuned checkpoint uploaded to Hub?

No, it was trained in confidential data.

@ydshieh
Copy link
Collaborator

ydshieh commented Sep 30, 2022

Is that finetuned checkpoint uploaded to Hub?

No, it was trained in confidential data.

Got it. However, it would be really nice if we have a public available checkpoint (on another dataset) that can show the issue and the effect of the fix. I understand that it may not easy to obtain another such checkpoint - and potentially time consuming.

@patrickvonplaten Any further comment?

@younesbelkada
Copy link
Contributor

Hi!
I second what @ydshieh said, probably the root cause of this is happening inside the attention score computation as it has been observed for BLOOM and OPT - maybe it's worth investigating a bit before merging!
As a simple test, we could try to reproduce what has been done in #18057 & #17437 and see if this fixes the initial issue

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @SSamDav for the fix!
I managed to reproduce the initial issue with the following snippet:

import torch
from transformers import AutoTokenizer, LongT5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")

model = LongT5ForConditionalGeneration.from_pretrained("google/long-t5-tglobal-base", torch_dtype=torch.float16).to(0)
inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt")
input_ids = inputs.input_ids

outputs = model.encoder(input_ids.to(0))
print(outputs.last_hidden_state.isnan().any())

However, it seems that the rootcause of this issue is happening at the LongT5LayerFF layer, exactly at this line. It seems that adding the previous hidden states with the GeLU-ed hidden states (forwarded_states) causes overflow issues. I tried adding scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min)) here but this seems to not help as the overflow comes after the attention layer. I propose to use these changes for now as it definitely helps to get inference in fp16 working. Maybe we should still add the line scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min)), but I can see that the attention scores are casted in fp32 before the softmax, so maybe it's not necessary - cc @ydshieh 🙏

I think that we should also add a small slow test reproducing the behavior of the snippet above!

@slow 
def test_fp16_inference(self):
    tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
    model = LongT5ForConditionalGeneration.from_pretrained("google/long-t5-tglobal-base", torch_dtype=torch.float16).to(0)
    inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt")
    input_ids = inputs.input_ids
    outputs = model.encoder(input_ids.to(0))
    self.assertFalse(outputs.last_hidden_state.isnan().any())

I also propose to change the comments and explicitly specify that this helps for fp16 inference, not training as mentioned by @patrickvonplaten

src/transformers/models/longt5/modeling_longt5.py Outdated Show resolved Hide resolved
src/transformers/models/longt5/modeling_longt5.py Outdated Show resolved Hide resolved
src/transformers/models/longt5/modeling_longt5.py Outdated Show resolved Hide resolved
SSamDav and others added 4 commits October 3, 2022 09:15
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@patrickvonplaten
Copy link
Contributor

Hey @SSamDav,

If the PR as is now solves your problem for inference - it's good for me to merge! I don't think it'll fix problems with fine-tuning though

@SSamDav
Copy link
Contributor Author

SSamDav commented Oct 4, 2022

Hey @patrickvonplaten, good thanks for the help!

@patrickvonplaten patrickvonplaten merged commit 971da2e into huggingface:main Oct 4, 2022
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 18, 2022
* Clamping hidden state values to allow FP16

* Reformating

* Adding missing if condition

* Update src/transformers/models/longt5/modeling_longt5.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/longt5/modeling_longt5.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/models/longt5/modeling_longt5.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Formating file

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants