Skip to content

Conversation

@gante
Copy link
Collaborator

@gante gante commented Jun 25, 2023

⚠️ do not merge!

This is an experimental PR that shares cos and sin across the decoder layers.

If we look at the profile, a LOT of time is spent on apply_rotary_pos_emb_opt. This is an attempt to reduce it.

Learnings

PT version: torch==2.1.0.dev20230621+cu118

  • No significant result changes

Comment on lines +146 to +151
q_embed[..., half_dim:] += q[..., :half_dim] * sin[..., half_dim:]
q_embed[..., :half_dim] += q[..., half_dim:] * sin[..., :half_dim] * -1

k_embed = (key_states * cos)
k_embed[..., half_dim:] += key_states[..., :half_dim] * sin[..., half_dim:]
k_embed[..., :half_dim] += key_states[..., half_dim:] * sin[..., :half_dim] * -1
Copy link
Owner

@fxmarty fxmarty Jun 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume there is a lot of overhead from multiple aten::slice calls (had the same issue when slicing past_key_values)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative, calling rotate_half, is equally as bad 😅

It's quite frustrating to know that our attention layers take more time to apply the rotary embedding than the attention itself 😞

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at least it's good to know there's a strong bottleneck there. Maybe there exist better pytorch-based implementation (not sure how TGI handles it).

@gante
Copy link
Collaborator Author

gante commented Jun 27, 2023

Using the plot facilities from #12 (and using the plots in that PR as a reference for the performance in main)

batch size sweep

llama_sweep_e746c78_batch

prompt length sweep

llama_sweep_e746c78_length

performance conclusions

  • No major performance changes confirmed (I'd attribute the slightly faster runs to lower temperature in the room, we can see that the transformers model runs are also slightly faster)

@fxmarty
Copy link
Owner

fxmarty commented Jun 28, 2023

@gante maybe this can be useful for apple-to-apple comparison NVIDIA/cutlass#430 (comment)

@gante
Copy link
Collaborator Author

gante commented Jun 28, 2023

@fxmarty TIL, that's interesting to make comparisons!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants