Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training with fp16 precision gives nan in Longt5 #17978

Closed
2 of 4 tasks
Leonard907 opened this issue Jul 1, 2022 · 14 comments
Closed
2 of 4 tasks

Training with fp16 precision gives nan in Longt5 #17978

Leonard907 opened this issue Jul 1, 2022 · 14 comments
Assignees
Labels

Comments

@Leonard907
Copy link

System Info

  • transformers version: 4.10.0.dev0
  • Platform: Linux-3.10.0-1160.62.1.el7.x86_64-x86_64-with-glibc2.17
  • Python version: 3.8.13
  • PyTorch version (GPU?): 1.9.0+cu111 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help?

@patrickvonplaten

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm currently running the scrolls_benchmark. I'm interested to see the performance of longt5 model on scrolls, so I changed the model name to google/long-t5-tglobal-base and run training with fp16 enabled (If I run with fp32, I get CUDA OOM errors). However, the output loss is always nan. I googled for fixes and found this post: t5-fp16-fixed. I searched in the transformers repo and found that the modelling_longt5 file doesn't seem to incorporate the clamp_value change. I wonder if this is the problem that fp16 is not working in longt5? And if so, is there a way to fix it by a similar approach like what you guys have done for t5? Thank you very much!

fyi: You probably noticed that the transformers version is 4.10.0 which does not have longt5. I manually added the longt5 files in a forked scrolls repo here longt5_folder. It indeed works properly under a small parameter setting.

Expected behavior

longt5 model not producing nan loss on fp16

@Leonard907 Leonard907 added the bug label Jul 1, 2022
@ydshieh
Copy link
Collaborator

ydshieh commented Jul 1, 2022

This is due to -1e10 is used as attention mask. A fix similar to what has done in #17306 should work.
Would you like to open a PR for LongT5, @Leonard907 ?

@ydshieh ydshieh self-assigned this Jul 1, 2022
@Leonard907
Copy link
Author

Leonard907 commented Jul 2, 2022

@ydshieh Sorry for the late reply, I ran a few experiments and found that fixing both the clamp value and attention mask seems to work (I should mention that at first I only fixed the attention mask and still gives nan. Then I added clamp value and there is no nan). For more detail, there are lines like longt5 704 that needs to be set to torch.finfo(dtype).min. For the clamp value part, my changes are identical to modeling_t5.py since I found the forward function to be similar. I did not read the code fully hence my modifications are hacky (i.e. using direct float16 attribute). Shall I publish a branch and open a PR for further review? This is my first time doing this so if there is a guide on how to publish changes and open PR I would be appreciated. Thanks!

@ZJaume
Copy link

ZJaume commented Jul 4, 2022

I had also NaN loss with google/mt5-small on PyTorch

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 4, 2022

@Leonard907 @ZJaume

Could you provide code snippets (training script) that reproduce this issue? I am a bit surprised that torch.finfo(dtype).min is not enough to avoid this issue. But what have done in T5 seems reasonable!

@patrickvonplaten
Copy link
Contributor

In general T5 just doesn't work well with fp16 since it was trained on bfloat16 and it seems like the model requires quite large values which are not supported by fp16. See: #14189

Note that LongT5, MT5, T5, ByT5 all use the same architecture more or less

@Leonard907
Copy link
Author

@ydshieh It was a bit hard to provide a training script as I'm experimenting on the SCROLLS dataset, there are a lot of setups and I also made some modifications myself. Probably one can reference this discussion stancld-longt5 and use the script by changing the bf16 to fp16, but I haven't given it a try.

@patrickvonplaten Thank you for mentioning more issues regarding this. I have trained the model using clamp_value and attention_mask fix and the results are poor. Without using fp16 I can actually get better results even using fewer parameters. Nevertheless, it's just a guess and I need to experiment on that more, but I thought it might be worth knowing.

FYI: How can I open a PR? I tried making a branch from main, but when I tried to push it I found that I had no permission to do so. For a workaround, I uploaded my changed file to one of my own repository here: longt5_fix. My fixes are hardcode, so it should be changed for general purposes.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 5, 2022

Please see https://huggingface.co/docs/transformers/contributing#start-contributing-pull-requests

(in short: you should create a branch, push to your fork, and open an PR)

@patrickvonplaten
Copy link
Contributor

I'm actually against the clamp hack here. We've seen multiple times that it's not sufficient to correct the fp16 problem in T5 so I doubt it'll be sufficient here. However adding clamp at multiple positions in the code slows down the forward pass and makes the code less readable. @patil-suraj @stas00 what do you think here?

@Leonard907
Copy link
Author

Based on my observations, using the clamp_value fix produces much worse results than just using fp32 under same configurations. And with the further comment from @patrickvonplaten , I realized that this fix also causes the training to take more time (I need around 1hr to generate predictions on test set with this fix, while using fp32 I only need around 20-30min).

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 7, 2022

(not promoting clamp_value here) I actually think, if we want to add clamp_value, it should be done only when the model is in training mode. (if we do so, this also solves the slow down issue)

(in inference/eval mode, large values almost implies something is wrong in the model)

@stas00
Copy link
Contributor

stas00 commented Jul 7, 2022

Indeed, the clamping works around the overflow but it's not really a good mathematical approach. I suppose the model learns to deal with this type of penalty, but it doesn't make things smooth for the training.

  1. the OP in [T5/MT5] resolve inf/nan under amp (mixed precision)  #10956 suggests penalizing large activation as it was done in the original code. it's done more consistently than clamping and probably easier for the model to learn.
  2. Then a group of people developed a weight scaling approach (in addition to normal mixed precision scaling) [T5/MT5] resolve inf/nan under amp (mixed precision)  #10956 (comment)

I also proposed removing clamping in t5: #10956 (comment)

@whiteRa2bit
Copy link

I have the same problem when training with bf16 set to True

@github-actions
Copy link

github-actions bot commented Aug 8, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@patrickvonplaten patrickvonplaten mentioned this issue Aug 18, 2022
5 tasks
@vnica2
Copy link

vnica2 commented Mar 29, 2024

System Info

* `transformers` version: 4.10.0.dev0

* Platform: Linux-3.10.0-1160.62.1.el7.x86_64-x86_64-with-glibc2.17

* Python version: 3.8.13

* PyTorch version (GPU?): 1.9.0+cu111 (False)

* Tensorflow version (GPU?): not installed (NA)

* Flax version (CPU?/GPU?/TPU?): not installed (NA)

* Jax version: not installed

* JaxLib version: not installed

* Using GPU in script?: yes

* Using distributed or parallel set-up in script?: yes

Who can help?

@patrickvonplaten

Information

* [ ]  The official example scripts

* [x]  My own modified scripts

Tasks

* [ ]  An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)

* [x]  My own task or dataset (give details below)

Reproduction

I'm currently running the scrolls_benchmark. I'm interested to see the performance of longt5 model on scrolls, so I changed the model name to google/long-t5-tglobal-base and run training with fp16 enabled (If I run with fp32, I get CUDA OOM errors). However, the output loss is always nan. I googled for fixes and found this post: t5-fp16-fixed. I searched in the transformers repo and found that the modelling_longt5 file doesn't seem to incorporate the clamp_value change. I wonder if this is the problem that fp16 is not working in longt5? And if so, is there a way to fix it by a similar approach like what you guys have done for t5? Thank you very much!

fyi: You probably noticed that the transformers version is 4.10.0 which does not have longt5. I manually added the longt5 files in a forked scrolls repo here longt5_folder. It indeed works properly under a small parameter setting.

Expected behavior

longt5 model not producing nan loss on fp16

It appears this is the case for mistral as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

8 participants