Skip to content

Fix the attention mask in ulysses SP for QwenImage#13278

Merged
dg845 merged 6 commits intohuggingface:mainfrom
zhtmike:fix_sp
Mar 24, 2026
Merged

Fix the attention mask in ulysses SP for QwenImage#13278
dg845 merged 6 commits intohuggingface:mainfrom
zhtmike:fix_sp

Conversation

@zhtmike
Copy link
Contributor

@zhtmike zhtmike commented Mar 17, 2026

What does this PR do?

Fix issue #13277.

QwenImagePipeline cannot run with Ulysses SP together with batch prompt inputs. It is related to the mask is not correctly broadcasted.
We need to broadcast the attention mask from [B, S] to [B, H, S_q, S_kv] or simply [B, 1, 1, S_kv] before feeding into SDPA.

Before Fix:

We have the error when running the code snippet mentioned in the issue.

RuntimeError: The expanded size of the tensor (4222) must match the existing size (2) at non-singleton dimension 2.  Target sizes: [2, 12, 4222, 4222].  Tensor sizes: [2, 4222]

After Fix:

The images are correctly produced.
output_image_ulysses2_0
output_image_ulysses2_1

Fixes # (issue)

Before submitting

Who can review?

@sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Comment on lines +816 to +817
if attn_mask is not None and attn_mask.dim() == 2:
attn_mask = attn_mask[:, None, None, :]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this Qwen specific?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't test with other models. But I think models with a 2D masks input should have the similar problem

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to check out one other? And also run the

class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. From a quick scan, most models seem to handle the mask shape correctly in their own implementations. So I’ve limited the modification to QwenImage only.

Should I run any test cases?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Maybe we could add a similar test to

class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
?

I will give you a ping once it's refactored to follow the latest pattern.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry disregard my suggestion on using the CUDNN backend.

Yes, native attention x Ulysses is perfect for single prompt input. Currently batch inputs have some problem.

Is it the case just for Qwen or the same happens for Flux, as well? Also, the test under consideration -- does it not use a single prompt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the case just for Qwen or the same happens for Flux, as well?

So far, I have only found that Qwen has this problem. Other models, such as Z-Image, HunyuanImage
expand the attention mask in a similar way before entering the attention block. For Flux, I tested with the main branch, and it works fine with both CP and batch inputs.

Also, the test under consideration -- does it not use a single prompt?

I am wondering whether we should add a batch input test if possible. At the beginning, I think we should first ensure that all unit tests pass without modifying them.

The background of this bug is that we are working on the training engine based on the Diffusers backend, using QwenImage as the first example. Therefore, we may need a combination of batch inputs (for high throughput) as well as Ulysses SP support. This is why we encountered this bug during the forward process.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed and thanks so much for the context!

I am wondering whether we should add a batch input test if possible. At the beginning, I think we should first ensure that all unit tests pass without modifying them.

Would you like to take a crack at this? We'll be quick to review.

I think first we need to ensure that the test_context_parallel_inference() test is xfailed when ring attention is enabled with the SDPA. #13182 is adding a test suite for CP-backends and attention backends.

And then a test for batched inputs.

Then let's revisit this PR?

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure NP, I will add a UT test for batch input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sayakpaul , I have added a PR #13312 for this, could you take a look?

@zhtmike zhtmike changed the title Fix the attention mask in ulysses SP Fix the attention mask in ulysses SP for QwenImage Mar 17, 2026
@sayakpaul sayakpaul requested a review from kashif March 17, 2026 10:13
@sayakpaul
Copy link
Member

@naykun if you want to take a look

batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
joint_attention_mask = joint_attention_mask[:, None, None, :]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this okay for non-CP?

Copy link
Contributor Author

@zhtmike zhtmike Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The image is same w/o this change.

@zhtmike
Copy link
Contributor Author

zhtmike commented Mar 24, 2026

Following up on #13312: we can drop the xfail after fixing the QwenImage mask.

running pytest tests/models/transformers/test_models_transformer_qwenimage.py:

=========================== short test summary info ============================
FAILED tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerLoRAHotSwap::test_hotswapping_compiled_model_linear[11-11]
FAILED tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerLoRAHotSwap::test_hotswapping_compiled_model_linear[7-13]
FAILED tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerLoRAHotSwap::test_hotswapping_compiled_model_linear[13-7]
FAILED tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerLoRAHotSwap::test_hotswapping_compiled_model_both_linear_and_other[11-11]
FAILED tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerLoRAHotSwap::test_hotswapping_compiled_model_both_linear_and_other[7-13]
FAILED tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerLoRAHotSwap::test_hotswapping_compiled_model_both_linear_and_other[13-7]
FAILED tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerLoRAHotSwap::test_enable_lora_hotswap_called_after_adapter_added_warning
====== 7 failed, 69 passed, 54 skipped, 36 warnings in 182.34s (0:03:02) =======

The error of LoRAHotSwap seems unrelated.

@sayakpaul sayakpaul requested a review from dg845 March 24, 2026 04:20
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

I ran the tests on my end and they are passing (CP tests). Hotswapping test failures are not needed.

Copy link
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@dg845
Copy link
Collaborator

dg845 commented Mar 24, 2026

Merging as the CI is green.

@dg845 dg845 merged commit afdda57 into huggingface:main Mar 24, 2026
11 checks passed
@zhtmike zhtmike deleted the fix_sp branch March 24, 2026 09:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants