diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 50001470a46d..59883a92c962 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -449,6 +449,8 @@ "Flux2KleinModularPipeline", "Flux2ModularPipeline", "FluxAutoBlocks", + "ErnieImageAutoBlocks", + "ErnieImageModularPipeline", "FluxKontextAutoBlocks", "FluxKontextModularPipeline", "FluxModularPipeline", @@ -1234,6 +1236,8 @@ Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline, + ErnieImageAutoBlocks, + ErnieImageModularPipeline, FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index c4891d1c0f7d..41a9be92ec9c 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -88,6 +88,10 @@ "QwenImageLayeredModularPipeline", "QwenImageLayeredAutoBlocks", ] + _import_structure["ernie_image"] = [ + "ErnieImageAutoBlocks", + "ErnieImageModularPipeline", + ] _import_structure["ltx"] = [ "LTXAutoBlocks", "LTXModularPipeline", @@ -106,6 +110,7 @@ from ..utils.dummy_pt_objects import * # noqa F403 else: from .components_manager import ComponentsManager + from .ernie_image import ErnieImageAutoBlocks, ErnieImageModularPipeline from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline from .flux2 import ( Flux2AutoBlocks, diff --git a/src/diffusers/modular_pipelines/ernie_image/__init__.py b/src/diffusers/modular_pipelines/ernie_image/__init__.py new file mode 100644 index 000000000000..68ed723c590c --- /dev/null +++ b/src/diffusers/modular_pipelines/ernie_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_ernie_image"] = ["ErnieImageAutoBlocks"] + _import_structure["modular_pipeline"] = ["ErnieImageModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_ernie_image import ErnieImageAutoBlocks + from .modular_pipeline import ErnieImageModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/ernie_image/before_denoise.py b/src/diffusers/modular_pipelines/ernie_image/before_denoise.py new file mode 100644 index 000000000000..1c13c50f2db3 --- /dev/null +++ b/src/diffusers/modular_pipelines/ernie_image/before_denoise.py @@ -0,0 +1,269 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...models import ErnieImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ErnieImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _pad_text( + text_hiddens: list[torch.Tensor], device: torch.device, dtype: torch.dtype, text_in_dim: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Pad a list of variable-length text hidden states to a common length and return (padded, lengths).""" + batch_size = len(text_hiddens) + if batch_size == 0: + return ( + torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype), + torch.zeros((0,), device=device, dtype=torch.long), + ) + normalized = [t.squeeze(1).to(device).to(dtype) if t.dim() == 3 else t.to(device).to(dtype) for t in text_hiddens] + lengths = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long) + max_length = int(lengths.max().item()) + padded = torch.zeros((batch_size, max_length, text_in_dim), device=device, dtype=dtype) + for i, t in enumerate(normalized): + padded[i, : t.shape[0], :] = t + return padded, lengths + + +class ErnieImageTextInputStep(ModularPipelineBlocks): + model_name = "ernie-image" + + @property + def description(self) -> str: + return ( + "Input processing step that pads the variable-length text hidden states to a common length and " + "produces `text_bth` / `text_lens` tensors consumed by the denoiser." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", ErnieImageTransformer2DModel)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "prompt_embeds", + required=True, + type_hint=list, + description="List of per-prompt text embeddings from the text encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=list, + description="List of per-prompt negative text embeddings from the text encoder step.", + ), + InputParam( + "num_images_per_prompt", + type_hint=int, + default=1, + description="Number of images to generate per prompt.", + ), + InputParam( + "batch_size", + type_hint=int, + default=None, + description="Prompt batch size. Resolved from `prompt_embeds` when not provided.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("batch_size", type_hint=int, description="The number of prompts in the batch."), + OutputParam( + "text_bth", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Padded text hidden states of shape (B, T_max, H) fed into the transformer.", + ), + OutputParam( + "text_lens", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Actual per-prompt text lengths used to build the transformer attention mask.", + ), + OutputParam( + "negative_text_bth", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Padded negative text hidden states, when classifier-free guidance is enabled.", + ), + OutputParam( + "negative_text_lens", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="Actual per-prompt negative text lengths, when classifier-free guidance is enabled.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + dtype = components.transformer.dtype + text_in_dim = components.text_in_dim + + prompt_embeds = block_state.prompt_embeds + block_state.batch_size = getattr(block_state, "batch_size", None) or len(prompt_embeds) + + text_bth, text_lens = _pad_text(prompt_embeds, device, dtype, text_in_dim) + block_state.text_bth = text_bth + block_state.text_lens = text_lens + + negative_prompt_embeds = getattr(block_state, "negative_prompt_embeds", None) + if negative_prompt_embeds is not None: + negative_text_bth, negative_text_lens = _pad_text(negative_prompt_embeds, device, dtype, text_in_dim) + block_state.negative_text_bth = negative_text_bth + block_state.negative_text_lens = negative_text_lens + else: + block_state.negative_text_bth = None + block_state.negative_text_lens = None + + self.set_block_state(state, block_state) + return components, state + + +class ErnieImageSetTimestepsStep(ModularPipelineBlocks): + model_name = "ernie-image" + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference using a linear sigma schedule." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "num_inference_steps", + type_hint=int, + default=50, + description="Number of denoising steps.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference."), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps."), + ] + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + num_inference_steps = block_state.num_inference_steps + + sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] + components.scheduler.set_timesteps(sigmas=sigmas, device=device) + + block_state.timesteps = components.scheduler.timesteps + block_state.num_inference_steps = num_inference_steps + + self.set_block_state(state, block_state) + return components, state + + +class ErnieImagePrepareLatentsStep(ModularPipelineBlocks): + model_name = "ernie-image" + + @property + def description(self) -> str: + return "Prepare random noise latents for the ErnieImage text-to-image denoising process." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", ErnieImageTransformer2DModel)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int, description="The height in pixels of the generated image."), + InputParam("width", type_hint=int, description="The width in pixels of the generated image."), + InputParam( + "latents", + type_hint=torch.Tensor, + description="Pre-generated noisy latents. If provided, skips noise sampling.", + ), + InputParam( + "num_images_per_prompt", + type_hint=int, + default=1, + description="Number of images to generate per prompt.", + ), + InputParam("generator", description="Torch generator for deterministic noise sampling."), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Prompt batch size resolved by the text input step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("latents", type_hint=torch.Tensor, description="The initial noise latents to denoise."), + OutputParam("height", type_hint=int, description="The resolved image height in pixels."), + OutputParam("width", type_hint=int, description="The resolved image width in pixels."), + ] + + @staticmethod + def _check_inputs(components: ErnieImageModularPipeline, height: int, width: int) -> None: + vae_scale_factor = components.vae_scale_factor + if height % vae_scale_factor != 0 or width % vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` must be divisible by {vae_scale_factor}, got {height} and {width}." + ) + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + dtype = components.transformer.dtype + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + self._check_inputs(components, height, width) + + total_batch_size = block_state.batch_size * block_state.num_images_per_prompt + latent_h = height // components.vae_scale_factor + latent_w = width // components.vae_scale_factor + num_channels_latents = components.num_channels_latents + + shape = (total_batch_size, num_channels_latents, latent_h, latent_w) + if block_state.latents is None: + block_state.latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype) + else: + block_state.latents = block_state.latents.to(device=device, dtype=dtype) + + block_state.height = height + block_state.width = width + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/ernie_image/decoders.py b/src/diffusers/modular_pipelines/ernie_image/decoders.py new file mode 100644 index 000000000000..fb65e80f112f --- /dev/null +++ b/src/diffusers/modular_pipelines/ernie_image/decoders.py @@ -0,0 +1,100 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from PIL import Image + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLFlux2 +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ErnieImageModularPipeline, ErnieImagePachifier + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ErnieImageVaeDecoderStep(ModularPipelineBlocks): + model_name = "ernie-image" + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images (unpachify, BN denormalization, VAE decode)." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLFlux2), + ComponentSpec( + "pachifier", + ErnieImagePachifier, + config=FrozenDict({"patch_size": 2}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to decode into images.", + ), + InputParam( + "output_type", + type_hint=str, + default="pil", + description="Output format: 'pil', 'np', or 'pt'.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("images", type_hint=list, description="The generated images.")] + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + device = block_state.latents.device + + latents = block_state.latents + bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device=device, dtype=latents.dtype) + bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( + device=device, dtype=latents.dtype + ) + latents = latents * bn_std + bn_mean + + latents = components.pachifier.unpack_latents(latents) + + images = vae.decode(latents.to(vae.dtype), return_dict=False)[0] + images = (images.clamp(-1, 1) + 1) / 2 + + output_type = block_state.output_type + if output_type == "pt": + block_state.images = images + elif output_type == "np": + block_state.images = images.cpu().permute(0, 2, 3, 1).float().numpy() + elif output_type == "pil": + images_np = images.cpu().permute(0, 2, 3, 1).float().numpy() + block_state.images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images_np] + else: + raise ValueError(f"Unsupported `output_type`: {output_type!r}. Expected one of 'pil', 'np', 'pt'.") + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/ernie_image/denoise.py b/src/diffusers/modular_pipelines/ernie_image/denoise.py new file mode 100644 index 000000000000..d3ab9c78f60a --- /dev/null +++ b/src/diffusers/modular_pipelines/ernie_image/denoise.py @@ -0,0 +1,242 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import ErnieImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ErnieImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ErnieImageLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "ernie-image" + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that prepares the latent model input and timestep tensor. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `ErnieImageDenoiseLoopWrapper`)." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", ErnieImageTransformer2DModel)] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents = block_state.latents + block_state.latent_model_input = latents.to(components.transformer.dtype) + block_state.timestep = t.expand(latents.shape[0]).to(components.transformer.dtype) + return components, block_state + + +class ErnieImageLoopDenoiser(ModularPipelineBlocks): + model_name = "ernie-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", ErnieImageTransformer2DModel), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that runs the ErnieImage transformer with classifier-free guidance via " + "the configured guider." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "text_bth", + required=True, + type_hint=torch.Tensor, + description="Padded text hidden states fed into the transformer.", + ), + InputParam( + "text_lens", + required=True, + type_hint=torch.Tensor, + description="Per-prompt text lengths used by the transformer attention mask.", + ), + InputParam( + "negative_text_bth", + type_hint=torch.Tensor, + description="Padded negative text hidden states for classifier-free guidance.", + ), + InputParam( + "negative_text_lens", + type_hint=torch.Tensor, + description="Per-prompt negative text lengths for classifier-free guidance.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="Total number of denoising steps. Used by the guider for step-aware scheduling.", + ), + ] + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + guider_inputs = { + "text_bth": ( + getattr(block_state, "text_bth", None), + getattr(block_state, "negative_text_bth", None), + ), + "text_lens": ( + getattr(block_state, "text_lens", None), + getattr(block_state, "negative_text_lens", None), + ), + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {name: getattr(guider_state_batch, name) for name in guider_inputs.keys()} + noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep, + return_dict=False, + **cond_kwargs, + )[0] + guider_state_batch.noise_pred = noise_pred + components.guider.cleanup_models(components.transformer) + + block_state.noise_pred = components.guider(guider_state)[0] + return components, block_state + + +class ErnieImageLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "ernie-image" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step within the denoising loop that updates the latents using the scheduler step." + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, t, block_state.latents, return_dict=False + )[0] + if block_state.latents.dtype != latents_dtype and torch.backends.mps.is_available(): + block_state.latents = block_state.latents.to(latents_dtype) + return components, block_state + + +class ErnieImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "ernie-image" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoises the latents over `timesteps`. " + "The specific steps within each iteration can be customized with `sub_blocks` attribute." + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", ErnieImageTransformer2DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for inference.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of denoising steps.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents.")] + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + progress_bar.update() + self.set_block_state(state, block_state) + return components, state + + +class ErnieImageDenoiseStep(ErnieImageDenoiseLoopWrapper): + block_classes = [ + ErnieImageLoopBeforeDenoiser, + ErnieImageLoopDenoiser, + ErnieImageLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents. At each iteration it runs:\n" + " - `ErnieImageLoopBeforeDenoiser`\n" + " - `ErnieImageLoopDenoiser`\n" + " - `ErnieImageLoopAfterDenoiser`" + ) diff --git a/src/diffusers/modular_pipelines/ernie_image/encoders.py b/src/diffusers/modular_pipelines/ernie_image/encoders.py new file mode 100644 index 000000000000..a1e65b53d90b --- /dev/null +++ b/src/diffusers/modular_pipelines/ernie_image/encoders.py @@ -0,0 +1,286 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import torch +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import ErnieImageModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ErnieImagePromptEnhancerStep(ModularPipelineBlocks): + model_name = "ernie-image" + + @property + def description(self) -> str: + return ( + "Prompt enhancer step that rewrites the input prompt using a causal language model (PE). " + "If `use_pe` is False or the PE components are not loaded, the step is a no-op." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("pe", AutoModelForCausalLM), + ComponentSpec("pe_tokenizer", AutoTokenizer), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt", required=True, description="The prompt or prompts to guide image generation."), + InputParam("height", type_hint=int, description="The height in pixels of the generated image."), + InputParam("width", type_hint=int, description="The width in pixels of the generated image."), + InputParam( + "use_pe", + type_hint=bool, + default=True, + description="Whether to use the prompt enhancer to rewrite the prompt before encoding.", + ), + InputParam( + "pe_system_prompt", + type_hint=str, + default=None, + description="Optional system prompt passed to the prompt enhancer.", + ), + InputParam( + "pe_temperature", + type_hint=float, + default=0.6, + description="Sampling temperature used when generating with the prompt enhancer.", + ), + InputParam( + "pe_top_p", + type_hint=float, + default=0.95, + description="Nucleus sampling `top_p` used when generating with the prompt enhancer.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt", + type_hint=list, + description="The prompt list after optional prompt-enhancer rewriting.", + ), + OutputParam( + "revised_prompts", + type_hint=list, + description="The prompts returned by the prompt enhancer when it ran, else None.", + ), + ] + + @staticmethod + def _enhance_prompt( + pe: AutoModelForCausalLM, + pe_tokenizer: AutoTokenizer, + prompt: str, + device: torch.device, + width: int, + height: int, + system_prompt: str | None, + temperature: float, + top_p: float, + ) -> str: + user_content = json.dumps({"prompt": prompt, "width": width, "height": height}, ensure_ascii=False) + messages = [] + if system_prompt is not None: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": user_content}) + + input_text = pe_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + inputs = pe_tokenizer(input_text, return_tensors="pt").to(device) + output_ids = pe.generate( + **inputs, + max_new_tokens=pe_tokenizer.model_max_length, + do_sample=temperature != 1.0 or top_p != 1.0, + temperature=temperature, + top_p=top_p, + pad_token_id=pe_tokenizer.pad_token_id, + eos_token_id=pe_tokenizer.eos_token_id, + ) + generated_ids = output_ids[0][inputs["input_ids"].shape[1] :] + return pe_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + prompt = block_state.prompt + if isinstance(prompt, str): + prompt = [prompt] + + pe = getattr(components, "pe", None) + pe_tokenizer = getattr(components, "pe_tokenizer", None) + if not block_state.use_pe or pe is None or pe_tokenizer is None: + block_state.prompt = prompt + block_state.revised_prompts = None + self.set_block_state(state, block_state) + return components, state + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + + revised = [ + self._enhance_prompt( + pe=pe, + pe_tokenizer=pe_tokenizer, + prompt=p, + device=device, + width=width, + height=height, + system_prompt=block_state.pe_system_prompt, + temperature=block_state.pe_temperature, + top_p=block_state.pe_top_p, + ) + for p in prompt + ] + + block_state.prompt = revised + block_state.revised_prompts = list(revised) + + self.set_block_state(state, block_state) + return components, state + + +class ErnieImageTextEncoderStep(ModularPipelineBlocks): + model_name = "ernie-image" + + @property + def description(self) -> str: + return ( + "Text encoder step that encodes prompts into variable-length hidden states for the ErnieImage transformer." + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", AutoModel), + ComponentSpec("tokenizer", AutoTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt", description="The prompt or prompts to guide image generation."), + InputParam("negative_prompt", description="The prompt or prompts to avoid during image generation."), + InputParam( + "num_images_per_prompt", + type_hint=int, + default=1, + description="Number of images to generate per prompt.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=list, + kwargs_type="denoiser_input_fields", + description="List of per-prompt text embeddings of shape (T, H) used as conditioning for the transformer.", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=list, + kwargs_type="denoiser_input_fields", + description="List of per-prompt negative text embeddings used for classifier-free guidance.", + ), + ] + + @staticmethod + def _encode( + text_encoder: AutoModel, + tokenizer: AutoTokenizer, + prompt: list[str], + device: torch.device, + num_images_per_prompt: int, + ) -> list[torch.Tensor]: + text_hiddens = [] + for p in prompt: + ids = tokenizer(p, add_special_tokens=True, truncation=True, padding=False)["input_ids"] + if len(ids) == 0: + ids = [tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 0] + input_ids = torch.tensor([ids], device=device) + outputs = text_encoder(input_ids=input_ids, output_hidden_states=True) + # Second-to-last hidden state matches ErnieImage training + hidden = outputs.hidden_states[-2][0] + for _ in range(num_images_per_prompt): + text_hiddens.append(hidden) + return text_hiddens + + @torch.no_grad() + def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + prompt = block_state.prompt + if prompt is None: + prompt = [""] + if isinstance(prompt, str): + prompt = [prompt] + num_images_per_prompt = block_state.num_images_per_prompt + + block_state.prompt_embeds = self._encode( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt + if negative_prompt is None: + negative_prompt = "" + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) + if len(negative_prompt) != len(prompt): + raise ValueError( + f"`negative_prompt` must have the same length as `prompt` ({len(prompt)}), " + f"got {len(negative_prompt)}." + ) + block_state.negative_prompt_embeds = self._encode( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + else: + block_state.negative_prompt_embeds = None + + state.set("batch_size", len(prompt)) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py b/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py new file mode 100644 index 000000000000..83a8bb5988bc --- /dev/null +++ b/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py @@ -0,0 +1,153 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + ErnieImagePrepareLatentsStep, + ErnieImageSetTimestepsStep, + ErnieImageTextInputStep, +) +from .decoders import ErnieImageVaeDecoderStep +from .denoise import ErnieImageDenoiseStep +from .encoders import ErnieImagePromptEnhancerStep, ErnieImageTextEncoderStep + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# auto_docstring +class ErnieImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + Denoise block that takes encoded conditions and runs the denoising process for ErnieImage. + + Components: + transformer (`ErnieImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) + + Inputs: + prompt_embeds (`list`): + List of per-prompt text embeddings from the text encoder step. + negative_prompt_embeds (`list`, *optional*): + List of per-prompt negative text embeddings from the text encoder step. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate per prompt. + batch_size (`int`, *optional*): + Prompt batch size. Resolved from `prompt_embeds` when not provided. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents. If provided, skips noise sampling. + generator (`None`, *optional*): + Torch generator for deterministic noise sampling. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "ernie-image" + block_classes = [ + ErnieImageTextInputStep, + ErnieImageSetTimestepsStep, + ErnieImagePrepareLatentsStep, + ErnieImageDenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return "Denoise block that takes encoded conditions and runs the denoising process for ErnieImage." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class ErnieImageAutoBlocks(SequentialPipelineBlocks): + """ + Auto modular pipeline for ErnieImage text-to-image generation. Supports an optional prompt enhancer when the `pe` + components are loaded and `use_pe=True`. + + Supported workflows: + - `text2image`: requires `prompt` + + Components: + pe (`AutoModelForCausalLM`) pe_tokenizer (`AutoTokenizer`) text_encoder (`AutoModel`) tokenizer + (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) transformer (`ErnieImageTransformer2DModel`) scheduler + (`FlowMatchEulerDiscreteScheduler`) vae (`AutoencoderKLFlux2`) pachifier (`ErnieImagePachifier`) + + Inputs: + prompt (`None`): + The prompt or prompts to guide image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + use_pe (`bool`, *optional*, defaults to True): + Whether to use the prompt enhancer to rewrite the prompt before encoding. + pe_system_prompt (`str`, *optional*): + Optional system prompt passed to the prompt enhancer. + pe_temperature (`float`, *optional*, defaults to 0.6): + Sampling temperature used when generating with the prompt enhancer. + pe_top_p (`float`, *optional*, defaults to 0.95): + Nucleus sampling `top_p` used when generating with the prompt enhancer. + negative_prompt (`None`, *optional*): + The prompt or prompts to avoid during image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate per prompt. + batch_size (`int`, *optional*): + Prompt batch size. Resolved from `prompt_embeds` when not provided. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps. + latents (`Tensor`, *optional*): + Pre-generated noisy latents. If provided, skips noise sampling. + generator (`None`, *optional*): + Torch generator for deterministic noise sampling. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', or 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "ernie-image" + block_classes = [ + ErnieImagePromptEnhancerStep, + ErnieImageTextEncoderStep, + ErnieImageCoreDenoiseStep, + ErnieImageVaeDecoderStep, + ] + block_names = ["prompt_enhancer", "text_encoder", "denoise", "decode"] + _workflow_map = { + "text2image": {"prompt": True}, + } + + @property + def description(self): + return ( + "Auto modular pipeline for ErnieImage text-to-image generation. Supports an optional prompt enhancer " + "when the `pe` components are loaded and `use_pe=True`." + ) + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/src/diffusers/modular_pipelines/ernie_image/modular_pipeline.py b/src/diffusers/modular_pipelines/ernie_image/modular_pipeline.py new file mode 100644 index 000000000000..cf4497fe9138 --- /dev/null +++ b/src/diffusers/modular_pipelines/ernie_image/modular_pipeline.py @@ -0,0 +1,109 @@ +# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ErnieImagePachifier(ConfigMixin): + """ + A class to pack and unpack latents for ErnieImage. + """ + + config_name = "config.json" + + @register_to_config + def __init__(self, patch_size: int = 2): + super().__init__() + + def pack_latents(self, latents: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = latents.shape + patch_size = self.config.patch_size + + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError( + f"Latent height and width must be divisible by {patch_size}, but got {height} and {width}" + ) + + latents = latents.view( + batch_size, num_channels, height // patch_size, patch_size, width // patch_size, patch_size + ) + latents = latents.permute(0, 1, 3, 5, 2, 4) + return latents.reshape( + batch_size, num_channels * patch_size * patch_size, height // patch_size, width // patch_size + ) + + def unpack_latents(self, latents: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = latents.shape + patch_size = self.config.patch_size + + latents = latents.reshape( + batch_size, num_channels // (patch_size * patch_size), patch_size, patch_size, height, width + ) + latents = latents.permute(0, 1, 4, 2, 5, 3) + return latents.reshape( + batch_size, num_channels // (patch_size * patch_size), height * patch_size, width * patch_size + ) + + +class ErnieImageModularPipeline(ModularPipeline): + """ + A ModularPipeline for ErnieImage. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "ErnieImageAutoBlocks" + + @property + def default_height(self): + return 1024 + + @property + def default_width(self): + return 1024 + + @property + def vae_scale_factor(self): + vae_scale_factor = 16 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.config.block_out_channels) + return vae_scale_factor + + @property + def num_channels_latents(self): + num_channels_latents = 128 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_latents = self.transformer.config.in_channels + return num_channels_latents + + @property + def text_in_dim(self): + text_in_dim = 3584 + if hasattr(self, "transformer") and self.transformer is not None: + text_in_dim = self.transformer.config.text_in_dim + return text_in_dim + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + return requires_unconditional_embeds diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index ace89f0d6f91..b616a353663a 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -133,6 +133,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("helios", _create_default_map_fn("HeliosModularPipeline")), ("helios-pyramid", _helios_pyramid_map_fn), ("ltx", _create_default_map_fn("LTXModularPipeline")), + ("ernie-image", _create_default_map_fn("ErnieImageModularPipeline")), ] ) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 7198b46fb381..0c7dcfb1d61f 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -92,6 +92,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ErnieImageAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class ErnieImageModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class FluxAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/ernie_image/__init__.py b/tests/modular_pipelines/ernie_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/ernie_image/test_modular_pipeline_ernie_image.py b/tests/modular_pipelines/ernie_image/test_modular_pipeline_ernie_image.py new file mode 100644 index 000000000000..23be10abc073 --- /dev/null +++ b/tests/modular_pipelines/ernie_image/test_modular_pipeline_ernie_image.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from diffusers.modular_pipelines import ErnieImageAutoBlocks, ErnieImageModularPipeline + +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +ERNIE_IMAGE_WORKFLOWS = { + "text2image": [ + ("prompt_enhancer", "ErnieImagePromptEnhancerStep"), + ("text_encoder", "ErnieImageTextEncoderStep"), + ("denoise.input", "ErnieImageTextInputStep"), + ("denoise.set_timesteps", "ErnieImageSetTimestepsStep"), + ("denoise.prepare_latents", "ErnieImagePrepareLatentsStep"), + ("denoise.denoise", "ErnieImageDenoiseStep"), + ("decode", "ErnieImageVaeDecoderStep"), + ], +} + + +class TestErnieImageModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = ErnieImageModularPipeline + pipeline_blocks_class = ErnieImageAutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-ernie-image-modular-pipe" + + params = frozenset(["prompt", "height", "width"]) + batch_params = frozenset(["prompt"]) + optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents"]) + expected_workflow_blocks = ERNIE_IMAGE_WORKFLOWS + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + return { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "use_pe": False, + "output_type": "pt", + } + + @pytest.mark.skip(reason="PE generation is non-deterministic on CPU") + def test_float16_inference(self): + pass