diff --git a/benchmarks/benchmark_qwen_mask_performance.py b/benchmarks/benchmark_qwen_mask_performance.py new file mode 100644 index 000000000000..0a9dbe6daf86 --- /dev/null +++ b/benchmarks/benchmark_qwen_mask_performance.py @@ -0,0 +1,295 @@ +""" +Performance benchmark for QwenImage attention mask implementation. + +This benchmark measures: +1. Latency impact of mask processing +2. Memory overhead +3. Throughput comparison +4. CFG batching performance + +Run with: python benchmark_qwen_mask_performance.py +""" + +import gc +import time +from typing import Dict + +import pandas as pd +import torch +import torch.utils.benchmark as benchmark + +from diffusers import QwenImageTransformer2DModel + + +def flush(): + """Clean up GPU memory.""" + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + +def get_model(): + """Create a QwenImage model for benchmarking.""" + model = QwenImageTransformer2DModel( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + joint_attention_dim=16, + guidance_embeds=False, + axes_dims_rope=(8, 4, 4), # Match small model dimensions + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32 + + model = model.to(device).to(dtype).eval() + return model, device, dtype + + +def create_inputs_no_mask(batch_size, device, dtype, height=512, width=512, text_seq_len=256): + """Create inputs without mask (baseline).""" + vae_scale_factor = 16 + patch_size = 2 + + latent_height = height // vae_scale_factor // patch_size + latent_width = width // vae_scale_factor // patch_size + num_latent_pixels = latent_height * latent_width + + hidden_states = torch.randn(batch_size, num_latent_pixels, 16, device=device, dtype=dtype) + encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16, device=device, dtype=dtype) + timestep = torch.tensor([1.0], device=device, dtype=dtype).expand(batch_size) + + img_shapes = [(1, latent_height, latent_width)] * batch_size + txt_seq_lens = [text_seq_len] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + +def create_inputs_with_mask_full(batch_size, device, dtype, height=512, width=512, text_seq_len=256): + """Create inputs with all-ones mask (no actual padding).""" + inputs = create_inputs_no_mask(batch_size, device, dtype, height, width, text_seq_len) + inputs["encoder_hidden_states_mask"] = torch.ones( + batch_size, text_seq_len, dtype=torch.long, device=device + ) + return inputs + + +def create_inputs_with_padding(batch_size, device, dtype, height=512, width=512, text_seq_len=256): + """Create inputs with variable-length sequences (realistic CFG scenario).""" + vae_scale_factor = 16 + patch_size = 2 + + latent_height = height // vae_scale_factor // patch_size + latent_width = width // vae_scale_factor // patch_size + num_latent_pixels = latent_height * latent_width + + hidden_states = torch.randn(batch_size, num_latent_pixels, 16, device=device, dtype=dtype) + encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16, device=device, dtype=dtype) + + # Variable lengths: first is full, second is ~10% (simulates CFG with empty unconditional) + actual_lengths = [text_seq_len, max(1, text_seq_len // 10)] + encoder_hidden_states_mask = torch.zeros(batch_size, text_seq_len, dtype=torch.long, device=device) + for i, length in enumerate(actual_lengths): + encoder_hidden_states_mask[i, :length] = 1 + + # Zero out padding + mask_expanded = encoder_hidden_states_mask.unsqueeze(-1).to(dtype) + encoder_hidden_states = encoder_hidden_states * mask_expanded + + timestep = torch.tensor([1.0], device=device, dtype=dtype).expand(batch_size) + + img_shapes = [(1, latent_height, latent_width)] * batch_size + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": actual_lengths, + } + + +def measure_latency(model, inputs, num_warmup=5, num_runs=100): + """Measure average latency with proper warmup.""" + # Warmup + with torch.no_grad(): + for _ in range(num_warmup): + _ = model(**inputs) + + # Measure + if torch.cuda.is_available(): + torch.cuda.synchronize() + + times = [] + with torch.no_grad(): + for _ in range(num_runs): + start = time.perf_counter() + _ = model(**inputs) + if torch.cuda.is_available(): + torch.cuda.synchronize() + end = time.perf_counter() + times.append(end - start) + + return { + "mean_ms": sum(times) / len(times) * 1000, + "std_ms": (sum((t - sum(times)/len(times))**2 for t in times) / len(times)) ** 0.5 * 1000, + "min_ms": min(times) * 1000, + "max_ms": max(times) * 1000, + } + + +def measure_memory(model, inputs): + """Measure peak memory usage.""" + flush() + + if not torch.cuda.is_available(): + return {"peak_memory_mb": 0} + + with torch.no_grad(): + # Warmup + _ = model(**inputs) + + flush() + + with torch.no_grad(): + _ = model(**inputs) + peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 2) # MB + + return {"peak_memory_mb": peak_memory} + + +def benchmark_throughput(model, inputs, duration_seconds=10): + """Measure throughput (iterations per second).""" + num_iterations = 0 + start_time = time.perf_counter() + + with torch.no_grad(): + while time.perf_counter() - start_time < duration_seconds: + _ = model(**inputs) + num_iterations += 1 + if torch.cuda.is_available(): + torch.cuda.synchronize() + + elapsed = time.perf_counter() - start_time + return {"iterations_per_sec": num_iterations / elapsed} + + +def run_benchmark_suite(): + """Run complete benchmark suite.""" + print("="*80) + print("QwenImage Attention Mask Performance Benchmark") + print("="*80) + + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"\nDevice: {device}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name()}") + print(f"CUDA Version: {torch.version.cuda}") + print() + + results = [] + + # Configuration (smaller for faster benchmarking) + batch_size = 2 + height = 256 # Smaller resolution for faster benchmarking + width = 256 + text_seq_len = 64 # Shorter sequences for faster benchmarking + + scenarios = [ + ("Baseline (no mask)", lambda m, d, dt: create_inputs_no_mask(batch_size, d, dt, height, width, text_seq_len)), + ("Mask all-ones (no padding)", lambda m, d, dt: create_inputs_with_mask_full(batch_size, d, dt, height, width, text_seq_len)), + ("Mask with padding (CFG)", lambda m, d, dt: create_inputs_with_padding(batch_size, d, dt, height, width, text_seq_len)), + ] + + for scenario_name, input_fn in scenarios: + print(f"\nBenchmarking: {scenario_name}") + print("-" * 80) + + flush() + model, device, dtype = get_model() + inputs = input_fn(model, device, dtype) + + # Latency + print(" Measuring latency...") + latency = measure_latency(model, inputs, num_warmup=5, num_runs=50) + + # Memory + print(" Measuring memory...") + memory = measure_memory(model, inputs) + + # Throughput + print(" Measuring throughput...") + throughput = benchmark_throughput(model, inputs, duration_seconds=10) + + result = { + "Scenario": scenario_name, + "Batch Size": batch_size, + "Latency (ms)": f"{latency['mean_ms']:.2f} ± {latency['std_ms']:.2f}", + "Latency Mean (ms)": latency['mean_ms'], + "Latency Std (ms)": latency['std_ms'], + "Min Latency (ms)": latency['min_ms'], + "Max Latency (ms)": latency['max_ms'], + "Peak Memory (MB)": memory['peak_memory_mb'], + "Throughput (iter/s)": throughput['iterations_per_sec'], + } + + results.append(result) + print(f" Mean latency: {latency['mean_ms']:.2f} ms (± {latency['std_ms']:.2f})") + print(f" Peak memory: {memory['peak_memory_mb']:.1f} MB") + print(f" Throughput: {throughput['iterations_per_sec']:.2f} iter/s") + + del model + flush() + + # Create DataFrame and save + df = pd.DataFrame(results) + + print("\n" + "="*80) + print("BENCHMARK RESULTS SUMMARY") + print("="*80) + print(df[["Scenario", "Latency (ms)", "Peak Memory (MB)", "Throughput (iter/s)"]].to_string(index=False)) + + # Calculate overhead + if len(results) >= 2: + baseline_latency = results[0]['Latency Mean (ms)'] + mask_no_padding_latency = results[1]['Latency Mean (ms)'] + mask_with_padding_latency = results[2]['Latency Mean (ms)'] + + overhead_no_padding = ((mask_no_padding_latency / baseline_latency) - 1) * 100 + overhead_with_padding = ((mask_with_padding_latency / baseline_latency) - 1) * 100 + + print("\n" + "="*80) + print("PERFORMANCE OVERHEAD ANALYSIS") + print("="*80) + print(f"Mask overhead (no padding): {overhead_no_padding:+.2f}%") + print(f"Mask overhead (with padding): {overhead_with_padding:+.2f}%") + + if abs(overhead_no_padding) < 5: + print("Negligible overhead for mask processing") + elif overhead_no_padding < 0: + print("Actually faster with mask (optimization opportunity)") + else: + print(f"WARNING: {overhead_no_padding:.1f}% overhead when using masks") + + # Save to CSV + csv_filename = "qwenimage.csv" + df.to_csv(csv_filename, index=False) + print(f"\nResults saved to: {csv_filename}") + + return df + + +if __name__ == "__main__": + df = run_benchmark_suite() diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c0fa031b9faf..f23e20b1d855 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -330,6 +330,31 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) + # Convert encoder_hidden_states_mask to 2D attention mask if provided. + if encoder_hidden_states_mask is not None and attention_mask is None: + batch_size = hidden_states.shape[0] + image_seq_len = hidden_states.shape[1] + text_seq_len = encoder_hidden_states.shape[1] + + if encoder_hidden_states_mask.shape[0] != batch_size: + raise ValueError( + f"encoder_hidden_states_mask batch size ({encoder_hidden_states_mask.shape[0]}) " + f"must match hidden_states batch size ({batch_size})" + ) + if encoder_hidden_states_mask.shape[1] != text_seq_len: + raise ValueError( + f"encoder_hidden_states_mask sequence length ({encoder_hidden_states_mask.shape[1]}) " + f"must match encoder_hidden_states sequence length ({text_seq_len})" + ) + + text_attention_mask = encoder_hidden_states_mask.bool() + image_attention_mask = torch.ones( + (batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device + ) + + joint_attention_mask_1d = torch.cat([text_attention_mask, image_attention_mask], dim=1) + attention_mask = joint_attention_mask_1d[:, None, None, :] * joint_attention_mask_1d[:, None, :, None] + # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, @@ -630,7 +655,15 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # Use padded sequence length for RoPE when mask is present. + # The attention mask will handle excluding padding tokens. + if encoder_hidden_states_mask is not None: + txt_seq_lens_for_rope = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] + else: + txt_seq_lens_for_rope = ( + txt_seq_lens if txt_seq_lens is not None else [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] + ) + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens_for_rope, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503ef..352037aa0534 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -91,6 +91,124 @@ def test_gradient_checkpointing_is_applied(self): expected_set = {"QwenImageTransformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_attention_mask_with_padding(self): + """Test that encoder_hidden_states_mask properly handles padded sequences.""" + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device).eval() + + batch_size = 2 + height = width = 4 + num_latent_channels = embedding_dim = 16 + text_seq_len = 7 + vae_scale_factor = 4 + + # Create inputs with padding + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device) + + # First sample: 5 real tokens, 2 padding + # Second sample: 3 real tokens, 4 padding + encoder_hidden_states_mask = torch.tensor( + [[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long + ).to(torch_device) + + # Zero out padding in embeddings + encoder_hidden_states = encoder_hidden_states * encoder_hidden_states_mask.unsqueeze(-1).float() + + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist() + + inputs_with_mask = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + # Run with proper mask + with torch.no_grad(): + output_with_mask = model(**inputs_with_mask).sample + + # Run with all-ones mask (treating padding as real tokens) + inputs_without_mask = { + "hidden_states": hidden_states.clone(), + "encoder_hidden_states": encoder_hidden_states.clone(), + "encoder_hidden_states_mask": torch.ones_like(encoder_hidden_states_mask), + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": [text_seq_len] * batch_size, + } + + with torch.no_grad(): + output_without_mask = model(**inputs_without_mask).sample + + # Outputs should differ when mask is applied correctly + diff = (output_with_mask - output_without_mask).abs().mean().item() + assert diff > 1e-5, f"Mask appears to be ignored (diff={diff})" + + def test_attention_mask_padding_isolation(self): + """Test that changing padding content doesn't affect output when mask is used.""" + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device).eval() + + batch_size = 2 + height = width = 4 + num_latent_channels = embedding_dim = 16 + text_seq_len = 7 + vae_scale_factor = 4 + + # Create inputs + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, text_seq_len, embedding_dim)).to(torch_device) + encoder_hidden_states_mask = torch.tensor( + [[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0]], dtype=torch.long + ).to(torch_device) + + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + orig_height = height * 2 * vae_scale_factor + orig_width = width * 2 * vae_scale_factor + img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size + txt_seq_lens = encoder_hidden_states_mask.sum(dim=1).tolist() + + inputs1 = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + with torch.no_grad(): + output1 = model(**inputs1).sample + + # Modify padding content with large noise + encoder_hidden_states2 = encoder_hidden_states.clone() + mask = encoder_hidden_states_mask.unsqueeze(-1).float() + noise = torch.randn_like(encoder_hidden_states2) * 10.0 + encoder_hidden_states2 = encoder_hidden_states2 + noise * (1 - mask) + + inputs2 = { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states2, + "encoder_hidden_states_mask": encoder_hidden_states_mask, + "timestep": timestep, + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + } + + with torch.no_grad(): + output2 = model(**inputs2).sample + + # Outputs should be nearly identical (padding is masked out) + diff = (output1 - output2).abs().mean().item() + assert diff < 1e-4, f"Padding content affected output (diff={diff})" + class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel