Skip to content

[Attention] feat: add support for Context Parallelism#1521

Merged
AlpinDale merged 9 commits into
mainfrom
context_parallel
Sep 20, 2025
Merged

[Attention] feat: add support for Context Parallelism#1521
AlpinDale merged 9 commits into
mainfrom
context_parallel

Conversation

@AlpinDale

Copy link
Copy Markdown
Member

Adds support for Context Parallelism for models using MLA, GQA, and MQA. You can enable it by using the -cp flag. Currently, only applies CP to the decode phase, not prefill.

Limitations:

  • tp_size needs to be divisible by cp_size
  • tp_size > num_key_value_heads

This is because CP re-uses the GPUs from TP, and doesn't require more GPUs.

For supporting GQA/MQA models, we first need to separately compute the attention scores for the context and query KV within a sequence and then merge the results. For the query, no collective communication is required among the CP group, and for context, the KV is distributed across different CP ranks.

With a random init test model based on Qwen3-0.6B but with 1 kv head rather than 8, and 2x 5060 Ti:

TP=2:

GPU KV cache size: 949,296 tokens (12.67 GiB, 115.9x concurrency)

TP=2 CP=2:

GPU KV cache size: 1,896,256 tokens (12.66 GiB, 231.5x concurrency)

@AlpinDale

Copy link
Copy Markdown
Member Author

@cursor review

== 0, (
f"tp_size={self.tensor_parallel_size} must be divisible by"
f"dcp_size={self.decode_context_parallel_size}."
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Missing Space in Assertion Message

The assertion message for tensor_parallel_size divisibility by decode_context_parallel_size is missing a space. It currently displays as "must be divisible bydcp_size" instead of "must be divisible by dcp_size", affecting readability.

Fix in Cursor Fix in Web

@AlpinDale

Copy link
Copy Markdown
Member Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Context Parallelism (CP) to support larger batch sizes and increase concurrency, especially for models with Grouped-Query Attention (GQA) and Multi-Query Attention (MQA). The implementation is comprehensive, touching configuration, distributed state, attention backends, and KV cache management. The core idea of splitting the context across TP ranks and combining partial attention results using log-sum-exp is well-implemented, particularly in the new flash_attn and mla backends.

My review focuses on ensuring correctness, maintainability, and identifying potential issues. I've found a critical issue regarding an undefined attribute, a high-severity issue with class variable modification that could lead to subtle bugs, and several medium-severity issues related to code duplication and typos in comments and error messages. Overall, this is a solid implementation of a complex feature.

Comment thread aphrodite/v1/attention/backends/flashinfer.py
Comment thread aphrodite/v1/attention/backends/mla/flashattn_mla.py
Comment thread aphrodite/v1/attention/backends/flash_attn.py
Comment thread aphrodite/v1/attention/backends/mla/common.py
Comment thread aphrodite/v1/attention/backends/mla/common.py Outdated
Comment thread aphrodite/v1/attention/backends/mla/common.py Outdated
Comment thread aphrodite/v1/attention/backends/mla/common.py Outdated
AlpinDale and others added 4 commits September 18, 2025 11:29
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@AlpinDale

Copy link
Copy Markdown
Member Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Context Parallelism (CP) for models using MLA, GQA, and MQA, which is a significant feature enhancement. The changes are extensive, touching distributed state management, attention backends, and CUDA kernels. The implementation correctly reuses GPUs from the Tensor Parallelism group and handles the necessary logic for splitting context KV cache and merging partial attention results using log-sum-exp. The code is well-structured, and the addition of checks and assertions for the new parallelism dimension is good. My main feedback is on an unconventional use of __new__ for initialization in an abstract base class, which could be refactored for better clarity and maintainability. I also have a suggestion to improve the clarity of a complex function by adding a comment.

Comment on lines +267 to +280
def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
try:
from aphrodite.distributed.parallel_state import get_dcp_group
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
and self.can_return_lse_for_decode
return self

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using __new__ for instance initialization is unconventional and can be confusing for developers. While it achieves the goal of ensuring this logic is run for all subclasses without them needing to call super().__init__(), it's generally better to stick to standard Python idioms for maintainability. The __init__ method is the standard place for initialization logic.

A more idiomatic approach would be to move this logic into __init__ and have all subclasses explicitly call super().__init__(). If you want to enforce this initialization without relying on subclasses, a metaclass or a class decorator would be a more explicit and less surprising pattern.

Comment on lines +881 to +884
if rank > cp_target_rank and cp_chunk_seq_len:
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
else:
real_cp_chunk_seq_len = cp_chunk_seq_len

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to determine real_cp_chunk_seq_len is a bit subtle. A comment explaining why cp_chunk_seq_len is sometimes reduced by 1 would greatly improve the readability and maintainability of this complex function. It seems related to how context lengths are distributed among CP ranks, but the reasoning isn't immediately obvious.

@AlpinDale AlpinDale merged commit 42b4b4a into main Sep 20, 2025
0 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant