From baf42dbd26c5013f83e64633ca200c5c70cedc28 Mon Sep 17 00:00:00 2001 From: Ratish1 Date: Wed, 5 Nov 2025 18:25:11 +0400 Subject: [PATCH] fix(qwenimage): Correct context parallelism padding --- .../transformers/transformer_qwenimage.py | 32 ++++++++++++++- tests/pipelines/qwenimage/test_qwenimage.py | 39 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c0fa031b9faf..992abddb72a9 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -18,6 +18,7 @@ import numpy as np import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -330,6 +331,19 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) + if attention_mask is None and encoder_hidden_states_mask is not None: + # The joint sequence is [text, image]. + seq_len_img = hidden_states.shape[1] + img_mask = torch.ones( + encoder_hidden_states_mask.shape[0], seq_len_img, device=encoder_hidden_states_mask.device + ) + attention_mask = torch.cat([encoder_hidden_states_mask, img_mask], dim=1) + + # Convert the mask to the format expected by SDPA + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=joint_query.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(joint_query.dtype).min + # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, @@ -600,6 +614,16 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if dist.is_initialized(): + world_size = dist.get_world_size() + if world_size > 1 and encoder_hidden_states is not None: + seq_len = encoder_hidden_states.shape[1] + pad_len = (world_size - seq_len % world_size) % world_size + if pad_len > 0: + encoder_hidden_states = F.pad(encoder_hidden_states, (0, 0, 0, pad_len)) + if encoder_hidden_states_mask is not None: + encoder_hidden_states_mask = F.pad(encoder_hidden_states_mask, (0, pad_len), value=0) + if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -630,7 +654,13 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # use the shape of the padded hidden states to generate the rotary embeddings + if encoder_hidden_states is not None: + recalculated_txt_seq_lens = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] + else: + recalculated_txt_seq_lens = txt_seq_lens + + image_rotary_emb = self.pos_embed(img_shapes, recalculated_txt_seq_lens, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/tests/pipelines/qwenimage/test_qwenimage.py b/tests/pipelines/qwenimage/test_qwenimage.py index 8ebfe7d08bc1..6a574a66d442 100644 --- a/tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/pipelines/qwenimage/test_qwenimage.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from unittest.mock import patch import numpy as np import torch @@ -234,3 +235,41 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): expected_diff_max, "VAE tiling should not affect the inference results", ) + + def test_context_parallelism_padding_fix(self): + """ + Compare pipeline outputs: baseline (normal single-process) vs + simulated multi-process (mocked torch.distributed). This verifies + padding logic does not change the generated image. + """ + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + + # Baseline run (no distributed) + baseline_image = pipe(**inputs).images[0] + + # Re-initialize inputs to get a fresh generator with the same seed for a fair comparison + inputs = self.get_dummy_inputs(device) + + # Simulate distributed env (world_size = 3) so padding branch runs + # NOTE: patch target must match where `dist` is imported in the transformer module. + with ( + patch("diffusers.models.transformers.transformer_qwenimage.dist.is_initialized", return_value=True), + patch("diffusers.models.transformers.transformer_qwenimage.dist.get_world_size", return_value=3), + ): + padded_image = pipe(**inputs).images[0] + + # shape check + self.assertEqual(baseline_image.shape, padded_image.shape) + + # Additional check: verify padding didn't introduce extreme values + self.assertTrue(torch.isfinite(padded_image).all()) + + # Verify numerical equivalence + self.assertTrue(torch.allclose(baseline_image, padded_image, atol=1e-2, rtol=1e-2))