-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Always run the rotary embedding layer in float32 #1508
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks good!
We probably should test this to make sure numerics are as close to our reference jax implementation as they were before, and that this does not negatively impact performance.
@@ -87,28 +88,20 @@ def build(self, inputs_shape): | |||
(None, None, self.num_query_heads, self.head_dim) | |||
) | |||
self.softmax = keras.layers.Softmax(dtype="float32") | |||
|
|||
self.rope_layer = RotaryEmbedding( | |||
max_wavelength=10000.0, dtype=self.dtype_policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, but let's start using the new 10_000.0
for large numbers, it's quite readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I love _
to separate large numbers!
freq_exponents = (2.0 / x_shape[-1]) * ops.arange( | ||
x_shape[-1] // 2, dtype="float32" | ||
x = self.rope_layer(x, start_index=start_index) | ||
x = ops.reshape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we drop a comment explaining this?
|
||
tensor = ops.cast(tensor, dtype=inverse_freq.dtype) | ||
freq = ops.einsum("i,j->ij", tensor, inverse_freq) | ||
embedding = ops.concatenate((freq, freq), axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a weird bug with concatenate on the jax runtime. See the note below in the code you are removing.
# Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA
# compilation on jax. We should be able to remove this once the
# following PR is in all jax releases we care about:
# https://github.com/openxla/xla/pull/7875
output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
return ops.reshape(output, x_shape)
This may be fixed on recent version of jax, but we should probably check that this is not an issue with the current version of jax on colab and kaggle with a GPU runtime. We might need to persist this fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checked inference, works with the JAX version on Colab with A100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RotaryEmbedding has always used concatenate and it worked even before the fix was shipped in JAX. (for e.g. Mistral worked). I am not sure if this bug is an obscure TPU thing. Might be good to confirm with the original author.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the concatenate issue was funky. It should be only with jax 0.4.23, not 0.4.24. And I believe it was only showing up at certain shapes (e.g. batch size 2 not batch size 1), and only during fine-tuning.
To be safe maybe let's propagate the fix for now, and ditch after 0.4.24 is on colab.
Already done here: https://colab.research.google.com/drive/1BNNlxN7Y7yAzJl0UeWdG9TZ6RpfJjCBS?usp=sharing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! thank you!
* Always run the rotary embedding layer in float32 * Fix the int32 issue with TensorFlow * Only run sin/cos embedding compute step in float32 * Avoid start_index from downcasting automatically * Use stack instrad of concatenate
Follow-up for #1497
This PR refactors the
keras_nlp.layers.modelling.rotary_embedding.RotaryEmbedding
layer to always compute infloat32
dtype since there are significant precision losses in other dtypes. Also update Gemma to use this layer instead of implementing its own version of RoPE.This PR isn't ready yet. TODO:
bfloat16
.Add tests for theRoataryEmbedding
layer to check no precision is lost with float16, bfloat16 dtypes.Colab showing the equivalence of Gemma's embedding and the rotary embedding in KerasNLP: https://colab.research.google.com/drive/1BNNlxN7Y7yAzJl0UeWdG9TZ6RpfJjCBS?usp=sharing