Skip to content

[CB] Add paged_attention kernel#44379

Merged
remi-or merged 6 commits intomainfrom
cb-paged-attention
Mar 9, 2026
Merged

[CB] Add paged_attention kernel#44379
remi-or merged 6 commits intomainfrom
cb-paged-attention

Conversation

@remi-or
Copy link
Collaborator

@remi-or remi-or commented Mar 1, 2026

Summary

! This PR is in draft, waiting for #44227 to be merged

This PR adds support for the flash_attention_with_kvcache kernel in continuoys batching. This is very efficient for decode-only batches, hence especially useful for long generations.
Right now, this needs to be explicitly turned on by the user, because there are some divergence in generation. Once this is fixed / confirmed to be expected, it will be turned on by default.
This PR also fixes a few things:

  • a bug with async batching + forking
  • the reset of static tensors happens right before they are filled, with the future padded size
  • better cleanup for CUDA graphs and shared memory pool
  • better handling of EOS tokens: we look into model AND generation config, and support multiple EOS tokens

Performance

Arguments Throughput on main Throughput with PR Delta
--samples 10 539.00 498.81 -7.46%
--samples 20 --num-blocks 20 133.34 651.52 +388.59%
--samples 50 1420.52 1426.69 +0.43%
--samples 100 2398.94 2403.40 +0.19%
--samples 100 --attn flash_attention_2 2058.21 2053.22 -0.24%
--samples 100 --attn sdpa 867.49 838.25 -3.37%
--samples 500 --use-async 5698.34 5777.05 +1.38%
--samples 500 --add-prefix --compile 6564.75 6832.18 +4.07%
--samples 50 --num-return-sequences 8 --do-sample 689.13 700.83 +1.70%
--samples 100 --num-return-sequences 4 --do-sample 1284.28 1331.73 +3.70%

Performance on long generations:

Generation length tok/s on main tok/s with PR tok/s with VLLM Ratio (with PR / VLLM)
4K 1521 2489 3559 69.9 %
8K 1044 2029 2646 76.7 %
16K 605 1262 1509 83.6 %

Tests

All tests, including the ones added for this feature, pass.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@remi-or remi-or requested a review from ArthurZucker March 2, 2026 12:37
@remi-or remi-or force-pushed the cb-paged-attention branch from f7cbee7 to 79c4729 Compare March 3, 2026 18:01
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!
Next is llm eval with paged vs no paged ? (model perf testing)

read_index: List of tensors indicating which cache positions to read from, one per attention group.
logits_indices: Tensor indicating which positions in the output should be used for next-token prediction.
cache: The [`PagedAttentionCache`] instance managing the KV cache.
block_table: Block table for paged KV cache. If provided, uses `flash_attn_with_kvcache` for fused attention + cache update.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would help a LOT if you can show what it should look like etc. I remember trying to figure it out from code only and it can be super annoying!

The better the doc the happier myself

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added documentation in the docstring, but no drawing / concrete example just there. The reason is we are adding support for compile soon, and that will involve a dedicated wrapper. I think the cleanest would be to add a concrete example / ASCII drawing there. Leaving a TODO for this, and the current documentation should make things understandable at least/

elif isinstance(self.eos_token_id, int):
if self.eos_token_id >= 0:
self._eos_token_ids.add(self.eos_token_id)
# If there are multiple EOS token IDs, add them to the set only if they are valid, ie. non-negative
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never saw that happenign

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this handle the case where we set eos_token_id=-1 for infinite generation

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine on my side, I think we can refactor the paged fa into the general fa modeling as we duplicate a lot of effort on processing kwargs.

attn_output = attn_output[0]
# Reshape output from [batch_size, 1, num_heads, head_dim] to [batch_size, num_heads, head_dim]
attn_output = attn_output.squeeze(1)
return attn_output, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment on maybe refactoring this. Not a must for this PR and maybe I'll tackle this in a different PR. Atm, we manually do a lot of the kwargs processing which can already be done in the base fa forward here

We should enter the right branch automatically due to the fa kwargs being available. This would leverage

def _process_flash_attention_kwargs(
which you should not need to handle here

It would probably require some additional logic to add the kv cache branch and some kwargs processing (force int dtype etc). Wdyt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it would make things easier in general. Though since it's a pretty different behavior when changing from one function to the other, we would need to document this quite carefully!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of scope for this PR as discussed. yea agreed it might be behaving differently so needs some docs then

Comment on lines +141 to +143
# This function is not implemented but should never be called because block table is not used on NPU
def npu_flash_attn_with_kvcache():
raise NotImplementedError("npu_flash_attn_with_kvcache is not implemented")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@remi-or remi-or added this pull request to the merge queue Mar 9, 2026
Merged via the queue into main with commit a08aa52 Mar 9, 2026
29 checks passed
@remi-or remi-or deleted the cb-paged-attention branch March 9, 2026 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants