[Attention] feat: add support for Context Parallelism#1521
Conversation
|
@cursor review |
| == 0, ( | ||
| f"tp_size={self.tensor_parallel_size} must be divisible by" | ||
| f"dcp_size={self.decode_context_parallel_size}." | ||
| ) |
There was a problem hiding this comment.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces Context Parallelism (CP) to support larger batch sizes and increase concurrency, especially for models with Grouped-Query Attention (GQA) and Multi-Query Attention (MQA). The implementation is comprehensive, touching configuration, distributed state, attention backends, and KV cache management. The core idea of splitting the context across TP ranks and combining partial attention results using log-sum-exp is well-implemented, particularly in the new flash_attn and mla backends.
My review focuses on ensuring correctness, maintainability, and identifying potential issues. I've found a critical issue regarding an undefined attribute, a high-severity issue with class variable modification that could lead to subtle bugs, and several medium-severity issues related to code duplication and typos in comments and error messages. Overall, this is a solid implementation of a complex feature.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces Context Parallelism (CP) for models using MLA, GQA, and MQA, which is a significant feature enhancement. The changes are extensive, touching distributed state management, attention backends, and CUDA kernels. The implementation correctly reuses GPUs from the Tensor Parallelism group and handles the necessary logic for splitting context KV cache and merging partial attention results using log-sum-exp. The code is well-structured, and the addition of checks and assertions for the new parallelism dimension is good. My main feedback is on an unconventional use of __new__ for initialization in an abstract base class, which could be refactored for better clarity and maintainability. I also have a suggestion to improve the clarity of a complex function by adding a comment.
| def __new__(cls, *args, **kwargs): | ||
| # use __new__ so that all subclasses will call this | ||
| self = super().__new__(cls) | ||
| try: | ||
| from aphrodite.distributed.parallel_state import get_dcp_group | ||
| self.dcp_world_size = get_dcp_group().world_size | ||
| self.dcp_rank = get_dcp_group().rank_in_group | ||
| except AssertionError: | ||
| # DCP might not be initialized in testing | ||
| self.dcp_world_size = 1 | ||
| self.dcp_rank = 0 | ||
| self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \ | ||
| and self.can_return_lse_for_decode | ||
| return self |
There was a problem hiding this comment.
Using __new__ for instance initialization is unconventional and can be confusing for developers. While it achieves the goal of ensuring this logic is run for all subclasses without them needing to call super().__init__(), it's generally better to stick to standard Python idioms for maintainability. The __init__ method is the standard place for initialization logic.
A more idiomatic approach would be to move this logic into __init__ and have all subclasses explicitly call super().__init__(). If you want to enforce this initialization without relying on subclasses, a metaclass or a class decorator would be a more explicit and less surprising pattern.
| if rank > cp_target_rank and cp_chunk_seq_len: | ||
| real_cp_chunk_seq_len = cp_chunk_seq_len - 1 | ||
| else: | ||
| real_cp_chunk_seq_len = cp_chunk_seq_len |
There was a problem hiding this comment.
This logic to determine real_cp_chunk_seq_len is a bit subtle. A comment explaining why cp_chunk_seq_len is sometimes reduced by 1 would greatly improve the readability and maintainability of this complex function. It seems related to how context lengths are distributed among CP ranks, but the reasoning isn't immediately obvious.
Adds support for Context Parallelism for models using MLA, GQA, and MQA. You can enable it by using the
-cpflag. Currently, only applies CP to the decode phase, not prefill.Limitations:
tp_sizeneeds to be divisible bycp_sizetp_size>num_key_value_headsThis is because CP re-uses the GPUs from TP, and doesn't require more GPUs.
For supporting GQA/MQA models, we first need to separately compute the attention scores for the context and query KV within a sequence and then merge the results. For the query, no collective communication is required among the CP group, and for context, the KV is distributed across different CP ranks.
With a random init test model based on Qwen3-0.6B but with 1 kv head rather than 8, and 2x 5060 Ti:
TP=2:
TP=2 CP=2: