New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix mixed precision in TF models #9163
Conversation
@@ -56,14 +57,20 @@ def mish(x): | |||
|
|||
def gelu_fast(x): | |||
x = tf.convert_to_tensor(x) | |||
coeff1 = tf.cast(7978845608, x.dtype) | |||
coeff1 = tf.cast(0.7978845608, x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow was that wrong the whole time before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! I was as surprised as you 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for fixing this!
src/transformers/activations_tf.py
Outdated
if version.parse(tf.version.VERSION) >= version.parse("2.4"): | ||
gelu = tf.keras.activations.gelu | ||
else: | ||
gelu = tf.keras.layers.Activation(_gelu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice to be able to use the TF version now!
What does this PR do?
This PR aims to fix the mixed precision issues when
tf.keras.mixed_precision.experimental.set_policy()
is set to something else thantf.float32
. In the same page, this PR aims to fix some TFLite quantization issues.Before to further continue this PR, the PR #9418 has to be merged.
Fixes # (issue)
#7052