diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index e6b19377b14f..713a1bec70a5 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2025 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,49 +12,84 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +import warnings import torch from diffusers import QwenImageTransformer2DModel from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + BitsAndBytesTesterMixin, + ContextParallelTesterMixin, + LoraHotSwappingForModelTesterMixin, + LoraTesterMixin, + MemoryTesterMixin, + ModelTesterMixin, + TorchAoTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = QwenImageTransformer2DModel - main_input_name = "hidden_states" - # We override the items here because the transformer under consideration is small. - model_split_percents = [0.7, 0.6, 0.6] - - # Skip setting testing with default: AttnProcessor - uses_custom_attn_processor = True - +class QwenImageTransformerTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - return self.prepare_dummy_input() + def model_class(self): + return QwenImageTransformer2DModel @property - def input_shape(self): + def output_shape(self) -> tuple[int, int]: return (16, 16) @property - def output_shape(self): + def input_shape(self) -> tuple[int, int]: return (16, 16) - def prepare_dummy_input(self, height=4, width=4): + @property + def model_split_percents(self) -> list: + return [0.7, 0.6, 0.6] + + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self) -> dict[str, int | list[int]]: + return { + "patch_size": 2, + "in_channels": 16, + "out_channels": 4, + "num_layers": 2, + "attention_head_dim": 16, + "num_attention_heads": 4, + "joint_attention_dim": 16, + "guidance_embeds": False, + "axes_dims_rope": (8, 4, 4), + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: batch_size = 1 num_latent_channels = embedding_dim = 16 - sequence_length = 7 + height = width = 4 + sequence_length = 8 vae_scale_factor = 4 - hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) orig_height = height * 2 * vae_scale_factor @@ -70,89 +104,57 @@ def prepare_dummy_input(self, height=4, width=4): "img_shapes": img_shapes, } - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "patch_size": 2, - "in_channels": 16, - "out_channels": 4, - "num_layers": 2, - "attention_head_dim": 16, - "num_attention_heads": 3, - "joint_attention_dim": 16, - "guidance_embeds": False, - "axes_dims_rope": (8, 4, 4), - } - - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def test_gradient_checkpointing_is_applied(self): - expected_set = {"QwenImageTransformer2DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) +class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin): def test_infers_text_seq_len_from_mask(self): - """Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - # Test 1: Contiguous mask with padding at the end (only first 2 tokens valid) encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() - encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid + encoder_hidden_states_mask[:, 2:] = 0 rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], encoder_hidden_states_mask ) - # Verify rope_text_seq_len is returned as an int (for torch.compile compatibility) - self.assertIsInstance(rope_text_seq_len, int) - - # Verify per_sample_len is computed correctly (max valid position + 1 = 2) - self.assertIsInstance(per_sample_len, torch.Tensor) - self.assertEqual(int(per_sample_len.max().item()), 2) - - # Verify mask is normalized to bool dtype - self.assertTrue(normalized_mask.dtype == torch.bool) - self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values - - # Verify rope_text_seq_len is at least the sequence length - self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1]) + assert isinstance(rope_text_seq_len, int) + assert isinstance(per_sample_len, torch.Tensor) + assert int(per_sample_len.max().item()) == 2 + assert normalized_mask.dtype == torch.bool + assert normalized_mask.sum().item() == 2 + assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1] - # Test 2: Verify model runs successfully with inferred values inputs["encoder_hidden_states_mask"] = normalized_mask with torch.no_grad(): output = model(**inputs) - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] - # Test 3: Different mask pattern (padding at beginning) encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone() - encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding - encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid + encoder_hidden_states_mask2[:, :3] = 0 + encoder_hidden_states_mask2[:, 3:] = 1 rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], encoder_hidden_states_mask2 ) - # Max valid position is 6 (last token), so per_sample_len should be 7 - self.assertEqual(int(per_sample_len2.max().item()), 7) - self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values + assert int(per_sample_len2.max().item()) == 8 + assert normalized_mask2.sum().item() == 5 - # Test 4: No mask provided (None case) rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], None ) - self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1]) - self.assertIsInstance(rope_text_seq_len_none, int) - self.assertIsNone(per_sample_len_none) - self.assertIsNone(normalized_mask_none) + assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1] + assert isinstance(rope_text_seq_len_none, int) + assert per_sample_len_none is None + assert normalized_mask_none is None def test_non_contiguous_attention_mask(self): - """Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - # Create a non-contiguous mask pattern: valid, padding, valid, padding, etc. encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone() - # Pattern: [True, False, True, False, True, False, False] encoder_hidden_states_mask[:, 1] = 0 encoder_hidden_states_mask[:, 3] = 0 encoder_hidden_states_mask[:, 5:] = 0 @@ -160,95 +162,85 @@ def test_non_contiguous_attention_mask(self): inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask( inputs["encoder_hidden_states"], encoder_hidden_states_mask ) - self.assertEqual(int(per_sample_len.max().item()), 5) - self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1]) - self.assertIsInstance(inferred_rope_len, int) - self.assertTrue(normalized_mask.dtype == torch.bool) + assert int(per_sample_len.max().item()) == 5 + assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1] + assert isinstance(inferred_rope_len, int) + assert normalized_mask.dtype == torch.bool inputs["encoder_hidden_states_mask"] = normalized_mask with torch.no_grad(): output = model(**inputs) - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] def test_txt_seq_lens_deprecation(self): - """Test that passing txt_seq_lens raises a deprecation warning.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) - # Prepare inputs with txt_seq_lens (deprecated parameter) txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]] - # Remove encoder_hidden_states_mask to use the deprecated path inputs_with_deprecated = inputs.copy() inputs_with_deprecated.pop("encoder_hidden_states_mask") inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens - # Test that deprecation warning is raised - with self.assertWarns(FutureWarning) as warning_context: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") with torch.no_grad(): output = model(**inputs_with_deprecated) - # Verify the warning message mentions the deprecation - warning_message = str(warning_context.warning) - self.assertIn("txt_seq_lens", warning_message) - self.assertIn("deprecated", warning_message) - self.assertIn("encoder_hidden_states_mask", warning_message) + future_warnings = [x for x in w if issubclass(x.category, FutureWarning)] + assert len(future_warnings) > 0, "Expected FutureWarning to be raised" - # Verify the model still works correctly despite the deprecation - self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1]) + warning_message = str(future_warnings[0].message) + assert "txt_seq_lens" in warning_message + assert "deprecated" in warning_message + + assert output.sample.shape[1] == inputs["hidden_states"].shape[1] def test_layered_model_with_mask(self): - """Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model).""" - # Create layered model config init_dict = { "patch_size": 2, "in_channels": 16, "out_channels": 4, "num_layers": 2, "attention_head_dim": 16, - "num_attention_heads": 3, + "num_attention_heads": 4, "joint_attention_dim": 16, - "axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16) - "use_layer3d_rope": True, # Enable layered RoPE - "use_additional_t_cond": True, # Enable additional time conditioning + "axes_dims_rope": (8, 4, 4), + "use_layer3d_rope": True, + "use_additional_t_cond": True, } model = self.model_class(**init_dict).to(torch_device) - # Verify the model uses QwenEmbedLayer3DRope from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope - self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope) + assert isinstance(model.pos_embed, QwenEmbedLayer3DRope) - # Test single generation with layered structure batch_size = 1 - text_seq_len = 7 + text_seq_len = 8 img_h, img_w = 4, 4 layers = 4 - # For layered model: (layers + 1) because we have N layers + 1 combined image hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device) encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device) - # Create mask with some padding encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device) - encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens + encoder_hidden_states_mask[0, 5:] = 0 timestep = torch.tensor([1.0]).to(torch_device) - # additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding) addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device) - # Layer structure: 4 layers + 1 condition image img_shapes = [ [ - (1, img_h, img_w), # layer 0 - (1, img_h, img_w), # layer 1 - (1, img_h, img_w), # layer 2 - (1, img_h, img_w), # layer 3 - (1, img_h, img_w), # condition image (last one gets special treatment) + (1, img_h, img_w), + (1, img_h, img_w), + (1, img_h, img_w), + (1, img_h, img_w), + (1, img_h, img_w), ] ] @@ -262,37 +254,113 @@ def test_layered_model_with_mask(self): additional_t_cond=addition_t_cond, ) - self.assertEqual(output.sample.shape[1], hidden_states.shape[1]) + assert output.sample.shape[1] == hidden_states.shape[1] + + +class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for QwenImage Transformer.""" + + +class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin): + """Training tests for QwenImage Transformer.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"QwenImageTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for QwenImage Transformer.""" -class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = QwenImageTransformer2DModel +class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin): + """Context Parallel inference tests for QwenImage Transformer.""" - def prepare_init_args_and_inputs_for_common(self): - return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common() - def prepare_dummy_input(self, height, width): - return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width) +class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin): + """LoRA adapter tests for QwenImage Transformer.""" - def test_torch_compile_recompilation_and_graph_break(self): - super().test_torch_compile_recompilation_and_graph_break() + +class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): + """LoRA hot-swapping tests for QwenImage Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + batch_size = 1 + num_latent_channels = embedding_dim = 16 + sequence_length = 8 + vae_scale_factor = 4 + + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } + + +class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for QwenImage Transformer.""" + + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] + + def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: + batch_size = 1 + num_latent_channels = embedding_dim = 16 + sequence_length = 8 + vae_scale_factor = 4 + + hidden_states = randn_tensor( + (batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device + ) + encoder_hidden_states = randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ) + encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + } def test_torch_compile_with_and_without_mask(self): - """Test that torch.compile works with both None mask and padding mask.""" - init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) model.eval() model.compile(mode="default", fullgraph=True) - # Test 1: Run with None mask (no padding, all tokens are valid) inputs_no_mask = inputs.copy() inputs_no_mask["encoder_hidden_states_mask"] = None - # First run to allow compilation with torch.no_grad(): output_no_mask = model(**inputs_no_mask) - # Second run to verify no recompilation with ( torch._inductor.utils.fresh_inductor_cache(), torch._dynamo.config.patch(error_on_recompile=True), @@ -300,19 +368,15 @@ def test_torch_compile_with_and_without_mask(self): ): output_no_mask_2 = model(**inputs_no_mask) - self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1]) - self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1] + assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1] - # Test 2: Run with all-ones mask (should behave like None) inputs_all_ones = inputs.copy() - # Keep the all-ones mask - self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item()) + assert inputs_all_ones["encoder_hidden_states_mask"].all().item() - # First run to allow compilation with torch.no_grad(): output_all_ones = model(**inputs_all_ones) - # Second run to verify no recompilation with ( torch._inductor.utils.fresh_inductor_cache(), torch._dynamo.config.patch(error_on_recompile=True), @@ -320,21 +384,18 @@ def test_torch_compile_with_and_without_mask(self): ): output_all_ones_2 = model(**inputs_all_ones) - self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1]) - self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1] + assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1] - # Test 3: Run with actual padding mask (has zeros) inputs_with_padding = inputs.copy() mask_with_padding = inputs["encoder_hidden_states_mask"].clone() - mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding + mask_with_padding[:, 4:] = 0 inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding - # First run to allow compilation with torch.no_grad(): output_with_padding = model(**inputs_with_padding) - # Second run to verify no recompilation with ( torch._inductor.utils.fresh_inductor_cache(), torch._dynamo.config.patch(error_on_recompile=True), @@ -342,8 +403,15 @@ def test_torch_compile_with_and_without_mask(self): ): output_with_padding_2 = model(**inputs_with_padding) - self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1]) - self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1]) + assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1] + assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1] + + assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3) + + +class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin): + """BitsAndBytes quantization tests for QwenImage Transformer.""" + - # Verify that outputs are different (mask should affect results) - self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)) +class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin): + """TorchAO quantization tests for QwenImage Transformer."""