add SP support for _flash_3_varlen_hub backend#13809
Conversation
| if attn_mask is not None: | ||
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | ||
|
|
||
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen( | ||
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device | ||
| ) | ||
| ) | ||
|
|
||
| key_valid, value_valid = [], [] | ||
| for b in range(batch_size): | ||
| valid_len = seqlens_k[b] | ||
| key_valid.append(key[b, :valid_len]) | ||
| value_valid.append(value[b, :valid_len]) |
There was a problem hiding this comment.
This seems like should come under if _parallel_config is None and attn_mask is not None:?
| max_seqlen_k=max_seqlen_k, | ||
| softmax_scale=scale, | ||
| causal=is_causal, | ||
| return_attn_probs=return_lse, |
There was a problem hiding this comment.
This seems like an extra argument?
| return_attn_probs=return_lse, | ||
| ) | ||
| if return_lse: | ||
| out, lse, *_ = out |
There was a problem hiding this comment.
Why do we need to initialize lse = None above?
|
@askserge could you do a review? |
There was a problem hiding this comment.
🤗 Serge says:
This PR adds sequence parallel (context parallel) support for the _flash_3_varlen_hub attention backend, following the same pattern established by the existing _flash_varlen_hub (flash-attn2) and _flash_3_hub implementations. The code is well-structured and closely mirrors the existing patterns.
Correctness
-
Potential bug:
indices_kused before assignment on the no-mask path. In_flash_attention_3_varlen_hub_forward_op, whenattn_mask is None, the variableindices_kis never assigned, but at line 1721ctx.indices_k = indices_k if attn_mask is not None else None— this is actually fine because the conditional guards it. However, if_save_ctxisFalse,indices_kis never referenced at all on the no-mask path, so there's no issue. This matches the flash-attn2 varlen pattern exactly. -
Positional argument fragility in
wrapped_forward_fncall. The non-varlen_flash_attention_3_hub_forward_opuses keyword arguments forcausal,window_size_left, etc., but the new varlen forward op passes everything positionally (lines 1677–1712). This makes the code harder to read and more fragile if the upstream_flash_attn_forwardsignature changes. Consider using keyword arguments for at least the trailing parameters, consistent with the non-varlen version. -
return_lsehandling change in the non-SP path. The original code always unpackedout, lse, *_from the function call. The new code passesreturn_attn_probs=return_lseand conditionally unpacks. This is a behavioral change for the non-SP path — ifreturn_lse=False, the function now returns a single tensor instead of a tuple. This should be verified to work correctly with theflash_attn_varlen_funcAPI. The flash-attn2 varlen hub usesreturn_attn_probs=return_lsesimilarly, so this is likely correct.
Tests
- The new backend is properly added to
ContextParallelAttentionBackendsTesterMixinand thering_degreeskip logic. - The
_FLASH_3_VARLEN_HUBis added to the hub kernels set intests/models/testing_utils/utils.py.
Minor Issues
- Bug in existing test code (pre-existing). Line 413 (deleted):
attention_backend in ("flash_varlen_hub")— usinginwith a parenthesized string (not a tuple) means this is justattention_backend in "flash_varlen_hub", which checks character membership, not string membership. The fix on line 417 adds a proper tuple("flash_varlen_hub", "_flash_3_varlen_hub"), which is correct. However, the original single-element check was buggy. Good that it's fixed now.
Overall the implementation follows established patterns well and looks correct.
17 LLM turns · 18 tool calls · 92.5s · 309778 in / 4124 out tokens
| value_packed = value.flatten(0, 1) | ||
| seqlens_k = None | ||
|
|
||
| out_packed, softmax_lse, *_ = wrapped_forward_fn( |
There was a problem hiding this comment.
Nit: The non-varlen _flash_attention_3_hub_forward_op uses keyword arguments for the trailing parameters (causal=is_causal, window_size_left=window_size[0], etc.), but here everything is passed positionally with no inline comments explaining what each None corresponds to. This makes the code harder to audit and fragile if the upstream signature changes.
Consider either:
- Using keyword arguments for at least the trailing parameters (like the non-varlen version does), or
- Adding inline comments for the positional
Nonevalues (like the non-varlen version does with# k_new, v_new,# cu_seqlens_q/k/k_new, etc.)
| window_size[1], | ||
| 0, | ||
| softcap, | ||
| True, |
There was a problem hiding this comment.
What does True correspond to here? Looking at the non-varlen version, the parameters after softcap are num_splits, pack_gqa, sm_margin — but here there are two extra positional args (True and None) before num_splits. This likely corresponds to return_softmax=True and perhaps gen_=None or similar. Please add an inline comment to clarify, or use keyword arguments.
| max_seqlen_k=max_seqlen_k, | ||
| softmax_scale=scale, | ||
| causal=is_causal, | ||
| return_attn_probs=return_lse, |
There was a problem hiding this comment.
Note: The original code always unpacked out, lse, *_ = func(...). Now with return_attn_probs=return_lse, when return_lse=False the return value may be different (single tensor vs tuple). Make sure flash_attn_varlen_func from flash-attn3 returns a single tensor (not a tuple) when return_attn_probs=False. The flash-attn2 varlen hub uses the same pattern, so this is likely fine, but worth verifying.
What does this PR do?
A follow up work for #13479. I have added
_flash_3_varlen_hubsupport for SP forward & backward.Tested with QwenImage pipeline, the result image is expected.
Tested with QwenImage training with SP, there is no error.
The UTs for Flux and QwenImage are passed.
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.