diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8867250deda8..686e8d99dabf 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -390,6 +390,8 @@ "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", "QwenImageEditModularPipeline", + "QwenImageEditPlusAutoBlocks", + "QwenImageEditPlusModularPipeline", "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", @@ -1052,6 +1054,8 @@ QwenImageAutoBlocks, QwenImageEditAutoBlocks, QwenImageEditModularPipeline, + QwenImageEditPlusAutoBlocks, + QwenImageEditPlusModularPipeline, QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 65c22b349b1c..2e590594af71 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -52,6 +52,8 @@ "QwenImageModularPipeline", "QwenImageEditModularPipeline", "QwenImageEditAutoBlocks", + "QwenImageEditPlusModularPipeline", + "QwenImageEditPlusAutoBlocks", ] _import_structure["components_manager"] = ["ComponentsManager"] @@ -78,6 +80,8 @@ QwenImageAutoBlocks, QwenImageEditAutoBlocks, QwenImageEditModularPipeline, + QwenImageEditPlusAutoBlocks, + QwenImageEditPlusModularPipeline, QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 037c9e323c6b..e543bf0bb3af 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -59,6 +59,7 @@ ("flux", "FluxModularPipeline"), ("qwenimage", "QwenImageModularPipeline"), ("qwenimage-edit", "QwenImageEditModularPipeline"), + ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"), ] ) @@ -1628,7 +1629,8 @@ def from_pretrained( blocks = ModularPipelineBlocks.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs ) - except EnvironmentError: + except EnvironmentError as e: + logger.debug(f"EnvironmentError: {e}") blocks = None cache_dir = kwargs.pop("cache_dir", None) diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py index 81cf515730ef..ae4ec4799fbc 100644 --- a/src/diffusers/modular_pipelines/qwenimage/__init__.py +++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py @@ -29,13 +29,20 @@ "EDIT_AUTO_BLOCKS", "EDIT_BLOCKS", "EDIT_INPAINT_BLOCKS", + "EDIT_PLUS_AUTO_BLOCKS", + "EDIT_PLUS_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "TEXT2IMAGE_BLOCKS", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", + "QwenImageEditPlusAutoBlocks", + ] + _import_structure["modular_pipeline"] = [ + "QwenImageEditModularPipeline", + "QwenImageEditPlusModularPipeline", + "QwenImageModularPipeline", ] - _import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -54,13 +61,20 @@ EDIT_AUTO_BLOCKS, EDIT_BLOCKS, EDIT_INPAINT_BLOCKS, + EDIT_PLUS_AUTO_BLOCKS, + EDIT_PLUS_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, TEXT2IMAGE_BLOCKS, QwenImageAutoBlocks, QwenImageEditAutoBlocks, + QwenImageEditPlusAutoBlocks, + ) + from .modular_pipeline import ( + QwenImageEditModularPipeline, + QwenImageEditPlusModularPipeline, + QwenImageModularPipeline, ) - from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index 606236cfe91b..fdec95dc506e 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -203,7 +203,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - block_state.latents = components.pachifier.pack_latents(block_state.latents) self.set_block_state(state, block_state) - return components, state @@ -571,7 +570,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks): @property def description(self) -> str: - return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step" + return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step" @property def inputs(self) -> List[InputParam]: diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 2ab83a03ee55..04fb3fdc947b 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -128,6 +128,61 @@ def get_qwen_prompt_embeds_edit( return prompt_embeds, encoder_attention_mask +def get_qwen_prompt_embeds_edit_plus( + text_encoder, + processor, + prompt: Union[str, List[str]] = None, + image: Optional[Union[torch.Tensor, List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>", + prompt_template_encode_start_idx: int = 64, + device: Optional[torch.device] = None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_template_encode.format(i + 1) + elif image is not None: + base_img_prompt = img_template_encode.format(1) + else: + base_img_prompt = "" + + template = prompt_template_encode + + drop_idx = prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + outputs = text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(device=device) + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -266,6 +321,83 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep): + model_name = "qwenimage" + + def __init__( + self, + input_name: str = "image", + output_name: str = "resized_image", + vae_image_output_name: str = "vae_image", + ): + """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. + + This block resizes an input image or a list input images and exposes the resized result under configurable + input and output names. Use this when you need to wire the resize step to different image fields (e.g., + "image", "control_image") + + Args: + input_name (str, optional): Name of the image field to read from the + pipeline state. Defaults to "image". + output_name (str, optional): Name of the resized image field to write + back to the pipeline state. Defaults to "resized_image". + vae_image_output_name (str, optional): Name of the image field + to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus + processes the input image(s) differently for the VL and the VAE. + """ + if not isinstance(input_name, str) or not isinstance(output_name, str): + raise ValueError( + f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}" + ) + self.condition_image_size = 384 * 384 + self._image_input_name = input_name + self._resized_image_output_name = output_name + self._vae_image_output_name = vae_image_output_name + super().__init__() + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return super().intermediate_outputs + [ + OutputParam( + name=self._vae_image_output_name, + type_hint=List[PIL.Image.Image], + description="The images to be processed which will be further used by the VAE encoder.", + ), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + images = getattr(block_state, self._image_input_name) + + if not is_valid_image_imagelist(images): + raise ValueError(f"Images must be image or list of images but are {type(images)}") + + if ( + not isinstance(images, torch.Tensor) + and isinstance(images, PIL.Image.Image) + and not isinstance(images, list) + ): + images = [images] + + # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s + condition_images = [] + vae_images = [] + for img in images: + image_width, image_height = img.size + condition_width, condition_height, _ = calculate_dimensions( + self.condition_image_size, image_width / image_height + ) + condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width)) + vae_images.append(img) + + setattr(block_state, self._resized_image_output_name, condition_images) + setattr(block_state, self._vae_image_output_name, vae_images) + self.set_block_state(state, block_state) + return components, state + + class QwenImageTextEncoderStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -511,6 +643,61 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep): + model_name = "qwenimage" + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec( + name="prompt_template_encode", + default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + ), + ConfigSpec( + name="img_template_encode", + default="Picture {}: <|vision_start|><|image_pad|><|vision_end|>", + ), + ConfigSpec(name="prompt_template_encode_start_idx", default=64), + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + self.check_inputs(block_state.prompt, block_state.negative_prompt) + + device = components._execution_device + + block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus( + components.text_encoder, + components.processor, + prompt=block_state.prompt, + image=block_state.resized_image, + prompt_template_encode=components.config.prompt_template_encode, + img_template_encode=components.config.img_template_encode, + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + device=device, + ) + + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or " " + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = ( + get_qwen_prompt_embeds_edit_plus( + components.text_encoder, + components.processor, + prompt=negative_prompt, + image=block_state.resized_image, + prompt_template_encode=components.config.prompt_template_encode, + img_template_encode=components.config.img_template_encode, + prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx, + device=device, + ) + ) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -612,12 +799,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [ - InputParam("resized_image"), - InputParam("image"), - InputParam("height"), - InputParam("width"), - ] + return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] @property def intermediate_outputs(self) -> List[OutputParam]: @@ -661,6 +843,47 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): return components, state +class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): + model_name = "qwenimage-edit-plus" + vae_image_size = 1024 * 1024 + + @property + def description(self) -> str: + return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing." + + @property + def inputs(self) -> List[InputParam]: + return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + if block_state.vae_image is None and block_state.image is None: + raise ValueError("`vae_image` and `image` cannot be None at the same time") + + if block_state.vae_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + block_state.processed_image = components.image_processor.preprocess( + image=image, height=height, width=width + ) + else: + width, height = block_state.vae_image[0].size + image = block_state.vae_image + + block_state.processed_image = components.image_processor.preprocess( + image=image, height=height, width=width + ) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks): model_name = "qwenimage" @@ -738,7 +961,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - dtype=dtype, latent_channels=components.num_channels_latents, ) - setattr(block_state, self._image_latents_output_name, image_latents) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 9126766cc202..83bfcb3da4fd 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -37,6 +37,9 @@ ) from .encoders import ( QwenImageControlNetVaeEncoderStep, + QwenImageEditPlusProcessImagesInputStep, + QwenImageEditPlusResizeDynamicStep, + QwenImageEditPlusTextEncoderStep, QwenImageEditResizeDynamicStep, QwenImageEditTextEncoderStep, QwenImageInpaintProcessImagesInputStep, @@ -872,7 +875,151 @@ def description(self): ) -# 3. all block presets supported in QwenImage & QwenImage-Edit +#################### QwenImage Edit Plus ##################### + +# 3. QwenImage-Edit Plus + +## 3.1 QwenImage-Edit Plus / edit + +#### QwenImage-Edit Plus vl encoder: take both image and text prompts +QwenImageEditPlusVLEncoderBlocks = InsertableDict( + [ + ("resize", QwenImageEditPlusResizeDynamicStep()), + ("encode", QwenImageEditPlusTextEncoderStep()), + ] +) + + +class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditPlusVLEncoderBlocks.values() + block_names = QwenImageEditPlusVLEncoderBlocks.keys() + + @property + def description(self) -> str: + return "QwenImage-Edit Plus VL encoder step that encode the image an text prompts together." + + +#### QwenImage-Edit Plus vae encoder +QwenImageEditPlusVaeEncoderBlocks = InsertableDict( + [ + ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step + ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image + ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents + ] +) + + +class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): + model_name = "qwenimage" + block_classes = QwenImageEditPlusVaeEncoderBlocks.values() + block_names = QwenImageEditPlusVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "Vae encoder step that encode the image inputs into their latent representations." + + +#### QwenImage Edit Plus presets +EDIT_PLUS_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditPlusVLEncoderStep()), + ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), + ("input", QwenImageEditInputStep()), + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), + ("denoise", QwenImageEditDenoiseStep()), + ("decode", QwenImageDecodeStep()), + ] +) + + +# auto before_denoise step for edit tasks +class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = [QwenImageEditBeforeDenoiseStep] + block_names = ["edit"] + block_trigger_inputs = ["image_latents"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" + + "This is an auto pipeline block that works for edit (img2img) task.\n" + + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" + + " - if `image_latents` is not provided, step will be skipped." + ) + + +## 3.2 QwenImage-Edit Plus/auto encoders + + +class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks): + block_classes = [ + QwenImageEditPlusVaeEncoderStep, + ] + block_names = ["edit"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ( + "Vae encoder step that encode the image inputs into their latent representations. \n" + " This is an auto pipeline block that works for edit task.\n" + + " - `QwenImageEditPlusVaeEncoderStep` (edit) is used when `image` is provided.\n" + + " - if `image` is not provided, step will be skipped." + ) + + +## 3.3 QwenImage-Edit/auto blocks & presets + + +class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = [ + QwenImageEditAutoInputStep, + QwenImageEditPlusAutoBeforeDenoiseStep, + QwenImageEditAutoDenoiseStep, + ] + block_names = ["input", "before_denoise", "denoise"] + + @property + def description(self): + return ( + "Core step that performs the denoising process. \n" + + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n" + + " - `QwenImageEditPlusAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" + + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n" + + "This step support edit (img2img) workflow for QwenImage Edit Plus:\n" + + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n" + ) + + +EDIT_PLUS_AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", QwenImageEditPlusVLEncoderStep()), + ("vae_encoder", QwenImageEditPlusAutoVaeEncoderStep()), + ("denoise", QwenImageEditPlusCoreDenoiseStep()), + ("decode", QwenImageAutoDecodeStep()), + ] +) + + +class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = EDIT_PLUS_AUTO_BLOCKS.values() + block_names = EDIT_PLUS_AUTO_BLOCKS.keys() + + @property + def description(self): + return ( + "Auto Modular pipeline for edit (img2img) and edit tasks using QwenImage-Edit Plus.\n" + + "- for edit (img2img) generation, you need to provide `image`\n" + ) + + +# 3. all block presets supported in QwenImage, QwenImage-Edit, QwenImage-Edit Plus ALL_BLOCKS = { @@ -880,8 +1027,10 @@ def description(self): "img2img": IMAGE2IMAGE_BLOCKS, "edit": EDIT_BLOCKS, "edit_inpaint": EDIT_INPAINT_BLOCKS, + "edit_plus": EDIT_PLUS_BLOCKS, "inpaint": INPAINT_BLOCKS, "controlnet": CONTROLNET_BLOCKS, "auto": AUTO_BLOCKS, "edit_auto": EDIT_AUTO_BLOCKS, + "edit_plus_auto": EDIT_PLUS_AUTO_BLOCKS, } diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py index 7200169923a5..d9e30864f660 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py @@ -196,3 +196,13 @@ def requires_unconditional_embeds(self): requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 return requires_unconditional_embeds + + +class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline): + """ + A ModularPipeline for QwenImage-Edit Plus. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "QwenImageEditPlusAutoBlocks" diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index d265bfdcaf3d..8a32d4c367a3 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -95,6 +95,7 @@ QwenImageControlNetPipeline, QwenImageEditInpaintPipeline, QwenImageEditPipeline, + QwenImageEditPlusPipeline, QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline, @@ -186,6 +187,7 @@ ("flux-kontext", FluxKontextPipeline), ("qwenimage", QwenImageImg2ImgPipeline), ("qwenimage-edit", QwenImageEditPipeline), + ("qwenimage-edit-plus", QwenImageEditPlusPipeline), ] ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index bb8fea8c8a8b..cf8037796488 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -77,6 +77,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class QwenImageEditPlusAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class QwenImageEditPlusModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImageModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]