[misc] attention hot-path cleanup + denoising loop hoists#1272
Conversation
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 PR merge requirementsWaiting for
This rule is failing.
|
There was a problem hiding this comment.
Code Review
This pull request introduces several performance optimizations across the attention backends and the denoising pipeline, including hoisting constant calculations out of loops, reducing host-device synchronizations, and implementing mask caching. Feedback highlights that the use of a class-level shared buffer in VideoSparseAttentionImpl is not thread-safe and relies on a fragile assumption regarding zero-padding that could lead to numerical discrepancies. Additionally, the mask caching optimization for Flash Attention appears currently unreachable due to an incomplete backend implementation and a lack of integration within the denoising stage.
Addresses review comments on PR hao-ai-lab#1272: - Class-level shared `_tile_buf` was not thread-safe — concurrent inference requests in the same process would clobber each other. - Reusing the buffer across metadata instances also made the "pad positions stay zero" invariant fragile: if non_pad_index ever shifted between calls of the same shape, stale non-zero data would leak into what are now padding positions. Both concerns dissolve by scoping the buffer to a single metadata object (built fresh per denoising step, lifetime ends with the step). Numerics bit-exact across the Wan2.1 + VSA bench. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
6b8e8c2 to
027d2d8
Compare
Addresses review comments on PR hao-ai-lab#1272: - Class-level shared `_tile_buf` was not thread-safe — concurrent inference requests in the same process would clobber each other. - Reusing the buffer across metadata instances also made the "pad positions stay zero" invariant fragile: if non_pad_index ever shifted between calls of the same shape, stale non-zero data would leak into what are now padding positions. Both concerns dissolve by scoping the buffer to a single metadata object (built fresh per denoising step, lifetime ends with the step). Numerics bit-exact across the Wan2.1 + VSA bench.
Satyam-53
left a comment
There was a problem hiding this comment.
Refactoring and performance changes look good to me except for one of the code files. PTAL at the comment for conditioning.py file.
|
Thanks for reviewing! |
- bsa_attn: replace .item() with python max() for max_seqlen_k (avoids host/device sync per active-block flash call) - sla: drop redundant .contiguous() after softmax/elu/relu feature maps - flash_attn: cache mask conversions on metadata so 30+ attention layers in one denoising step share one transformation - video_sparse_attn: reuse padded buffer in tile() across layers/steps, precompute combined index for untile() (one fancy index, not two) - denoising: hoist guidance_expand tensor / use_meanflow getattr / V2V zero-pad allocation out of the per-step loop - conditioning: drop unreachable post-return block in ConditioningStage.forward (CFG is handled in DenoisingStage) Numerics bit-exact across 8 Wan2.1 inference runs. Modest perf improvement (~1-4% e2e on Wan2.1-1.3B, within run-to-run noise).
Addresses review comments on PR hao-ai-lab#1272: - Class-level shared `_tile_buf` was not thread-safe — concurrent inference requests in the same process would clobber each other. - Reusing the buffer across metadata instances also made the "pad positions stay zero" invariant fragile: if non_pad_index ever shifted between calls of the same shape, stale non-zero data would leak into what are now padding positions. Both concerns dissolve by scoping the buffer to a single metadata object (built fresh per denoising step, lifetime ends with the step). Numerics bit-exact across the Wan2.1 + VSA bench.
027d2d8 to
960d5e9
Compare
|
Thanks @Satyam-53! Pushed an update — comment now points to the actual CFG sites ( Re removing the class: only |
Review feedback (hao-ai-lab#1272): point to the actual CFG sites in DenoisingStage and explain why the class is kept around. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
looks good to me now. We can keep the class for now. |
Satyam-53
left a comment
There was a problem hiding this comment.
I have marked it for the approval, maybe please take an approval from @SolitaryThinker as well before merging?
Review feedback (hao-ai-lab#1272): point to the actual CFG sites in DenoisingStage and explain why the class is kept around.
f23ef87 to
657207c
Compare
|
Hi Alex, this is from my agent setup's /review output. Could you double check to see if this makes sense? Issue 1:
|
… contract Review feedback (hao-ai-lab#1272): the per-metadata _mask_cache field never hits in current usage because every caller (hunyuanvideo15.py:702, hyworld.py:177, :197) builds a fresh FlashAttnMetadata per attention block call. Trim the field, lazy-init guard, and three lookup branches; keep the _key_padding_mask_from_attn_mask helper extraction which still simplifies the call site. Making the cache load-bearing would require sharing one metadata across blocks at the caller side and is left to a follow-up. Also document on VideoSparseAttentionBackend.tile() / preprocess_qkv() that the returned tensor aliases attn_metadata.tile_buf and is only valid until the next VSA-layer call on the same metadata. Both existing call sites copy via .transpose(...).contiguous() inside forward() so the contract holds today; the docstring fixes the implicit-invariant gap surfaced in review.
|
Thanks @SolitaryThinker — yep, both points hold up. Just pushed Issue 1 ( Issue 2 ( |
Summary
Small cleanup pass on the attention backends + denoising loop. Numerics bit-exact across 8 Wan2.1 runs, modest e2e improvement (within run-to-run noise on a 14s inference).
bsa_attn: drop.item()host/device sync in_flash_attn_single_*(use Pythonmax()over the cu_seqlens list)sla: drop redundant.contiguous()aftersoftmax/elu/relufeature maps (elementwise, layout already preserved)flash_attn: cache mask conversions on metadata so all attention layers in one denoising step share one transformationvideo_sparse_attn: reuse padded buffer intile()across layers/steps; precompute combined index foruntile()(one fancy index, not two)denoising: hoistguidance_expandtensor /use_meanflowgetattr/ V2V zero-pad allocation out of the per-step loopconditioning: drop the unreachable post-return block inConditioningStage.forward(CFG is handled insideDenoisingStage)Test plan
pytest fastvideo/tests/— to run before merging