-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clip floating point constants to bf16 range to avoid inf conversion #20605
Clip floating point constants to bf16 range to avoid inf conversion #20605
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for opening a clean PR. I still have the same comment :-)
Also make sure you run make style
on your branch to pass the quality tests on your PR.
src/transformers/modeling_utils.py
Outdated
if os.environ.get("XLA_USE_BF16") == '1': | ||
return torch.bfloat16 | ||
if os.environ.get("XLA_DOWNCAST_BF16") == '1': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As said before, we have a constant ENV_VARS_TRUE_VALUES
in utils
you should reuse here, to catch any declination of the user setting this environment variable. You should then test
os.environ.get(xxx, "0").upper() in ENV_VARS_TRUE_VALUES
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
…uggingface#20605) Co-authored-by: EC2 Default User <ec2-user@ip-172-31-40-169.us-west-2.compute.internal>
…uggingface#20605) Co-authored-by: EC2 Default User <ec2-user@ip-172-31-40-169.us-west-2.compute.internal>
…uggingface#20605) Co-authored-by: EC2 Default User <ec2-user@ip-172-31-40-169.us-west-2.compute.internal>
…uggingface#20605) Co-authored-by: EC2 Default User <ec2-user@ip-172-31-40-169.us-west-2.compute.internal>
When running HuggingFace BERT (any size) fine-tuning tutorial with transformers version >= 4.21.0 and using XLA_USE_BF16=1 or XLA_DOWNCAST_BF16=1, I see NaNs in the loss after the first step.
What does this PR do?
This PR addresses the issue where the model code passes a value that is out of range for XLA_USE_BF16=1 or XLA_DOWNCAST_BF16=1, so the conversion would cast it to -inf.
The NaNs likely come from the transformers library change: #17306 . This PR replaced many lines which used to be -float(inf) (or other small constants) with torch.finfo().min. For torch.float32 the min value is -3.4028234663852886e+38 which is smaller than the bfloat16 minimum of -3.3895313892515355e+38. So the problem is that torch.finfo(torch.float32).min = -3.4028234663852886e+38 gets converted to -inf. When the original encoder_extended_attention_mask is 1, then encoder_extended_attention_mask becomes (1.0 - 1.0 ) * -inf which becomes NaN (via IEEE rule Inf * 0.0 = NaN).
This PR ensures torch.finfo(torch.bfloat16).min = -3.3895313892515355e+38 and not -inf. Then the results would not have Nans.
The following lines checks for XLA_USE_BF16 or XLA_DOWNCAST_BF16 environment variable and sets the dtype accordingly:
Referencing related issues: aws-neuron/aws-neuron-sdk#593 and pytorch/xla#4152
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sgugger