From d91aa18fa7ad1b775910414e4376c31ffd960350 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 1 Oct 2025 11:38:25 +0530 Subject: [PATCH 1/9] up --- .../modular_pipelines/qwenimage/__init__.py | 12 +- .../qwenimage/before_denoise.py | 46 +++- .../modular_pipelines/qwenimage/encoders.py | 223 +++++++++++++++++- .../qwenimage/modular_blocks.py | 175 +++++++++++++- .../qwenimage/modular_pipeline.py | 10 + 5 files changed, 462 insertions(+), 4 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py index 81cf515730ef..c0f78366229a 100644 --- a/src/diffusers/modular_pipelines/qwenimage/__init__.py +++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py @@ -29,11 +29,14 @@ "EDIT_AUTO_BLOCKS", "EDIT_BLOCKS", "EDIT_INPAINT_BLOCKS", + "EDIT_PLUS_AUTO_BLOCKS", + "EDIT_PLUS_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "TEXT2IMAGE_BLOCKS", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", + "QwenImageEditPlusAutoBlocks", ] _import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"] @@ -54,13 +57,20 @@ EDIT_AUTO_BLOCKS, EDIT_BLOCKS, EDIT_INPAINT_BLOCKS, + EDIT_PLUS_AUTO_BLOCKS, + EDIT_PLUS_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, TEXT2IMAGE_BLOCKS, QwenImageAutoBlocks, QwenImageEditAutoBlocks, + QwenImageEditPlusAutoBlocks, + ) + from .modular_pipeline import ( + QwenImageEditModularPipeline, + QwenImageEditPlusModularPipeline, + QwenImageModularPipeline, ) - from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index 606236cfe91b..2fcbb42a32ef 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -571,7 +571,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): @property def description(self) -> str: - return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step" + return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step" @property def inputs(self) -> List[InputParam]: @@ -641,6 +641,50 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep): + model_name = "qwenimage" + # TODO: Is there a better way to handle this name? It's used in + # `QwenImageEditPlusResizeDynamicStep` as well. We can later + # keep these things as a module-level constant. + _image_size_output_name = "image_sizes" + + @property + def inputs(self) -> List[InputParam]: + inputs_list = super().inputs + return inputs_list + [ + InputParam(name=self._image_size_output_name, required=True), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_image_sizes = getattr(block_state, self._image_size_output_name) + height, width = block_state.image_height, block_state.image_width + + # for edit, image size can be different from the target size (height/width) + block_state.img_shapes = [ + [ + (1, height // components.vae_scale_factor // 2, width // components.vae_scale_factor // 2), + *[ + (1, vae_height // components.vae_scale_factor // 2, vae_width // components.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * block_state.batch_size + + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + self.set_block_state(state, block_state) + + return components, state + + ## ControlNet inputs for denoiser class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 2ab83a03ee55..a49137eae4e4 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import PIL import torch @@ -128,6 +128,63 @@ def get_qwen_prompt_embeds_edit( return prompt_embeds, encoder_attention_mask +def get_qwen_prompt_embeds_edit_plus( + text_encoder, + processor, + prompt: Union[str, List[str]] = None, + image: Optional[Union[torch.Tensor, List[PIL.Image.Image], [PIL.Image.Image]]] = None, + prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>", + prompt_template_encode_start_idx: int = 64, + device: Optional[torch.device] = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_template_encode.format(i + 1) + elif image is not None: + base_img_prompt = img_template_encode.format(1) + else: + base_img_prompt = "" + + template = prompt_template_encode + + drop_idx = prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -266,6 +323,102 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep): + model_name = "qwenimage" + + def __init__( + self, + input_name: str = "image", + output_name: str = "resized_image", + vae_image_output_name: str = "resize_vae_image", + ): + """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. + + This block resizes an input image or a list input images and exposes the resized result under configurable + input and output names. Use this when you need to wire the resize step to different image fields (e.g., + "image", "control_image") + + Args: + input_name (str, optional): Name of the image field to read from the + pipeline state. Defaults to "image". + output_name (str, optional): Name of the resized image field to write + back to the pipeline state. Defaults to "resized_image". + vae_image_output_name (str, optional): Name of the resized image field + to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus + resizes the input image(s) differently for the VL and the VAE. + """ + if not isinstance(input_name, str) or not isinstance(output_name, str): + raise ValueError( + f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" + ) + self.condition_image_size = 384 * 384 + self.vae_image_size = 1024 * 1024 + self._image_input_name = input_name + self._resized_image_output_name = output_name + self._resized_image_vae_output_name = vae_image_output_name + self._image_size_output_name = "image_sizes" + super().__init__() + + @property + def description(self) -> str: + return f"Image Resize step that resize the {self._image_input_name} to the target areas of {self.condition_image_size} and {self.vae_image_size} while maintaining the aspect ratio." + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" + ), + OutputParam( + name=self._resized_image_vae_output_name, + type_hint=List[PIL.Image.Image], + description="The resized images to be used by the VAE encoder.", + ), + OutputParam( + name=self._image_size_output_name, + type_hint=List[Tuple[int, int]], + description="Sizes of images fed to the VAE encoder. To be used with RoPE.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + images = getattr(block_state, self._image_input_name) + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if ( + not isinstance(images, torch.Tensor) + and isinstance(images, PIL.Image.Image) + and not isinstance(images, list) + ): + images = [images] + + # TODO: revisit this when the inputs are `torch.Tensor`s + image_width, image_height = images[-1].size + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in images: + image_width, image_height = img.size + condition_width, condition_height, _ = calculate_dimensions( + self.condition_image_size, image_width / image_height + ) + vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, image_width / image_height) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + setattr(block_state, self._resized_image_output_name, condition_images) + setattr(block_state, self._resized_image_vae_output_name, vae_images) + setattr(block_state, self._image_size_output_name, vae_image_sizes) + self.set_block_state(state, block_state) + return components, state + + class QwenImageTextEncoderStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -511,6 +664,74 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep): + model_name = "qwenimage" + + @property + def description(self) -> str: + return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation.\n" + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec( + name="prompt_template_encode", + default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + ), + ConfigSpec( + name="img_template_encode", + default='img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"', + ), + ConfigSpec(name="prompt_template_encode_start_idx", default=64), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.prompt, block_state.negative_prompt) + + device = components._execution_device + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus( + components.text_encoder, + components.processor, + prompt=block_state.prompt, + image=block_state.resized_image, + prompt_template_encode=components.config.prompt_template_encode, + img_template_encode=components.config.img_template_encode, + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + device=device, + ) + + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or " " + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit( + components.text_encoder, + components.processor, + prompt=negative_prompt, + image=block_state.resized_image, + prompt_template_encode=components.config.prompt_template_encode, + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + device=device, + ) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 9126766cc202..d94aeafe5744 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -18,6 +18,7 @@ from .before_denoise import ( QwenImageControlNetBeforeDenoiserStep, QwenImageCreateMaskLatentsStep, + QwenImageEditPlusRoPEInputsStep, QwenImageEditRoPEInputsStep, QwenImagePrepareLatentsStep, QwenImagePrepareLatentsWithStrengthStep, @@ -37,6 +38,8 @@ ) from .encoders import ( QwenImageControlNetVaeEncoderStep, + QwenImageEditPlusResizeDynamicStep, + QwenImageEditPlusTextEncoderStep, QwenImageEditResizeDynamicStep, QwenImageEditTextEncoderStep, QwenImageInpaintProcessImagesInputStep, @@ -872,7 +875,175 @@ def description(self): ) -# 3. all block presets supported in QwenImage & QwenImage-Edit +#################### QwenImage Edit Plus ##################### + +# 3. QwenImage-Edit Plus + +## 3.1 QwenImage-Edit Plus / edit + +#### QwenImage-Edit Plus vl encoder: take both image and text prompts +QwenImageEditPlusVLEncoderBlocks = InsertableDict( + [ + ("resize", QwenImageEditPlusResizeDynamicStep()), + ("encode", QwenImageEditPlusTextEncoderStep()), + ] +) + + +class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditPlusVLEncoderBlocks.values() + block_names = QwenImageEditPlusVLEncoderBlocks.keys() + + @property + def description(self) -> str: + return "QwenImage-Edit Plus VL encoder step that encode the image an text prompts together." + + +#### QwenImage-Edit Plus vae encoder +QwenImageEditPlusVaeEncoderBlocks = InsertableDict( + [ + ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step + ("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image + ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents + ] +) + + +class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditPlusVaeEncoderBlocks.values() + block_names = QwenImageEditPlusVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +#### QwenImage Edit Plus presets +EDIT_PLUS_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditPlusVLEncoderStep()), + ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), + ("input", QwenImageEditInputStep()), + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), + ("denoise", QwenImageEditDenoiseStep()), + ("decode", QwenImageDecodeStep()), + ] +) + + +## 3.2 QwenImage-Edit Plus/auto before denoise +# compose the steps into a BeforeDenoiseStep for edit tasks before combining into an auto step + +#### QwenImage-Edit/edit before denoise +QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + # Different from QwenImage Edit. + ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), + ] +) + + +class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values() + block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task." + + +# auto before_denoise step for edit tasks +class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = [QwenImageEditPlusBeforeDenoiseStep] + block_names = ["edit"] + block_trigger_inputs = ["image_latents"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" + + "This is an auto pipeline block that works for edit (img2img) task.\n" + + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" + + " - if `image_latents` is not provided, step will be skipped." + ) + + +## 3.3 QwenImage-Edit Plus/auto encoders + + +class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [ + QwenImageEditPlusVaeEncoderStep, + ] + block_names = ["edit"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations. \n" + " This is an auto pipeline block that works for edit task.\n" + + " - `QwenImageEditPlusVaeEncoderStep` (edit) is used when `image` is provided.\n" + + " - if `image` is not provided, step will be skipped." + ) + + +## 3.4 QwenImage-Edit/auto blocks & presets + + +class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditAutoInputStep, + QwenImageEditPlusAutoBeforeDenoiseStep, + QwenImageEditAutoDenoiseStep, + ] + block_names = ["input", "before_denoise", "denoise"] + + @property + def description(self): + return ( + "Core step that performs the denoising process. \n" + + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n" + + " - `QwenImageEditPlusAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" + + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n" + + "This step support edit (img2img) workflow for QwenImage Edit Plus:\n" + + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n" + ) + + +EDIT_PLUS_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditPlusVLEncoderStep()), + ("vae_encoder", QwenImageEditPlusAutoVaeEncoderStep()), + ("denoise", QwenImageEditPlusCoreDenoiseStep()), + ("decode", QwenImageAutoDecodeStep()), + ] +) + + +class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = EDIT_PLUS_AUTO_BLOCKS.values() + block_names = EDIT_PLUS_AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) and edit tasks using QwenImage-Edit Plus.\n" + + "- for edit (img2img) generation, you need to provide `image`\n" + ) + + +# 3. all block presets supported in QwenImage, QwenImage-Edit, QwenImage-Edit Plus ALL_BLOCKS = { @@ -880,8 +1051,10 @@ def description(self): "img2img": IMAGE2IMAGE_BLOCKS, "edit": EDIT_BLOCKS, "edit_inpaint": EDIT_INPAINT_BLOCKS, + "edit_plus": EDIT_PLUS_BLOCKS, "inpaint": INPAINT_BLOCKS, "controlnet": CONTROLNET_BLOCKS, "auto": AUTO_BLOCKS, "edit_auto": EDIT_AUTO_BLOCKS, + "edit_plus_auto": EDIT_PLUS_AUTO_BLOCKS, } diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py index 7200169923a5..d9e30864f660 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py @@ -196,3 +196,13 @@ def requires_unconditional_embeds(self): requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 return requires_unconditional_embeds + + +class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline): + """ + A ModularPipeline for QwenImage-Edit Plus. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageEditPlusAutoBlocks" From c56f200dbc042c0f58dee28591770e1c5042539e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 1 Oct 2025 16:48:59 +0530 Subject: [PATCH 2/9] up --- src/diffusers/__init__.py | 4 ++++ .../models/transformers/transformer_qwenimage.py | 1 + src/diffusers/modular_pipelines/__init__.py | 4 ++++ src/diffusers/modular_pipelines/modular_pipeline.py | 4 +++- src/diffusers/modular_pipelines/qwenimage/__init__.py | 2 +- src/diffusers/modular_pipelines/qwenimage/denoise.py | 1 + src/diffusers/modular_pipelines/qwenimage/encoders.py | 10 +++++----- src/diffusers/pipelines/auto_pipeline.py | 2 ++ 8 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8867250deda8..bf980dbffa3d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -389,8 +389,10 @@ "FluxModularPipeline", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", + "QwenImageEditPlusAutoBlocks", "QwenImageEditModularPipeline", "QwenImageModularPipeline", + "QwenImageEditPlusModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", "WanAutoBlocks", @@ -1051,8 +1053,10 @@ FluxModularPipeline, QwenImageAutoBlocks, QwenImageEditAutoBlocks, + QwenImageEditPlusAutoBlocks, QwenImageEditModularPipeline, QwenImageModularPipeline, + QwenImageEditPlusModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, WanAutoBlocks, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 05379270c13b..5ef460984fc3 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -134,6 +134,7 @@ def apply_rotary_emb_qwen( return out else: + print(f"{x.shape=}, {freqs_cis.shape=}") x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(1) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 65c22b349b1c..c91d2edbcd6a 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -52,6 +52,8 @@ "QwenImageModularPipeline", "QwenImageEditModularPipeline", "QwenImageEditAutoBlocks", + "QwenImageEditPlusModularPipeline", + "QwenImageEditPlusAutoBlocks", ] _import_structure["components_manager"] = ["ComponentsManager"] @@ -79,6 +81,8 @@ QwenImageEditAutoBlocks, QwenImageEditModularPipeline, QwenImageModularPipeline, + QwenImageEditPlusModularPipeline, + QwenImageEditPlusAutoBlocks, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import WanAutoBlocks, WanModularPipeline diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 037c9e323c6b..852eb45c1276 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -59,6 +59,7 @@ ("flux", "FluxModularPipeline"), ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), + ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline") ] ) @@ -1628,7 +1629,8 @@ def from_pretrained( blocks = ModularPipelineBlocks.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs ) - except EnvironmentError: + except EnvironmentError as e: + logger.debug(f"EnvironmentError: {e}") blocks = None cache_dir = kwargs.pop("cache_dir", None) diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py index c0f78366229a..5f1be40187b8 100644 --- a/src/diffusers/modular_pipelines/qwenimage/__init__.py +++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py @@ -38,7 +38,7 @@ "QwenImageEditAutoBlocks", "QwenImageEditPlusAutoBlocks", ] - _import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"] + _import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline", "QwenImageEditPlusModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index d0704ee6e071..69712fe1a846 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -343,6 +343,7 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields} # YiYi TODO: add cache context + print(f"{block_state.img_shapes=}") guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latent_model_input, timestep=block_state.timestep / 1000, diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index a49137eae4e4..68b338f61173 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -132,7 +132,7 @@ def get_qwen_prompt_embeds_edit_plus( text_encoder, processor, prompt: Union[str, List[str]] = None, - image: Optional[Union[torch.Tensor, List[PIL.Image.Image], [PIL.Image.Image]]] = None, + image: Optional[Union[torch.Tensor, List[PIL.Image.Image], PIL.Image.Image]] = None, prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>", prompt_template_encode_start_idx: int = 64, @@ -371,7 +371,7 @@ def intermediate_outputs(self) -> List[OutputParam]: ), OutputParam( name=self._resized_image_vae_output_name, - type_hint=List[PIL.Image.Image], + type_hint=torch.Tensor, description="The resized images to be used by the VAE encoder.", ), OutputParam( @@ -409,8 +409,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): ) vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, image_width / image_height) vae_image_sizes.append((vae_width, vae_height)) - condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) - vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width)) + vae_images.append(components.image_resize_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) setattr(block_state, self._resized_image_output_name, condition_images) setattr(block_state, self._resized_image_vae_output_name, vae_images) @@ -718,7 +718,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): if components.requires_unconditional_embeds: negative_prompt = block_state.negative_prompt or " " - block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit( + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus( components.text_encoder, components.processor, prompt=negative_prompt, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index d265bfdcaf3d..f9f5752f45f9 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -98,6 +98,7 @@ QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline, + QwenImageEditPlusPipeline, ) from .sana import SanaPipeline from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline @@ -186,6 +187,7 @@ ("flux-kontext", FluxKontextPipeline), ("qwenimage", QwenImageImg2ImgPipeline), ("qwenimage-edit", QwenImageEditPipeline), + ("qwenimage-edit-plus", QwenImageEditPlusPipeline), ] ) From 9e8b7250188dae8ac50f7a17c916599a5abe1fae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 1 Oct 2025 20:35:46 +0530 Subject: [PATCH 3/9] up --- src/diffusers/__init__.py | 8 +- .../transformers/transformer_qwenimage.py | 1 - src/diffusers/modular_pipelines/__init__.py | 4 +- .../modular_pipelines/modular_pipeline.py | 2 +- .../modular_pipelines/qwenimage/__init__.py | 6 +- .../modular_pipelines/qwenimage/denoise.py | 1 - .../modular_pipelines/qwenimage/encoders.py | 85 ++++++++++++------- .../qwenimage/modular_blocks.py | 3 +- src/diffusers/pipelines/auto_pipeline.py | 2 +- 9 files changed, 71 insertions(+), 41 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bf980dbffa3d..686e8d99dabf 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -389,10 +389,10 @@ "FluxModularPipeline", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", - "QwenImageEditPlusAutoBlocks", "QwenImageEditModularPipeline", - "QwenImageModularPipeline", + "QwenImageEditPlusAutoBlocks", "QwenImageEditPlusModularPipeline", + "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", "WanAutoBlocks", @@ -1053,10 +1053,10 @@ FluxModularPipeline, QwenImageAutoBlocks, QwenImageEditAutoBlocks, - QwenImageEditPlusAutoBlocks, QwenImageEditModularPipeline, - QwenImageModularPipeline, + QwenImageEditPlusAutoBlocks, QwenImageEditPlusModularPipeline, + QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, WanAutoBlocks, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 5ef460984fc3..05379270c13b 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -134,7 +134,6 @@ def apply_rotary_emb_qwen( return out else: - print(f"{x.shape=}, {freqs_cis.shape=}") x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(1) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index c91d2edbcd6a..2e590594af71 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -80,9 +80,9 @@ QwenImageAutoBlocks, QwenImageEditAutoBlocks, QwenImageEditModularPipeline, - QwenImageModularPipeline, - QwenImageEditPlusModularPipeline, QwenImageEditPlusAutoBlocks, + QwenImageEditPlusModularPipeline, + QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import WanAutoBlocks, WanModularPipeline diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 852eb45c1276..e543bf0bb3af 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -59,7 +59,7 @@ ("flux", "FluxModularPipeline"), ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), - ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline") + ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), ] ) diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py index 5f1be40187b8..ae4ec4799fbc 100644 --- a/src/diffusers/modular_pipelines/qwenimage/__init__.py +++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py @@ -38,7 +38,11 @@ "QwenImageEditAutoBlocks", "QwenImageEditPlusAutoBlocks", ] - _import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline", "QwenImageEditPlusModularPipeline"] + _import_structure["modular_pipeline"] = [ + "QwenImageEditModularPipeline", + "QwenImageEditPlusModularPipeline", + "QwenImageModularPipeline", + ] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index 69712fe1a846..d0704ee6e071 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -343,7 +343,6 @@ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields} # YiYi TODO: add cache context - print(f"{block_state.img_shapes=}") guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latent_model_input, timestep=block_state.timestep / 1000, diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 68b338f61173..f59afd4c2389 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -330,7 +330,7 @@ def __init__( self, input_name: str = "image", output_name: str = "resized_image", - vae_image_output_name: str = "resize_vae_image", + vae_image_output_name: str = "vae_image", ): """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. @@ -343,9 +343,9 @@ def __init__( pipeline state. Defaults to "image". output_name (str, optional): Name of the resized image field to write back to the pipeline state. Defaults to "resized_image". - vae_image_output_name (str, optional): Name of the resized image field + vae_image_output_name (str, optional): Name of the image field to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus - resizes the input image(s) differently for the VL and the VAE. + processes the input image(s) differently for the VL and the VAE. """ if not isinstance(input_name, str) or not isinstance(output_name, str): raise ValueError( @@ -355,14 +355,10 @@ def __init__( self.vae_image_size = 1024 * 1024 self._image_input_name = input_name self._resized_image_output_name = output_name - self._resized_image_vae_output_name = vae_image_output_name + self._vae_image_output_name = vae_image_output_name self._image_size_output_name = "image_sizes" super().__init__() - @property - def description(self) -> str: - return f"Image Resize step that resize the {self._image_input_name} to the target areas of {self.condition_image_size} and {self.vae_image_size} while maintaining the aspect ratio." - @property def intermediate_outputs(self) -> List[OutputParam]: return [ @@ -370,9 +366,9 @@ def intermediate_outputs(self) -> List[OutputParam]: name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" ), OutputParam( - name=self._resized_image_vae_output_name, - type_hint=torch.Tensor, - description="The resized images to be used by the VAE encoder.", + name=self._vae_image_output_name, + type_hint=List[PIL.Image.Image], + description="The images to be processed which will be used by the VAE encoder.", ), OutputParam( name=self._image_size_output_name, @@ -397,8 +393,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): ): images = [images] - # TODO: revisit this when the inputs are `torch.Tensor`s - image_width, image_height = images[-1].size + # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s condition_images = [] vae_image_sizes = [] vae_images = [] @@ -410,10 +405,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, image_width / image_height) vae_image_sizes.append((vae_width, vae_height)) condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width)) - vae_images.append(components.image_resize_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + vae_images.append(img) setattr(block_state, self._resized_image_output_name, condition_images) - setattr(block_state, self._resized_image_vae_output_name, vae_images) + setattr(block_state, self._vae_image_output_name, vae_images) setattr(block_state, self._image_size_output_name, vae_image_sizes) self.set_block_state(state, block_state) return components, state @@ -718,14 +713,16 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): if components.requires_unconditional_embeds: negative_prompt = block_state.negative_prompt or " " - block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus( - components.text_encoder, - components.processor, - prompt=negative_prompt, - image=block_state.resized_image, - prompt_template_encode=components.config.prompt_template_encode, - prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, - device=device, + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = ( + get_qwen_prompt_embeds_edit_plus( + components.text_encoder, + components.processor, + prompt=negative_prompt, + image=block_state.resized_image, + prompt_template_encode=components.config.prompt_template_encode, + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + device=device, + ) ) self.set_block_state(state, block_state) @@ -833,12 +830,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [ - InputParam("resized_image"), - InputParam("image"), - InputParam("height"), - InputParam("width"), - ] + return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] @property def intermediate_outputs(self) -> List[OutputParam]: @@ -882,6 +874,41 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): + model_name = "qwenimage-edit-plus" + + @property + def description(self) -> str: + return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing." + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + if block_state.vae_image is None and block_state.image is None: + raise ValueError("`vae_image` and `image` cannot be None at the same time") + + if block_state.vae_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + else: + width, height = block_state.vae_image[0].size + image = block_state.vae_image + + block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index d94aeafe5744..4169be37facc 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -38,6 +38,7 @@ ) from .encoders import ( QwenImageControlNetVaeEncoderStep, + QwenImageEditPlusProcessImagesInputStep, QwenImageEditPlusResizeDynamicStep, QwenImageEditPlusTextEncoderStep, QwenImageEditResizeDynamicStep, @@ -904,7 +905,7 @@ def description(self) -> str: QwenImageEditPlusVaeEncoderBlocks = InsertableDict( [ ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step - ("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image + ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents ] ) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index f9f5752f45f9..8a32d4c367a3 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -95,10 +95,10 @@ QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, QwenImageEditPipeline, + QwenImageEditPlusPipeline, QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline, - QwenImageEditPlusPipeline, ) from .sana import SanaPipeline from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline From ba02f6b5f016002c6f583404340e3dfc1321b81e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 1 Oct 2025 21:00:14 +0530 Subject: [PATCH 4/9] up --- .../modular_pipelines/qwenimage/encoders.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index f59afd4c2389..42342a543ffa 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -662,10 +662,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep): model_name = "qwenimage" - @property - def description(self) -> str: - return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation.\n" - @property def expected_configs(self) -> List[ConfigSpec]: return [ @@ -680,18 +676,6 @@ def expected_configs(self) -> List[ConfigSpec]: ConfigSpec(name="prompt_template_encode_start_idx", default=64), ] - @staticmethod - def check_inputs(prompt, negative_prompt): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if ( - negative_prompt is not None - and not isinstance(negative_prompt, str) - and not isinstance(negative_prompt, list) - ): - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) @@ -720,6 +704,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): prompt=negative_prompt, image=block_state.resized_image, prompt_template_encode=components.config.prompt_template_encode, + img_template_encode=components.config.img_template_encode, prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, device=device, ) From 5af00520edc71f15e2653c6872487597d19ffb6c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 1 Oct 2025 21:35:18 +0530 Subject: [PATCH 5/9] up --- .../dummy_torch_and_transformers_objects.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index bb8fea8c8a8b..cf8037796488 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -77,6 +77,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class QwenImageEditPlusAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class QwenImageEditPlusModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImageModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 83d4a860097a04a778b29871c60e0812435e59e2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 2 Oct 2025 16:21:48 +0530 Subject: [PATCH 6/9] up --- .../qwenimage/before_denoise.py | 2 +- .../modular_pipelines/qwenimage/encoders.py | 17 +++++++++++++---- .../qwenimage/pipeline_qwenimage_edit_plus.py | 7 +++++++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index 2fcbb42a32ef..e49ea6402272 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -203,7 +203,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - block_state.latents = components.pachifier.pack_latents(block_state.latents) self.set_block_state(state, block_state) - + torch.save({"latents": block_state.latents}, "latents_mod.pt") return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 42342a543ffa..1397a98961a9 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -159,7 +159,6 @@ def get_qwen_prompt_embeds_edit_plus( padding=True, return_tensors="pt", ).to(device) - outputs = text_encoder( input_ids=model_inputs.input_ids, attention_mask=model_inputs.attention_mask, @@ -181,7 +180,6 @@ def get_qwen_prompt_embeds_edit_plus( ) prompt_embeds = prompt_embeds.to(device=device) - return prompt_embeds, encoder_attention_mask @@ -671,7 +669,7 @@ def expected_configs(self) -> List[ConfigSpec]: ), ConfigSpec( name="img_template_encode", - default='img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"', + default="Picture {}: <|vision_start|><|image_pad|><|vision_end|>", ), ConfigSpec(name="prompt_template_encode_start_idx", default=64), ] @@ -694,6 +692,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, device=device, ) + torch.save( + {"prompt_embeds": block_state.prompt_embeds, "prompt_embeds_mask": block_state.prompt_embeds_mask}, + "prompt_embeds_mod.pt", + ) if components.requires_unconditional_embeds: negative_prompt = block_state.negative_prompt or " " @@ -709,6 +711,13 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): device=device, ) ) + torch.save( + { + "neg_prompt_embeds": block_state.negative_prompt_embeds, + "neg_prompt_embeds_mask": block_state.negative_prompt_embeds_mask, + }, + "neg_prompt_embeds_mod.pt", + ) self.set_block_state(state, block_state) return components, state @@ -971,7 +980,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - dtype=dtype, latent_channels=components.num_channels_latents, ) - + torch.save({"image_latents": image_latents}, "image_latents_mod.pt") setattr(block_state, self._image_latents_output_name, image_latents) self.set_block_state(state, block_state) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index ec203edf166c..44f92153911a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -706,6 +706,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) + torch.save({"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask}, "prompt_embeds.pt") if do_true_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( image=condition_images, @@ -716,6 +717,10 @@ def __call__( num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) + torch.save( + {"neg_prompt_embeds": negative_prompt_embeds, "neg_prompt_embeds_mask": negative_prompt_embeds_mask}, + "neg_prompt_embeds.pt", + ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -730,6 +735,7 @@ def __call__( generator, latents, ) + torch.save({"latents": latents, "image_latents": image_latents}, "latents.pt") img_shapes = [ [ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), @@ -830,6 +836,7 @@ def __call__( cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) noise_pred = comb_pred * (cond_norm / noise_norm) + torch.save({"noise_pred": noise_pred}, f"noise_pred_{i}.pt") # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From 1474ec56095a9e3ad7c613404d6b951dd75fb85c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 2 Oct 2025 16:47:00 +0530 Subject: [PATCH 7/9] remove saves --- .../modular_pipelines/qwenimage/before_denoise.py | 1 - .../modular_pipelines/qwenimage/encoders.py | 12 ------------ .../qwenimage/pipeline_qwenimage_edit_plus.py | 7 ------- 3 files changed, 20 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index e49ea6402272..efeeff007c46 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -203,7 +203,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - block_state.latents = components.pachifier.pack_latents(block_state.latents) self.set_block_state(state, block_state) - torch.save({"latents": block_state.latents}, "latents_mod.pt") return components, state diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 1397a98961a9..2bbfdf94b08e 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -692,10 +692,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, device=device, ) - torch.save( - {"prompt_embeds": block_state.prompt_embeds, "prompt_embeds_mask": block_state.prompt_embeds_mask}, - "prompt_embeds_mod.pt", - ) if components.requires_unconditional_embeds: negative_prompt = block_state.negative_prompt or " " @@ -711,13 +707,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): device=device, ) ) - torch.save( - { - "neg_prompt_embeds": block_state.negative_prompt_embeds, - "neg_prompt_embeds_mask": block_state.negative_prompt_embeds_mask, - }, - "neg_prompt_embeds_mod.pt", - ) self.set_block_state(state, block_state) return components, state @@ -980,7 +969,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - dtype=dtype, latent_channels=components.num_channels_latents, ) - torch.save({"image_latents": image_latents}, "image_latents_mod.pt") setattr(block_state, self._image_latents_output_name, image_latents) self.set_block_state(state, block_state) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 44f92153911a..ec203edf166c 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -706,7 +706,6 @@ def __call__( num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) - torch.save({"prompt_embeds": prompt_embeds, "prompt_embeds_mask": prompt_embeds_mask}, "prompt_embeds.pt") if do_true_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( image=condition_images, @@ -717,10 +716,6 @@ def __call__( num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, ) - torch.save( - {"neg_prompt_embeds": negative_prompt_embeds, "neg_prompt_embeds_mask": negative_prompt_embeds_mask}, - "neg_prompt_embeds.pt", - ) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 @@ -735,7 +730,6 @@ def __call__( generator, latents, ) - torch.save({"latents": latents, "image_latents": image_latents}, "latents.pt") img_shapes = [ [ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), @@ -836,7 +830,6 @@ def __call__( cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) noise_pred = comb_pred * (cond_norm / noise_norm) - torch.save({"noise_pred": noise_pred}, f"noise_pred_{i}.pt") # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From c7c215aab0b08e22296e2ffe577efead5340f194 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 5 Oct 2025 16:45:13 +0530 Subject: [PATCH 8/9] move things around a bit. --- .../qwenimage/before_denoise.py | 6 +-- .../modular_pipelines/qwenimage/encoders.py | 41 +++++++++++-------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index efeeff007c46..ce56de172409 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -641,11 +641,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep): - model_name = "qwenimage" + model_name = "qwenimage-edit-plus" # TODO: Is there a better way to handle this name? It's used in - # `QwenImageEditPlusResizeDynamicStep` as well. We can later + # `QwenImageEditPlusProcessImagesInputStep` as well. We can later # keep these things as a module-level constant. - _image_size_output_name = "image_sizes" + _image_size_output_name = "vae_image_sizes" @property def inputs(self) -> List[InputParam]: diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 2bbfdf94b08e..37fbef911c60 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -350,28 +350,18 @@ def __init__( f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" ) self.condition_image_size = 384 * 384 - self.vae_image_size = 1024 * 1024 self._image_input_name = input_name self._resized_image_output_name = output_name self._vae_image_output_name = vae_image_output_name - self._image_size_output_name = "image_sizes" super().__init__() @property def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam( - name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images" - ), + return super().intermediate_outputs + [ OutputParam( name=self._vae_image_output_name, type_hint=List[PIL.Image.Image], - description="The images to be processed which will be used by the VAE encoder.", - ), - OutputParam( - name=self._image_size_output_name, - type_hint=List[Tuple[int, int]], - description="Sizes of images fed to the VAE encoder. To be used with RoPE.", + description="The images to be processed which will be further used by the VAE encoder.", ), ] @@ -393,21 +383,17 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s condition_images = [] - vae_image_sizes = [] vae_images = [] for img in images: image_width, image_height = img.size condition_width, condition_height, _ = calculate_dimensions( self.condition_image_size, image_width / image_height ) - vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, image_width / image_height) - vae_image_sizes.append((vae_width, vae_height)) condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width)) vae_images.append(img) setattr(block_state, self._resized_image_output_name, condition_images) setattr(block_state, self._vae_image_output_name, vae_images) - setattr(block_state, self._image_size_output_name, vae_image_sizes) self.set_block_state(state, block_state) return components, state @@ -859,6 +845,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): model_name = "qwenimage-edit-plus" + vae_image_size = 1024 * 1024 @property def description(self) -> str: @@ -868,6 +855,12 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")] + @property + def intermediate_outputs(self) -> List[OutputParam]: + return super().intermediate_outputs + [ + OutputParam(name="vae_image_sizes", type_hint=List[Tuple[int, int]]), + ] + @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) @@ -882,11 +875,23 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): ) height = block_state.height or components.default_height width = block_state.width or components.default_width + block_state.processed_image = components.image_processor.preprocess( + image=image, height=height, width=width + ) else: - width, height = block_state.vae_image[0].size + vae_image_sizes = [] image = block_state.vae_image + for img in image: + width, height = img.size + vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height) + vae_image_sizes.append((vae_width, vae_height)) - block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width) + block_state.vae_image_sizes = vae_image_sizes + + width, height = block_state.vae_image[0].size + block_state.processed_image = components.image_processor.preprocess( + image=image, height=vae_height, width=vae_width + ) self.set_block_state(state, block_state) return components, state From 22b179924af4f340f1ffd1e46215fbf820b96159 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 5 Oct 2025 18:08:09 +0530 Subject: [PATCH 9/9] get ready. --- .../qwenimage/before_denoise.py | 44 ------------------- .../modular_pipelines/qwenimage/encoders.py | 19 ++------ .../qwenimage/modular_blocks.py | 33 ++------------ 3 files changed, 7 insertions(+), 89 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index ce56de172409..fdec95dc506e 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -640,50 +640,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state -class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep): - model_name = "qwenimage-edit-plus" - # TODO: Is there a better way to handle this name? It's used in - # `QwenImageEditPlusProcessImagesInputStep` as well. We can later - # keep these things as a module-level constant. - _image_size_output_name = "vae_image_sizes" - - @property - def inputs(self) -> List[InputParam]: - inputs_list = super().inputs - return inputs_list + [ - InputParam(name=self._image_size_output_name, required=True), - ] - - def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - vae_image_sizes = getattr(block_state, self._image_size_output_name) - height, width = block_state.image_height, block_state.image_width - - # for edit, image size can be different from the target size (height/width) - block_state.img_shapes = [ - [ - (1, height // components.vae_scale_factor // 2, width // components.vae_scale_factor // 2), - *[ - (1, vae_height // components.vae_scale_factor // 2, vae_width // components.vae_scale_factor // 2) - for vae_width, vae_height in vae_image_sizes - ], - ] - ] * block_state.batch_size - - block_state.txt_seq_lens = ( - block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None - ) - block_state.negative_txt_seq_lens = ( - block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() - if block_state.negative_prompt_embeds_mask is not None - else None - ) - - self.set_block_state(state, block_state) - - return components, state - - ## ControlNet inputs for denoiser class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 37fbef911c60..04fb3fdc947b 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import PIL import torch @@ -855,12 +855,6 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")] - @property - def intermediate_outputs(self) -> List[OutputParam]: - return super().intermediate_outputs + [ - OutputParam(name="vae_image_sizes", type_hint=List[Tuple[int, int]]), - ] - @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) @@ -879,18 +873,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): image=image, height=height, width=width ) else: - vae_image_sizes = [] + width, height = block_state.vae_image[0].size image = block_state.vae_image - for img in image: - width, height = img.size - vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height) - vae_image_sizes.append((vae_width, vae_height)) - - block_state.vae_image_sizes = vae_image_sizes - width, height = block_state.vae_image[0].size block_state.processed_image = components.image_processor.preprocess( - image=image, height=vae_height, width=vae_width + image=image, height=height, width=width ) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 4169be37facc..83bfcb3da4fd 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -18,7 +18,6 @@ from .before_denoise import ( QwenImageControlNetBeforeDenoiserStep, QwenImageCreateMaskLatentsStep, - QwenImageEditPlusRoPEInputsStep, QwenImageEditRoPEInputsStep, QwenImagePrepareLatentsStep, QwenImagePrepareLatentsWithStrengthStep, @@ -929,41 +928,17 @@ def description(self) -> str: ("input", QwenImageEditInputStep()), ("prepare_latents", QwenImagePrepareLatentsStep()), ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), + ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), ("denoise", QwenImageEditDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) -## 3.2 QwenImage-Edit Plus/auto before denoise -# compose the steps into a BeforeDenoiseStep for edit tasks before combining into an auto step - -#### QwenImage-Edit/edit before denoise -QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict( - [ - ("prepare_latents", QwenImagePrepareLatentsStep()), - ("set_timesteps", QwenImageSetTimestepsStep()), - # Different from QwenImage Edit. - ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), - ] -) - - -class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks): - model_name = "qwenimage" - block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values() - block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys() - - @property - def description(self): - return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task." - - # auto before_denoise step for edit tasks class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): model_name = "qwenimage-edit-plus" - block_classes = [QwenImageEditPlusBeforeDenoiseStep] + block_classes = [QwenImageEditBeforeDenoiseStep] block_names = ["edit"] block_trigger_inputs = ["image_latents"] @@ -977,7 +952,7 @@ def description(self): ) -## 3.3 QwenImage-Edit Plus/auto encoders +## 3.2 QwenImage-Edit Plus/auto encoders class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks): @@ -997,7 +972,7 @@ def description(self): ) -## 3.4 QwenImage-Edit/auto blocks & presets +## 3.3 QwenImage-Edit/auto blocks & presets class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):