Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RoPE position info in batch prefill/decode kernels #69

Merged
merged 1 commit into from
Feb 1, 2024

Conversation

MasterJH5574
Copy link
Collaborator

This PR adds q/k position information to batch prefill/decode kernels. More specifically, the kernel now accepts two additional arrays:

  • q_rope_position with shape (total_q_len,), denoting the in-sequence position of each position in the input q.
  • k_rope_pos_offset with shape (num_sequence,), denoting the start position of each sequence in k.

These two arrays helps on-the-fly calculate RoPE in multi-level cases.

Tests test_batch_prefill and test_batch_decode can pass. Performance is not validated yet. Per discussion with Zihao, this change is not very likely to incur significant perf regression.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 21, 2024

I'll merge this into the mainline after #75 gets merged.

@MasterJH5574 MasterJH5574 force-pushed the qk-rope-info branch 2 times, most recently from 5b189f5 to 47686ef Compare January 29, 2024 18:49
@yzh119
Copy link
Collaborator

yzh119 commented Jan 31, 2024

Sorry about the new conflicts, I'll take care of them tmr.

This PR adds q/k position information to batch prefill/decode
kernels. More specifically, the kernel now accepts two
additional arrays:
* `q_rope_position` with shape `(total_q_len,)`, denoting the
in-sequence position of each position in the input q.
* `k_rope_pos_offset` with shape `(num_sequence,)`, denoting
the start position of each sequence in k.

These two arrays helps on-the-fly calculate RoPE in multi-level
cases.

Tests `test_batch_prefill` and `test_batch_decode` can pass.
Performance is not validated yet. Per discussion with Zihao,
this change is not very likely to incur significant perf
regression.
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @MasterJH5574 !

@yzh119 yzh119 merged commit a389ed4 into flashinfer-ai:main Feb 1, 2024
yzh119 added a commit that referenced this pull request Feb 16, 2024
This PR fixes #113, which is because #69 changed the
`BatchPrefillWithPagedKVCacheWrapperDispatched` signature, and
`flashinfer_decl.h` was not updated accordingly.

Also fixes some tiny format issues in #111.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants