bugfix: fix chrono-edit context parallel #12660
Open
+14
−5
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
fixed #12661, fix the crash of ChronoEdit with context parallelism.
We need to disable the splitting of encoder_hidden_states because the image_encoder consistently generates 257 tokens for image_embed. This causes the shape of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation—to be indivisible by the number of devices in the CP.
Since the key/value in cross-attention depends solely on encoder_hidden_states (text or img), the (q_chunk * k) * v computation can be parallelized independently. Thus, there is no need to pass the parallel_config for cross-attention. This change reduces redundant all-to-all communications—specifically (3+1)×2=8 for the two cross-attention operations (text and img)—thereby improving ChronoEdit’s performance under context parallelism. With this optimization alone, I have achieved a nearly 1.85× speedup on L20x2, without relying on other optimizations such as torch.compile.
@sayakpaul @yiyixuxu @DN6
Reproduce