Skip to content

[misc] attention hot-path cleanup + denoising loop hoists#1272

Merged
SolitaryThinker merged 7 commits into
hao-ai-lab:mainfrom
FoundationResearch:feature/kernel/vsa-attention-hot-path
May 10, 2026
Merged

[misc] attention hot-path cleanup + denoising loop hoists#1272
SolitaryThinker merged 7 commits into
hao-ai-lab:mainfrom
FoundationResearch:feature/kernel/vsa-attention-hot-path

Conversation

@alexzms
Copy link
Copy Markdown
Collaborator

@alexzms alexzms commented May 2, 2026

Summary

Small cleanup pass on the attention backends + denoising loop. Numerics bit-exact across 8 Wan2.1 runs, modest e2e improvement (within run-to-run noise on a 14s inference).

  • bsa_attn: drop .item() host/device sync in _flash_attn_single_* (use Python max() over the cu_seqlens list)
  • sla: drop redundant .contiguous() after softmax/elu/relu feature maps (elementwise, layout already preserved)
  • flash_attn: cache mask conversions on metadata so all attention layers in one denoising step share one transformation
  • video_sparse_attn: reuse padded buffer in tile() across layers/steps; precompute combined index for untile() (one fancy index, not two)
  • denoising: hoist guidance_expand tensor / use_meanflow getattr / V2V zero-pad allocation out of the per-step loop
  • conditioning: drop the unreachable post-return block in ConditioningStage.forward (CFG is handled inside DenoisingStage)

Test plan

  • pytest fastvideo/tests/ — to run before merging
  • Wan2.1-T2V-1.3B e2e: hash bit-exact (sha1 unchanged across 8 runs)

@mergify mergify Bot added type: misc Cleanup, config, dependencies scope: inference Inference pipeline, serving, CLI scope: attention Attention backends (VSA, STA, Flash, etc.) labels May 2, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 2, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 PR merge requirements

Waiting for

  • check-success=fastcheck-passed
This rule is failing.
  • check-success=fastcheck-passed
  • #approved-reviews-by>=1
  • check-success=full-suite-passed
  • check-success~=pre-commit
  • title~=(?i)^\[(feat|feature|bugfix|fix|refactor|perf|ci|doc|docs|misc|chore|kernel|new.?model|skill|skills|infra)\]

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several performance optimizations across the attention backends and the denoising pipeline, including hoisting constant calculations out of loops, reducing host-device synchronizations, and implementing mask caching. Feedback highlights that the use of a class-level shared buffer in VideoSparseAttentionImpl is not thread-safe and relies on a fragile assumption regarding zero-padding that could lead to numerical discrepancies. Additionally, the mask caching optimization for Flash Attention appears currently unreachable due to an incomplete backend implementation and a lack of integration within the denoising stage.

Comment thread fastvideo/attention/backends/video_sparse_attn.py Outdated
Comment thread fastvideo/attention/backends/flash_attn.py Outdated
Comment thread fastvideo/attention/backends/video_sparse_attn.py Outdated
alexzms added a commit to FoundationResearch/FastVideo that referenced this pull request May 2, 2026
Addresses review comments on PR hao-ai-lab#1272:

- Class-level shared `_tile_buf` was not thread-safe — concurrent
  inference requests in the same process would clobber each other.
- Reusing the buffer across metadata instances also made the "pad
  positions stay zero" invariant fragile: if non_pad_index ever
  shifted between calls of the same shape, stale non-zero data would
  leak into what are now padding positions.

Both concerns dissolve by scoping the buffer to a single metadata
object (built fresh per denoising step, lifetime ends with the step).
Numerics bit-exact across the Wan2.1 + VSA bench.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@alexzms alexzms marked this pull request as ready for review May 2, 2026 00:51
@alexzms alexzms force-pushed the feature/kernel/vsa-attention-hot-path branch from 6b8e8c2 to 027d2d8 Compare May 2, 2026 02:37
alexzms added a commit to FoundationResearch/FastVideo that referenced this pull request May 2, 2026
Addresses review comments on PR hao-ai-lab#1272:

- Class-level shared `_tile_buf` was not thread-safe — concurrent
  inference requests in the same process would clobber each other.
- Reusing the buffer across metadata instances also made the "pad
  positions stay zero" invariant fragile: if non_pad_index ever
  shifted between calls of the same shape, stale non-zero data would
  leak into what are now padding positions.

Both concerns dissolve by scoping the buffer to a single metadata
object (built fresh per denoising step, lifetime ends with the step).
Numerics bit-exact across the Wan2.1 + VSA bench.
@alexzms alexzms added the ready PR is ready to merge label May 3, 2026
Comment thread fastvideo/attention/backends/sla.py
Copy link
Copy Markdown
Collaborator

@Satyam-53 Satyam-53 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring and performance changes look good to me except for one of the code files. PTAL at the comment for conditioning.py file.

Comment thread fastvideo/attention/backends/flash_attn.py
Comment thread fastvideo/attention/backends/bsa_attn.py
Comment thread fastvideo/pipelines/stages/conditioning.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me.

Comment thread fastvideo/attention/backends/video_sparse_attn.py
@alexzms
Copy link
Copy Markdown
Collaborator Author

alexzms commented May 8, 2026

Thanks for reviewing!

alexzms added 2 commits May 8, 2026 23:20
- bsa_attn: replace .item() with python max() for max_seqlen_k
  (avoids host/device sync per active-block flash call)
- sla: drop redundant .contiguous() after softmax/elu/relu feature maps
- flash_attn: cache mask conversions on metadata so 30+ attention
  layers in one denoising step share one transformation
- video_sparse_attn: reuse padded buffer in tile() across layers/steps,
  precompute combined index for untile() (one fancy index, not two)
- denoising: hoist guidance_expand tensor / use_meanflow getattr /
  V2V zero-pad allocation out of the per-step loop
- conditioning: drop unreachable post-return block in
  ConditioningStage.forward (CFG is handled in DenoisingStage)

Numerics bit-exact across 8 Wan2.1 inference runs. Modest perf
improvement (~1-4% e2e on Wan2.1-1.3B, within run-to-run noise).
Addresses review comments on PR hao-ai-lab#1272:

- Class-level shared `_tile_buf` was not thread-safe — concurrent
  inference requests in the same process would clobber each other.
- Reusing the buffer across metadata instances also made the "pad
  positions stay zero" invariant fragile: if non_pad_index ever
  shifted between calls of the same shape, stale non-zero data would
  leak into what are now padding positions.

Both concerns dissolve by scoping the buffer to a single metadata
object (built fresh per denoising step, lifetime ends with the step).
Numerics bit-exact across the Wan2.1 + VSA bench.
@alexzms alexzms force-pushed the feature/kernel/vsa-attention-hot-path branch from 027d2d8 to 960d5e9 Compare May 8, 2026 23:20
@alexzms
Copy link
Copy Markdown
Collaborator Author

alexzms commented May 8, 2026

Thanks @Satyam-53!

Pushed an update — comment now points to the actual CFG sites (denoising.py:364-394, :706, :930; all separate-forward passes, fastvideo never used the batched-cat pattern, so the old code in forward() was diffusers-style residue rather than the active path).

Re removing the class: only forward() is no-op — verify_input/verify_output still validate CFG fields and handle the Matrix-Game "no prompt embeds → disable CFG" fallback, which isn't replicated elsewhere at this point in the pipeline. So I'd keep the class. Happy to consolidate later as a follow-up if you think it's worth doing.

alexzms added a commit to FoundationResearch/FastVideo that referenced this pull request May 8, 2026
Review feedback (hao-ai-lab#1272): point to the actual CFG sites in
DenoisingStage and explain why the class is kept around.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Satyam-53
Copy link
Copy Markdown
Collaborator

Thanks @Satyam-53!

Pushed an update — comment now points to the actual CFG sites (denoising.py:364-394, :706, :930; all separate-forward passes, fastvideo never used the batched-cat pattern, so the old code in forward() was diffusers-style residue rather than the active path).

Re removing the class: only forward() is no-op — verify_input/verify_output still validate CFG fields and handle the Matrix-Game "no prompt embeds → disable CFG" fallback, which isn't replicated elsewhere at this point in the pipeline. So I'd keep the class. Happy to consolidate later as a follow-up if you think it's worth doing.

looks good to me now. We can keep the class for now.

@Satyam-53 Satyam-53 self-requested a review May 9, 2026 00:11
Copy link
Copy Markdown
Collaborator

@Satyam-53 Satyam-53 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have marked it for the approval, maybe please take an approval from @SolitaryThinker as well before merging?

Review feedback (hao-ai-lab#1272): point to the actual CFG sites in
DenoisingStage and explain why the class is kept around.
@alexzms alexzms force-pushed the feature/kernel/vsa-attention-hot-path branch from f23ef87 to 657207c Compare May 9, 2026 00:13
@SolitaryThinker
Copy link
Copy Markdown
Collaborator

Hi Alex, this is from my agent setup's /review output. Could you double check to see if this makes sense?


Issue 1: _mask_cache is paid-for-but-unused in current usage

In fastvideo/attention/backends/flash_attn.py, the new _mask_cache: dict | None = None field on FlashAttnMetadata only delivers savings if a single FlashAttnMetadata instance is reused across multiple attn_impl.forward() calls. Every current producer of FlashAttnMetadata builds a fresh instance per attention block call and uses it exactly once:

  • fastvideo/models/dits/hunyuanvideo15.py:702-709FlashAttnMetadataBuilder().build(...) inside the per-block forward, scoped to one set_forward_context block over a single self.attn(...) call.
  • fastvideo/models/dits/hyworld/hyworld.py:176-192 and :196-212 — same pattern; the prope path even builds a second metadata for the same block.

DistributedAttention.forward calls self.attn_impl.forward(...) exactly once per call, and within a single forward the cross-attn vs self-attn paths set disjoint cache keys. So in current usage the cache is never hit, even within a single call.

Net effect: the field, the lazy-init guard (if attn_metadata._mask_cache is None: attn_metadata._mask_cache = {}), and three if … not in cache lookups run on every forward but never return a cached value. Functionally correct, but the bit-exactness on Wan2.1 is also consistent with "this code path was never exercised by Wan2.1" (Wan uses VSA, not FlashAttn-with-mask). The hot-path saving the description advertises will only materialize after a corresponding caller change in hunyuanvideo15.py / hyworld.py to reuse one metadata across blocks.

Worth either (a) doing that caller change in this PR so the optimization is actually load-bearing, or (b) trimming the cache and keeping just the helper-extraction.

There is also a latent fragility: the ("kpad", key_len) key does not include a fingerprint of attn_mask. If a future caller does reuse one metadata across blocks AND mutates attn_metadata.attn_mask in-place between calls, the cached key_padding_mask will be stale. Not a bug today, but worth a comment in the dataclass that the cache assumes attn_mask is immutable for the metadata's lifetime.


Issue 2: tile_buf reuse establishes an implicit "consume before next call" contract

In fastvideo/attention/backends/video_sparse_attn.py, tile() now returns the shared attn_metadata.tile_buf rather than a fresh allocation. The PR's correctness argument — pad positions stay zero across reuses because non_pad_index is fixed on the metadata — holds, since non_pad_index is set once in build() from an lru_cache'd helper.

The remaining concern is aliasing on the caller side. Both VSA call paths I checked are safe today:

  • DistributedAttention_VSA.forward (fastvideo/attention/layer.py:204-207): qkvg = preprocess_qkv(...)chunk(4)attn_impl.forward(q, k, v, gate_compress, ...), where forward immediately runs .transpose(1, 2).contiguous() on each chunk, materializing a copy before returning.
  • fastvideo/models/dits/ltx2.py:1084-1088: same pattern.

Both consume the buf and copy it before any other layer's preprocess_qkv runs, so the reuse is safe under sequential single-stream execution. But this is now an unstated invariant of the backend: any future caller that holds the tile() return value across another VSA layer's preprocess_qkv call (or that runs them concurrently on different CUDA streams) will silently get corrupted data.

Consider documenting on tile() (or preprocess_qkv) that the returned tensor is only valid until the next VSA-layer call on the same metadata. Same comment applies more weakly to the _mask_cache field — worth noting the cache assumes attn_metadata.attn_mask is not mutated for the metadata's lifetime.


No correctness blockers — just a gap between claim and reality on Issue 1, and an implicit contract worth documenting on Issue 2.

mergify Bot and others added 3 commits May 9, 2026 19:21
… contract

Review feedback (hao-ai-lab#1272): the per-metadata _mask_cache field never
hits in current usage because every caller (hunyuanvideo15.py:702,
hyworld.py:177, :197) builds a fresh FlashAttnMetadata per attention
block call.  Trim the field, lazy-init guard, and three lookup
branches; keep the _key_padding_mask_from_attn_mask helper extraction
which still simplifies the call site.  Making the cache load-bearing
would require sharing one metadata across blocks at the caller side
and is left to a follow-up.

Also document on VideoSparseAttentionBackend.tile() / preprocess_qkv()
that the returned tensor aliases attn_metadata.tile_buf and is only
valid until the next VSA-layer call on the same metadata.  Both
existing call sites copy via .transpose(...).contiguous() inside
forward() so the contract holds today; the docstring fixes the
implicit-invariant gap surfaced in review.
@alexzms
Copy link
Copy Markdown
Collaborator Author

alexzms commented May 10, 2026

Thanks @SolitaryThinker — yep, both points hold up. Just pushed 0d2691d1 to address them:

Issue 1 (_mask_cache) — agreed, every current caller (hunyuanvideo15.py:702, hyworld.py:177 & :197) builds a fresh FlashAttnMetadata per attention block, so the cache never hits. Dropped the field, lazy-init guard, and the three lookup branches; kept _key_padding_mask_from_attn_mask as a top-level helper since the extraction stands on its own. Making the cache load-bearing would need a caller-side refactor to share one metadata across blocks — happy to do that as a follow-up rather than expand scope here.

Issue 2 (tile_buf contract) — added a docstring on tile() and preprocess_qkv() spelling out that the returned tensor aliases attn_metadata.tile_buf and is only valid until the next VSA-layer call on the same metadata, with a note that the current call sites copy via .transpose(...).contiguous() inside forward() so the contract holds today.

@SolitaryThinker SolitaryThinker merged commit 636d3b7 into hao-ai-lab:main May 10, 2026
15 of 18 checks passed
@SolitaryThinker SolitaryThinker deleted the feature/kernel/vsa-attention-hot-path branch May 10, 2026 18:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready PR is ready to merge scope: attention Attention backends (VSA, STA, Flash, etc.) scope: inference Inference pipeline, serving, CLI type: misc Cleanup, config, dependencies

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants