Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the rotary embedding computation in LLaMA #1544

Merged
merged 3 commits into from
Apr 3, 2024

Conversation

tirthasheshpatel
Copy link
Contributor

LLaMA backbone ignored the start_index parameter when computing the rotary embeddings which lead to numerical issues during generation. This PR fixes it along with the reverse embedding layer in both Mistral and LLaMA: run the reverse embedding stage in compute_dtype instead of full-precision. This is how HF does it, so helps get the numerics closer.

Also run the reverse embedding stage in compute_dtype instead of full-precision. This is how HF does it, so helps get the numerics closer
# If `cache_update_index` is a tensor, RotaryEmbedding expects it
# to have dtype `self.compute_dtype`.
start_index = ops.cast(
start_index, self.rotary_embedding_layer.compute_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this run into the same problems as #5 in https://unsloth.ai/blog/gemma-bugs? float16 or bfloat16 are both bad for an incrementing integer.


# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)
key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not be underscore I think

@mattdangerw mattdangerw merged commit 9ac3335 into keras-team:master Apr 3, 2024
10 checks passed
SamanehSaadat pushed a commit to SamanehSaadat/keras-nlp that referenced this pull request Apr 10, 2024
* Fix rotary embedding computation in LLaMA

Also run the reverse embedding stage in compute_dtype instead of full-precision. This is how HF does it, so helps get the numerics closer

* Don't cast start_index; save rope keys

* Remove underscore from num_key_value_heads
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants