Skip to content
92 changes: 76 additions & 16 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,31 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(1024)
neg_index = torch.arange(1024).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
# Initialize with default size 1024, but allow dynamic expansion
self._current_max_len = 1024
pos_index = torch.arange(self._current_max_len)
neg_index = torch.arange(self._current_max_len).flip(0) * -1 - 1
self.register_buffer(
"pos_freqs",
torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
),
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
self.register_buffer(
"neg_freqs",
torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
),
)
self.rope_cache = {}

Expand All @@ -193,6 +201,53 @@ def rope_params(self, index, dim, theta=10000):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

def _expand_pos_freqs_if_needed(self, required_len):
"""Expand pos_freqs and neg_freqs if required length exceeds current size"""
if required_len <= self._current_max_len:
return

# Calculate new size (use next power of 2 or round to nearest 512 for efficiency)
new_max_len = max(required_len, int((required_len + 511) // 512) * 512)

# Log warning about potential quality degradation for long prompts
if required_len > 512:
logger.warning(
f"QwenImage model was trained on prompts up to 512 tokens. "
f"Current prompt requires {required_len} tokens, which may lead to unpredictable behavior. "
f"Consider using shorter prompts for better results."
)

# Generate expanded indices
pos_index = torch.arange(new_max_len, device=self.pos_freqs.device)
neg_index = torch.arange(new_max_len, device=self.neg_freqs.device).flip(0) * -1 - 1

# Generate expanded frequency embeddings
new_pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
).to(device=self.pos_freqs.device, dtype=self.pos_freqs.dtype)

new_neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
).to(device=self.neg_freqs.device, dtype=self.neg_freqs.dtype)

# Update buffers
self.register_buffer("pos_freqs", new_pos_freqs)
self.register_buffer("neg_freqs", new_neg_freqs)
self._current_max_len = new_max_len

# Clear cache since dimensions changed
self.rope_cache = {}

def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
Expand Down Expand Up @@ -232,6 +287,11 @@ def forward(self, video_fhw, txt_seq_lens, device):
max_vid_index = max(height, width)

max_len = max(txt_seq_lens)

# Expand pos_freqs if needed to accommodate max_vid_index + max_len
required_len = max_vid_index + max_len
self._expand_pos_freqs_if_needed(required_len)

txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]

return vid_freqs, txt_freqs
Expand Down
78 changes: 77 additions & 1 deletion tests/pipelines/qwenimage/test_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
QwenImagePipeline,
QwenImageTransformer2DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.testing_utils import CaptureLogger, enable_full_determinism, torch_device

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
Expand Down Expand Up @@ -234,3 +234,79 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
expected_diff_max,
"VAE tiling should not affect the inference results",
)

def test_long_prompt_no_error(self):
# Test for issue #12083: long prompts should not cause dimension mismatch errors
device = torch_device
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)

# Create a long prompt that approaches but stays within limits
# This tests the original issue fix without triggering the warning
phrase = "A beautiful, detailed, high-resolution, photorealistic image showing "
long_prompt = phrase * 40 # Generates ~800 tokens, well within limits

# Verify token count for test clarity
tokenizer = components["tokenizer"]
token_count = len(tokenizer.encode(long_prompt))
required_len = 32 + token_count # height/width + tokens
# Should be large enough to test the fix but not trigger expansion warning
self.assertGreater(token_count, 500, f"Test prompt should be substantial (got {token_count} tokens)")
self.assertLess(required_len, 1024, f"Test should stay within limits (got {required_len})")

inputs = {
"prompt": long_prompt,
"generator": torch.Generator(device=device).manual_seed(0),
"num_inference_steps": 2,
"guidance_scale": 3.0,
"true_cfg_scale": 1.0,
"height": 32, # Small size for fast test
"width": 32, # Small size for fast test
"max_sequence_length": 1024, # Allow long sequence (max allowed)
"output_type": "pt",
}

# This should not raise a RuntimeError about tensor dimension mismatch
_ = pipe(**inputs)

def test_long_prompt_warning(self):
"""Test that long prompts trigger appropriate warning about training limitation"""
from diffusers.utils import logging

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)

# Create a long prompt that will exceed the RoPE expansion threshold
# The warning is triggered when required_len = max(height, width) + text_tokens > _current_max_len
# Since _current_max_len is 1024 and height=width=32, we need > 992 tokens
phrase = "A detailed photorealistic image showing many beautiful elements and complex artistic creative features with intricate designs."
long_prompt = phrase * 58 # Generates ~1045 tokens, ensuring required_len > 1024

# Verify we exceed the threshold (for test robustness)
tokenizer = components["tokenizer"]
token_count = len(tokenizer.encode(long_prompt))
required_len = 32 + token_count # height/width + tokens
self.assertGreater(required_len, 1024, f"Test prompt must exceed threshold (got {required_len})")

# Capture transformer logging
logger = logging.get_logger("diffusers.models.transformers.transformer_qwenimage")
logger.setLevel(logging.WARNING)

with CaptureLogger(logger) as cap_logger:
_ = pipe(
prompt=long_prompt,
generator=torch.Generator(device=torch_device).manual_seed(0),
num_inference_steps=2,
guidance_scale=3.0,
true_cfg_scale=1.0,
height=32, # Small size for fast test
width=32, # Small size for fast test
max_sequence_length=1024, # Allow long sequence
output_type="pt",
)

# Verify warning was logged about the 512-token training limitation
self.assertTrue("512 tokens" in cap_logger.out)
self.assertTrue("unpredictable behavior" in cap_logger.out)