-
Notifications
You must be signed in to change notification settings - Fork 218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the rotary embedding computation in LLaMA #1544
Fix the rotary embedding computation in LLaMA #1544
Conversation
Also run the reverse embedding stage in compute_dtype instead of full-precision. This is how HF does it, so helps get the numerics closer
# If `cache_update_index` is a tensor, RotaryEmbedding expects it | ||
# to have dtype `self.compute_dtype`. | ||
start_index = ops.cast( | ||
start_index, self.rotary_embedding_layer.compute_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this run into the same problems as #5 in https://unsloth.ai/blog/gemma-bugs? float16 or bfloat16 are both bad for an incrementing integer.
|
||
# [batch_shape, seq_len, num_key_value_heads, head_dim] | ||
# -> [batch_shape, seq_len, num_heads, head_dim] | ||
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) | ||
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) | ||
key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should not be underscore I think
* Fix rotary embedding computation in LLaMA Also run the reverse embedding stage in compute_dtype instead of full-precision. This is how HF does it, so helps get the numerics closer * Don't cast start_index; save rope keys * Remove underscore from num_key_value_heads
LLaMA backbone ignored the start_index parameter when computing the rotary embeddings which lead to numerical issues during generation. This PR fixes it along with the reverse embedding layer in both Mistral and LLaMA: run the reverse embedding stage in
compute_dtype
instead of full-precision. This is how HF does it, so helps get the numerics closer.