Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@

>>> prompt = "A robot, 4k photo"
>>> image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... ).resize((1024, 1024))
>>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
>>> depth_image = get_depth_map(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def __init__(self, *args, **kwargs):
>>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
>>> video = load_video(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
... )[
... :21
... ] # This example uses only the first 21 frames
... )[:21] # This example uses only the first 21 frames

>>> video = pipe(video=video, prompt=prompt).frames[0]
>>> export_to_video(video, "output.mp4", fps=30)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
negative_prompt = "low quality, bad quality"
original_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
mask = np.zeros((768, 768), dtype=np.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@
>>> pipe.to("cuda")
>>> init_image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/frog.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/frog.png"
... )
>>> image = pipe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@
>>> pipe.to("cuda")

>>> init_image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... )

>>> mask = np.zeros((768, 768), dtype=np.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@
>>> pipe_prior.to("cuda")

>>> img1 = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... )

>>> img2 = load_image(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
negative_prompt = "low quality, bad quality"

original_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)

mask = np.zeros((768, 768), dtype=np.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@


>>> img = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... ).resize((768, 768))

>>> hint = make_hint(img, depth_estimator).unsqueeze(0).half().to("cuda")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@
>>> pipe = pipe.to("cuda")

>>> img = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... ).resize((768, 768))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@
>>> pipe.to("cuda")

>>> init_image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/frog.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/frog.png"
... )

>>> image = pipe(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@
>>> pipe.to("cuda")

>>> init_image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... )

>>> mask = np.zeros((768, 768), dtype=np.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@
... )
>>> pipe_prior.to("cuda")
>>> img1 = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... )
>>> img2 = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@

>>> prompt = "red cat, 4k photo"
>>> img = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... )
>>> image_emb, nagative_image_emb = pipe_prior(prompt, image=img, strength=0.2).to_tuple()

Expand Down Expand Up @@ -73,8 +72,7 @@
>>> pipe_prior.to("cuda")

>>> img1 = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... )

>>> img2 = load_image(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@

>>> prompt = "A robot, 4k photo"
>>> image = load_image(
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
... "/kandinsky/cat.png"
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
... ).resize((1024, 1024))
>>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
>>> depth_image = get_depth_map(image)
Expand Down
111 changes: 99 additions & 12 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def _get_qwen_prompt_embeds(

return prompt_embeds, encoder_attention_mask

# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
Expand All @@ -295,7 +294,6 @@ def encode_prompt(
max_sequence_length: int = 1024,
):
r"""

Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
Expand All @@ -311,17 +309,59 @@ def encode_prompt(
"""
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
# [Fix] Loop over prompts to avoid Qwen2VLProcessor batching bugs & IndexError
if isinstance(prompt, list) and len(prompt) > 1:
prompt_embeds_list = []
mask_list = []

# Normalize images to a list matching the prompt length
if isinstance(image, list):
current_images = image
else:
current_images = [image] * len(prompt)

for i, single_prompt in enumerate(prompt):
# Safety: Ensure we have an image for this prompt
single_image = current_images[i] if i < len(current_images) else current_images[0]

pe, pem = self._get_qwen_prompt_embeds(single_prompt, image=single_image, device=device)
prompt_embeds_list.append(pe)
mask_list.append(pem)

# [Fix] Pad embeddings to the maximum length in the batch before stacking
max_len = max([p.shape[1] for p in prompt_embeds_list])

padded_embeds = []
padded_masks = []

for pe, pem in zip(prompt_embeds_list, mask_list):
cur_len = pe.shape[1]
pad_len = max_len - cur_len

if pad_len > 0:
# Pad sequence dim (2nd last dim for embeds, last dim for mask)
pe = torch.nn.functional.pad(pe, (0, 0, 0, pad_len))
pem = torch.nn.functional.pad(pem, (0, pad_len))

padded_embeds.append(pe)
padded_masks.append(pem)

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds = torch.cat(padded_embeds, dim=0)
prompt_embeds_mask = torch.cat(padded_masks, dim=0)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
else:
# Standard path for single prompt
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

return prompt_embeds, prompt_embeds_mask

Expand Down Expand Up @@ -627,7 +667,24 @@ def __call__(
[`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated images.
"""
image_size = image[-1].size if isinstance(image, list) else image.size
# [Fix] Robustly determine image size (Handles Lists, Tensors, and PIL)
if isinstance(image, list):
# Grab the first valid image to determine dimensions
check_img = image[0]
# Handle potential nested lists (e.g. if batching logic gets complex)
while isinstance(check_img, (list, tuple)):
check_img = check_img[0]

if isinstance(check_img, torch.Tensor):
# Tensor shape is usually (C, H, W) or (B, C, H, W) -> take last two dims
image_size = (check_img.shape[-1], check_img.shape[-2])
else:
image_size = check_img.size
elif isinstance(image, torch.Tensor):
image_size = (image.shape[-1], image.shape[-2])
else:
image_size = image.size

calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
height = height or calculated_height
width = width or calculated_width
Expand Down Expand Up @@ -668,10 +725,12 @@ def __call__(
if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
if not isinstance(image, list):
image = [image]

condition_image_sizes = []
condition_images = []
vae_image_sizes = []
vae_images = []

for img in image:
image_width, image_height = img.size
condition_width, condition_height = calculate_dimensions(
Expand All @@ -681,8 +740,36 @@ def __call__(
condition_image_sizes.append((condition_width, condition_height))
vae_image_sizes.append((vae_width, vae_height))
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))

# [5D Fix] Ensure (B, C, F, H, W)
vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))

# [FIX] Handle Batch vs Multi-Condition Ambiguity + Variable Resolutions
if isinstance(prompt, list) and len(prompt) > 1 and len(vae_images) == len(prompt):
# 1. Find max dims (Height=[-2], Width=[-1])
max_h = max(img.shape[-2] for img in vae_images)
max_w = max(img.shape[-1] for img in vae_images)

padded_images = []
for img in vae_images:
h, w = img.shape[-2], img.shape[-1]
pad_h = max_h - h
pad_w = max_w - w
if pad_h > 0 or pad_w > 0:
# Pad (left, right, top, bottom)
img = torch.nn.functional.pad(img, (0, pad_w, 0, pad_h))
padded_images.append(img)

# 2. 1-to-1 Batching
batch_tensor = torch.cat(padded_images, dim=0)
vae_images = [batch_tensor]

# 3. [FIX] Update metadata to match padded dims - Rotary Positional Embeddings
# We must tell the model that each batch item has exactly 1 condition image with the new padded dimensions.
height = max_h
width = max_w
vae_image_sizes = [(max_w, max_h)]

has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
Expand Down