Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -330,6 +331,19 @@ def __call__(
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)

if attention_mask is None and encoder_hidden_states_mask is not None:
# The joint sequence is [text, image].
seq_len_img = hidden_states.shape[1]
img_mask = torch.ones(
encoder_hidden_states_mask.shape[0], seq_len_img, device=encoder_hidden_states_mask.device
)
attention_mask = torch.cat([encoder_hidden_states_mask, img_mask], dim=1)

# Convert the mask to the format expected by SDPA
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=joint_query.dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(joint_query.dtype).min

# Compute joint attention
joint_hidden_states = dispatch_attention_fn(
joint_query,
Expand Down Expand Up @@ -600,6 +614,16 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if dist.is_initialized():
world_size = dist.get_world_size()
if world_size > 1 and encoder_hidden_states is not None:
seq_len = encoder_hidden_states.shape[1]
pad_len = (world_size - seq_len % world_size) % world_size
if pad_len > 0:
encoder_hidden_states = F.pad(encoder_hidden_states, (0, 0, 0, pad_len))
if encoder_hidden_states_mask is not None:
encoder_hidden_states_mask = F.pad(encoder_hidden_states_mask, (0, pad_len), value=0)

if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
Expand Down Expand Up @@ -630,7 +654,13 @@ def forward(
else self.time_text_embed(timestep, guidance, hidden_states)
)

image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
# use the shape of the padded hidden states to generate the rotary embeddings
if encoder_hidden_states is not None:
recalculated_txt_seq_lens = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0]
else:
recalculated_txt_seq_lens = txt_seq_lens

image_rotary_emb = self.pos_embed(img_shapes, recalculated_txt_seq_lens, device=hidden_states.device)

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down
39 changes: 39 additions & 0 deletions tests/pipelines/qwenimage/test_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import unittest
from unittest.mock import patch

import numpy as np
import torch
Expand Down Expand Up @@ -234,3 +235,41 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
expected_diff_max,
"VAE tiling should not affect the inference results",
)

def test_context_parallelism_padding_fix(self):
"""
Compare pipeline outputs: baseline (normal single-process) vs
simulated multi-process (mocked torch.distributed). This verifies
padding logic does not change the generated image.
"""
device = "cpu"

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)

# Baseline run (no distributed)
baseline_image = pipe(**inputs).images[0]

# Re-initialize inputs to get a fresh generator with the same seed for a fair comparison
inputs = self.get_dummy_inputs(device)

# Simulate distributed env (world_size = 3) so padding branch runs
# NOTE: patch target must match where `dist` is imported in the transformer module.
with (
patch("diffusers.models.transformers.transformer_qwenimage.dist.is_initialized", return_value=True),
patch("diffusers.models.transformers.transformer_qwenimage.dist.get_world_size", return_value=3),
):
padded_image = pipe(**inputs).images[0]

# shape check
self.assertEqual(baseline_image.shape, padded_image.shape)

# Additional check: verify padding didn't introduce extreme values
self.assertTrue(torch.isfinite(padded_image).all())

# Verify numerical equivalence
self.assertTrue(torch.allclose(baseline_image, padded_image, atol=1e-2, rtol=1e-2))