-
Notifications
You must be signed in to change notification settings - Fork 2
Experiment: shared cos sin #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| q_embed[..., half_dim:] += q[..., :half_dim] * sin[..., half_dim:] | ||
| q_embed[..., :half_dim] += q[..., half_dim:] * sin[..., :half_dim] * -1 | ||
|
|
||
| k_embed = (key_states * cos) | ||
| k_embed[..., half_dim:] += key_states[..., :half_dim] * sin[..., half_dim:] | ||
| k_embed[..., :half_dim] += key_states[..., half_dim:] * sin[..., :half_dim] * -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume there is a lot of overhead from multiple aten::slice calls (had the same issue when slicing past_key_values)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative, calling rotate_half, is equally as bad 😅
It's quite frustrating to know that our attention layers take more time to apply the rotary embedding than the attention itself 😞
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, at least it's good to know there's a strong bottleneck there. Maybe there exist better pytorch-based implementation (not sure how TGI handles it).
|
Using the plot facilities from #12 (and using the plots in that PR as a reference for the performance in batch size sweepprompt length sweepperformance conclusions
|
|
@gante maybe this can be useful for apple-to-apple comparison NVIDIA/cutlass#430 (comment) |
|
@fxmarty TIL, that's interesting to make comparisons! |


This is an experimental PR that shares
cosandsinacross the decoder layers.If we look at the profile, a LOT of time is spent on
apply_rotary_pos_emb_opt. This is an attempt to reduce it.Learnings
PT version:
torch==2.1.0.dev20230621+cu118