Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always run the rotary embedding layer in float32 #1508

Merged
merged 6 commits into from
Mar 14, 2024

Conversation

tirthasheshpatel
Copy link
Contributor

@tirthasheshpatel tirthasheshpatel commented Mar 12, 2024

Follow-up for #1497

This PR refactors the keras_nlp.layers.modelling.rotary_embedding.RotaryEmbedding layer to always compute in float32 dtype since there are significant precision losses in other dtypes. Also update Gemma to use this layer instead of implementing its own version of RoPE.

This PR isn't ready yet. TODO:

  • Make sure the models (Gemma/Mistral) generates the same output with the presets.
  • Make sure the presets run in around 16GB RAM with bfloat16.
  • Add tests for the RoataryEmbedding layer to check no precision is lost with float16, bfloat16 dtypes.

Colab showing the equivalence of Gemma's embedding and the rotary embedding in KerasNLP: https://colab.research.google.com/drive/1BNNlxN7Y7yAzJl0UeWdG9TZ6RpfJjCBS?usp=sharing

@github-actions github-actions bot added the Gemma Gemma model specific issues label Mar 12, 2024
@tirthasheshpatel tirthasheshpatel marked this pull request as ready for review March 13, 2024 18:57
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks good!

We probably should test this to make sure numerics are as close to our reference jax implementation as they were before, and that this does not negatively impact performance.

@@ -87,28 +88,20 @@ def build(self, inputs_shape):
(None, None, self.num_query_heads, self.head_dim)
)
self.softmax = keras.layers.Softmax(dtype="float32")

self.rope_layer = RotaryEmbedding(
max_wavelength=10000.0, dtype=self.dtype_policy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, but let's start using the new 10_000.0 for large numbers, it's quite readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I love _ to separate large numbers!

freq_exponents = (2.0 / x_shape[-1]) * ops.arange(
x_shape[-1] // 2, dtype="float32"
x = self.rope_layer(x, start_index=start_index)
x = ops.reshape(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we drop a comment explaining this?


tensor = ops.cast(tensor, dtype=inverse_freq.dtype)
freq = ops.einsum("i,j->ij", tensor, inverse_freq)
embedding = ops.concatenate((freq, freq), axis=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a weird bug with concatenate on the jax runtime. See the note below in the code you are removing.

        # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA
        # compilation on jax. We should be able to remove this once the
        # following PR is in all jax releases we care about:
        # https://github.com/openxla/xla/pull/7875
        output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
        return ops.reshape(output, x_shape)

This may be fixed on recent version of jax, but we should probably check that this is not an issue with the current version of jax on colab and kaggle with a GPU runtime. We might need to persist this fix.

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked inference, works with the JAX version on Colab with A100

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RotaryEmbedding has always used concatenate and it worked even before the fix was shipped in JAX. (for e.g. Mistral worked). I am not sure if this bug is an obscure TPU thing. Might be good to confirm with the original author.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the concatenate issue was funky. It should be only with jax 0.4.23, not 0.4.24. And I believe it was only showing up at certain shapes (e.g. batch size 2 not batch size 1), and only during fine-tuning.

To be safe maybe let's propagate the fix for now, and ditch after 0.4.24 is on colab.

@tirthasheshpatel
Copy link
Contributor Author

Code looks good!

We probably should test this to make sure numerics are as close to our reference jax implementation as they were before, and that this does not negatively impact performance.

Already done here: https://colab.research.google.com/drive/1BNNlxN7Y7yAzJl0UeWdG9TZ6RpfJjCBS?usp=sharing

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! thank you!

@mattdangerw mattdangerw merged commit f127901 into keras-team:master Mar 14, 2024
10 checks passed
@grasskin grasskin mentioned this pull request Mar 21, 2024
abuelnasr0 pushed a commit to abuelnasr0/keras-nlp that referenced this pull request Apr 2, 2024
* Always run the rotary embedding layer in float32

* Fix the int32 issue with TensorFlow

* Only run sin/cos embedding compute step in float32

* Avoid start_index from downcasting automatically

* Use stack instrad of concatenate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants