Skip to content

Add attention mask input support for flash backend#13479

Open
zhtmike wants to merge 8 commits intohuggingface:mainfrom
zhtmike:flash-attn-mask
Open

Add attention mask input support for flash backend#13479
zhtmike wants to merge 8 commits intohuggingface:mainfrom
zhtmike:flash-attn-mask

Conversation

@zhtmike
Copy link
Copy Markdown
Contributor

@zhtmike zhtmike commented Apr 15, 2026

What does this PR do?

This PR adds support for attention mask input when using the attention backend with set_attention_backend("flash"). With this change, QwenImagePipeline can run with the flash backend w/ or w/o Ulysses SP.

For FlashAttention 2, it is not feasible to use _wrapped_flash_attn_forward directly when a mask is applied. To maintain compatibility with the current interface, we introduce an additional branch for FlashAttention to handle attention masks.

# forward pass
-. w/o mask: _wrapped_flash_attn_forward()
-. w/ mask (new): _pack_qkv() --> _wrapped_flash_attn_varlen_forward() --> unpack()
# backward pass
-. w/o mask: stored tensor ->  _wrapped_flash_attn_forward()
-. w/ mask (new): stored packed tensor -> _wrapped_flash_attn_varlen_backward() -> unpack()

I haven't tested with ring attention, so it is left as unimplemented.

Fixes # (issue)

Before submitting

Who can review?

@sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions github-actions bot added models tests size/L PR with diff > 200 LOC labels Apr 15, 2026
@zhtmike
Copy link
Copy Markdown
Contributor Author

zhtmike commented Apr 15, 2026

code snippet to show it works

import torch
import torch.distributed as dist
import argparse
import os
from diffusers import QwenImagePipeline
from diffusers import ContextParallelConfig


def parse_args():
    parser = argparse.ArgumentParser(
        description="Test Qwen-Image with Context Parallelism")
    return parser.parse_args()


args = parse_args()

if dist.is_available():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

model_id = os.path.expanduser("~/models/Qwen/Qwen-Image")

pipe = QwenImagePipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
)
pipe.to(device)

pipe.transformer.set_attention_backend("flash")   # <--------- here 
if world_size > 1:
    from diffusers import QwenImageTransformer2DModel
    assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
    pipe.transformer.enable_parallelism(config=ContextParallelConfig(
        ulysses_degree=world_size))

pipe.set_progress_bar_config(disable=rank != 0)

positive_magic = {
    "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
    "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}
prompts = [
    "A coffee shop entrance features a chalkboard sign reading "
    '"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
    'displaying "通义千问". Next to it hangs a poster showing a '
    "beautiful Chinese woman, and beneath the poster is written "
    '"π≈3.1415926-53589793-23846264-33832795-02384197". '
    "Ultra HD, 4K, cinematic composition",
    "A cute cat with long hair sitting on a sofa, Ultra HD, 4K, cinematic composition."
]

inputs = {
    "prompt": [p + positive_magic["en"] for p in prompts],
    "generator": torch.Generator(device="cpu").manual_seed(0),
    "true_cfg_scale": 4.0,
    "negative_prompt": " ",
    "num_inference_steps": 50,
    "num_images_per_prompt": 1,
    "height": 1024,
    "width": 1024,
}

with torch.inference_mode():
    output = pipe(**inputs)
    for i, output_image in enumerate(output.images):
        if world_size > 1:
            save_path = f"output_image_ulysses{world_size}_{i}.png"
        else:
            save_path = f"output_image_{i}.png"
        if rank == 0:
            output_image.save(save_path)
            print(f"image saved at {save_path}")

if dist.is_initialized():
    dist.destroy_process_group()

Produces the following images:
output_image_0
output_image_1

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 15, 2026
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 16, 2026
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 16, 2026
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 16, 2026
@zhtmike
Copy link
Copy Markdown
Contributor Author

zhtmike commented Apr 16, 2026

Hi @sayakpaul, the PR is ready for review, please take a look once you have time

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR! I left some comments, LMK what you think.

Should it be propagated to FA3, too, perhaps in a different PR?

try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
from flash_attn.flash_attn_interface import (
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT of constraining the changes only to FLASH_HUB?

AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(

This way, people won't have to build the flash attention wheel locally.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will be deprecating the non-Hub variants for FLASH and FLASH_3` soonish anyway.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK let me move to the hub version of flash attention 2 then

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Moved to FLASH_HUB.



@dataclass
class _VarlenPackedInputs:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it apply to all varlen attention kernels, though? Or does it come to fruition only during CP?

We do have VARLEN implementations of a few backends already:

AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to use dataclasses for this? Won't it be better to apply the transformations inline for readability?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have VARLEN implementations of a few backends already:

Yes but seems all VARLEN implementation only works with non-CP case. So it is fine to work with transformation inline for non-CP.

Any reason to use dataclasses for this? Won't it be better to apply the transformations inline for readability?

Let me drop dataclasses, it is just my personal perference

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped.

return packed_out.view(self.batch_size, self.seq_len_q, *packed_out.shape[1:])


def _pack_qkv(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this function if we decided to do the transformations in-line?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped.


if attn_mask is not None:
if return_lse:
raise NotImplementedError("`return_lse=True` with `attn_mask` is not yet supported for flash-attn 2.")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means Ring isn't supported yet right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Not supported yet.

Comment thread tests/others/test_flash_attention.py Outdated
import torch
import torch.nn.functional as F

from diffusers.models.attention_dispatch import (
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do the testing in #13182.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped. Let us wait #13182 first?

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed tests size/L PR with diff > 200 LOC labels Apr 16, 2026
@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 16, 2026
@zhtmike
Copy link
Copy Markdown
Contributor Author

zhtmike commented Apr 16, 2026

Thanks a lot for the PR! I left some comments, LMK what you think.

Should it be propagated to FA3, too, perhaps in a different PR?

Agree. I will implement FA3 as well, in different PR

@sayakpaul
Copy link
Copy Markdown
Member

@claude could you do a review?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 16, 2026

Claude finished @sayakpaul's task in 3m 25s —— View job


Review of PR #13479: Add attention mask input support for flash backend

  • Read repository review guidelines (AGENTS.md, models.md, modular.md)
  • Read the PR diff
  • Read relevant context in attention_dispatch.py
  • Post review feedback

Overall: The approach is correct — using flash_attn_varlen_func / _wrapped_flash_attn_varlen_{forward,backward} for masked attention is the standard way to handle padding masks in FlashAttention 2. The Q stays fully packed (flatten), K/V are selectively packed by mask indices, which is correct for key-side padding masks. The backward pass properly dispatches between the varlen and regular paths via ctx.is_varlen_masked. Ring attention is correctly guarded.

I have a few comments:


1. Missing None-check on varlen_kernel_fn (potential runtime error)

In _flash_attention_hub (non-CP masked path, ~line 2591), varlen_kernel_fn is used without a None guard:

varlen_func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].varlen_kernel_fn
out_packed = varlen_func(...)  # AttributeError if varlen_kernel_fn is None

The forward_op path correctly checks wrapped_varlen_forward_fn is None and raises a descriptive RuntimeError. The same pattern should be applied here:

varlen_func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].varlen_kernel_fn
if varlen_func is None:
    raise RuntimeError(
        "Flash attention hub kernels must expose `flash_attn_varlen_func` for masked attention."
    )

In practice, _maybe_download_kernel_for_backend should populate it, but a guard is safer and consistent with the forward_op style.

Fix this →


2. Dead code: unreachable return_lse handling (very minor)

In _flash_attention_hub (~lines 2605-2606):

if return_lse:
    out_packed, lse, *_ = out_packed

This is unreachable because return_lse=True already raises NotImplementedError at line 2575-2578. It's harmless but could confuse readers into thinking the return_lse path is partially supported. Consider removing it.


3. Code duplication between _flash_attention_hub_forward_op and _flash_attention_hub (nit)

The mask handling logic (~30 lines: normalize mask → compute cu_seqlens → compute indices → pack Q/K/V → call varlen → reshape output) is duplicated between:

  • _flash_attention_hub_forward_op (lines ~1266-1336, CP path with ctx saving)
  • _flash_attention_hub (lines ~2574-2608, non-CP path)

I recognize this follows the existing codebase pattern where the forward_op (using _wrapped_* internals for backward support) and the main function (using the public API) are kept separate. So this is fine structurally — just noting it for awareness.


4. Ulysses + mask interaction looks correct

I traced through the TemplatedUlyssesAttention.forward path: after the all-to-all, Q/K/V have shape (B, S_FULL, H_LOCAL, D). The full-sized mask (B, S_KV) matches the now-full KV sequence length, so _normalize_attn_mask and the varlen packing work correctly. The backward path also works because the packed tensors and indices are saved to ctx.


5. Constraint addition looks good

Adding _check_attn_mask_or_causal to the FLASH_HUB constraints is correct — FlashAttention varlen doesn't support both attn_mask and is_causal=True simultaneously.


Summary

The PR is well-structured and the core logic is correct. The main actionable item is #1 (add a None-check on varlen_kernel_fn). Item #2 (dead code) is optional cleanup. The rest looks good.

@github-actions github-actions bot added size/L PR with diff > 200 LOC and removed size/L PR with diff > 200 LOC labels Apr 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/L PR with diff > 200 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants