Skip to content

enable cpu paged cache#42869

Merged
Cyrilvallez merged 50 commits intohuggingface:mainfrom
jiqing-feng:cpu_paged
Jan 29, 2026
Merged

enable cpu paged cache#42869
Cyrilvallez merged 50 commits intohuggingface:mainfrom
jiqing-feng:cpu_paged

Conversation

@jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Dec 15, 2025

CPU can also use paged cache with eager or sdpa:
python continuous_batching_simple.py --attn sdpa

Without this change, the previous command error would be like:

Error in generation loop: unsupported operand type(s) for -: 'NoneType' and 'int'
Traceback (most recent call last):
  File "/home/jiqing/transformers/src/transformers/generation/continuous_batching/continuous_api.py", line 1017, in _run_generation_loop
    paged_attention_cache = PagedAttentionCache(
                            ^^^^^^^^^^^^^^^^^^^^
  File "/home/jiqing/transformers/src/transformers/generation/continuous_batching/cache.py", line 191, in __init__
    num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jiqing/transformers/src/transformers/generation/continuous_batching/cache.py", line 481, in infer_num_blocks_and_max_batch_tokens
    num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens(
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jiqing/transformers/src/transformers/generation/continuous_batching/cache.py", line 522, in compute_num_blocks_and_max_batch_tokens
    cache_memory = self.get_available_memory(max_memory_percent)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jiqing/transformers/src/transformers/generation/continuous_batching/cache.py", line 456, in get_available_memory
    available_memory = total - max(allocated, reserved)
                       ~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for -: 'NoneType' and 'int'

@remi-or
Copy link
Collaborator

remi-or commented Dec 15, 2025

Hi @jiqing-feng , thanks for the contribution! Just letting you know that CPU-compatible continuous batching is not a priority right now, so even though this PR is small, it will not be reviewed right away. I am cautious about two things:

  1. How device map "auto" behaves and how it affects the model's repartition
  2. The lack of tests / benchmarks. We have a small template for continuous batching PRs, as in [CB] Easy optimizations for continuous batching #42839 if you can follow it, that would be great.

Will get to review this as soon as I have the bandwidth, thanks you!

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Contributor Author

  1. device_map="auto" will assign accelerator like cuda/xpu is exists, otherwise will use cpu. But I reverted this change and only added cpu as an option when cuda is not available, it will keep it as original if cuda exists.
  2. OK, I will add tests and benchmarks for it.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Dec 16, 2025

Hi @remi-or . I have updated the tests and examples for CPU. Now the example and tests can pass on CPU. Please review this PR and let me know your opinion. Thanks!

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Contributor Author

Hi @remi-or . Do you have bandwidth to review this PR?

@jiqing-feng
Copy link
Contributor Author

Hi @SunMarc . We have enabled flash varlen attention for CPU:https://huggingface.co/kernels-community/flash-attn2/tree/main/build.
With this change, the cpu can also use paged flash attention in contiguous batching case. The official example continuous_batching_simple.py can gain 1.6x speed up with kernels-community/flash_attention_2 compared to paged|sdpa
Would you please review this PR? Thanks!

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Jan 13, 2026

For the failed tests. I can pass tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_num_return_sequences_1 on CPU but failed on NV A100. The test still failed on A100 without my changes.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@remi-or
Copy link
Collaborator

remi-or commented Jan 13, 2026

For the failed tests. I can pass tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_num_return_sequences_1 on CPU but failed on NV A100. The test still failed on A100 without my changes.

Yes that test is a bit flaky, will look into it soon.

I just merged a big PR which caused conflict, my bad. Could you update your PR and I will review? Thanks

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@remi-or
Copy link
Collaborator

remi-or commented Jan 26, 2026

Hi @jiqing-feng , I just run the test on my end and a lot do not pass. Here re the results:

FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_block_sharing_with_hybrid_model - RuntimeError: No CUDA GPUs are available
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_continuous_batching_config_combinations_08 - NotImplementedError: Could not run '_flash_attn2_588b404::varlen_fwd' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using c...
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_continuous_batching_config_combinations_09 - NotImplementedError: Could not run '_flash_attn2_588b404::varlen_fwd' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using c...
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_continuous_batching_config_combinations_20 - NotImplementedError: Could not run '_flash_attn2_588b404::varlen_fwd' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using c...
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_continuous_batching_config_combinations_21 - NotImplementedError: Could not run '_flash_attn2_588b404::varlen_fwd' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using c...
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_continuous_batching_diverse_models_0_TinyLlama_TinyLlama_1_1B_Chat_v1_0 - NotImplementedError: Could not run '_flash_attn2_588b404::varlen_fwd' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using c...
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_continuous_batching_diverse_models_1_TinyLlama_TinyLlama_1_1B_Chat_v1_0 - NotImplementedError: Could not run '_flash_attn2_588b404::varlen_fwd' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using c...
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_continuous_batching_diverse_models_4_google_gemma_2_2b_it - NotImplementedError: Could not run '_flash_attn2_588b404::varlen_fwd' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using c...
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_continuous_batching_diverse_models_5_google_gemma_2_2b_it - NotImplementedError: Could not run '_flash_attn2_588b404::varlen_fwd' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using c...
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_num_return_sequences_0 - AssertionError: 0 != 2 : Expected 2 results, but got len(results) = 0
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_num_return_sequences_1 - AssertionError: 0 != 2 : Expected 2 results, but got len(results) = 0
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_prefix_sharing - RuntimeError: No CUDA GPUs are available

I am surprised because I thought there was some version of flash that worked on CPU, but I might be wrong here. If not please add back the decorator for torch_accelerator for those tests or a sip clause.
For the OSError, you can probably solve that by connecting to HF using a token or the hf-cli . The tests I ran using CUDA_VISIBLE_DEVICES="" RUN_SLOW=1 pytest tests/generation/test_continuous_batching.py

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Jan 27, 2026

Hi @remi-or . The most failed tests you listed can pass on my side, some failed tests like

FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_num_return_sequences_0 - AssertionError: 0 != 2 : Expected 2 results, but got len(results) = 0
FAILED tests/generation/test_continuous_batching.py::ContinuousBatchingGenerationTest::test_num_return_sequences_1 - AssertionError: 0 != 2 : Expected 2 results, but got len(results) = 0

are fixed in my last changes.

image

Here is my key packages:

torch                     2.10.0+cpu
kernels                   0.12.1
transformers              5.0.1.dev0

@jiqing-feng
Copy link
Contributor Author

Hi @remi-or . It seems that you didn't correctly loaded the latest kernels here: https://huggingface.co/kernels-community/flash-attn2/tree/main/build.

I'd like to log in to your node to check the env if it is possible. My email is jiqing.feng@intel.com
Please let me know if you need me to check the env. Thanks!

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Contributor Author

Hi @remi-or . I have fixed your comment and check cuda before using cuda stream. Please review the new change. Thanks!

@jiqing-feng
Copy link
Contributor Author

The failed CIs are not related to my changes. The main branch also failed.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Contributor Author

Hi @remi-or . I've fixed your comment. Please review the new change. Thanks!

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Contributor Author

Hi @remi-or , since this PR has been open for a while, I’m hoping we can wrap it up today. I’ll be online for the next few hours to address your feedback immediately. The failed CI is not related to my changes.

@jiqing-feng jiqing-feng requested a review from remi-or January 29, 2026 13:04
Refactor the initialization of _graphs to simplify the condition for using CUDA graphs.
@remi-or
Copy link
Collaborator

remi-or commented Jan 29, 2026

Hi, I just modified something related to the _graphs attribute, because perhaps my last comment was unclear. Testing and merging if tests pass.

Copy link
Collaborator

@remi-or remi-or left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for all the work you put into this. Please commit the 2 suggesions before merging! One is needed, the other will be very useful.

remi-or and others added 4 commits January 29, 2026 15:55
@jiqing-feng
Copy link
Contributor Author

LGTM! Thanks for all the work you put into this. Please commit the 2 suggesions before merging! One is needed, the other will be very useful.

Hi @remi-or . I have submitted your suggested commits. Thanks!

@Cyrilvallez Cyrilvallez merged commit 071e178 into huggingface:main Jan 29, 2026
19 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants