From 6e61370180406d2cda96283268583f078b30cc7c Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 30 Apr 2026 07:04:56 -0700 Subject: [PATCH 1/4] Address ernie-image review findings #13577 --- .../modular_pipelines/ernie_image/decoders.py | 4 +-- .../ernie_image/modular_blocks_ernie_image.py | 21 +++++++----- .../ernie_image/pipeline_ernie_image.py | 32 +++++++++---------- 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/diffusers/modular_pipelines/ernie_image/decoders.py b/src/diffusers/modular_pipelines/ernie_image/decoders.py index fb65e80f112f..7cd6d25cc443 100644 --- a/src/diffusers/modular_pipelines/ernie_image/decoders.py +++ b/src/diffusers/modular_pipelines/ernie_image/decoders.py @@ -75,9 +75,7 @@ def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) latents = block_state.latents bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device=device, dtype=latents.dtype) - bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( - device=device, dtype=latents.dtype - ) + bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device=device, dtype=latents.dtype) latents = latents * bn_std + bn_mean latents = components.pachifier.unpack_latents(latents) diff --git a/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py b/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py index e8d4c23a87b8..db27b897215e 100644 --- a/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py +++ b/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py @@ -13,7 +13,7 @@ # limitations under the License. from ...utils import logging -from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline import ConditionalPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import OutputParam from .before_denoise import ( ErnieImagePrepareLatentsStep, @@ -29,11 +29,11 @@ # auto_docstring -class ErnieImageAutoPromptEnhancerStep(AutoPipelineBlocks): +class ErnieImageAutoPromptEnhancerStep(ConditionalPipelineBlocks): """ - Auto block that runs the optional prompt enhancer when `use_pe` is provided. - - `ErnieImagePromptEnhancerStep` is used when `use_pe` is set. - - If `use_pe` is not provided, the step is skipped. + Conditional block that runs the optional prompt enhancer when `use_pe` is truthy. + - `ErnieImagePromptEnhancerStep` is used when `use_pe=True`. + - If `use_pe` is `None` or `False`, the step is skipped. Components: pe (`AutoModelForCausalLM`) pe_tokenizer (`AutoTokenizer`) @@ -66,12 +66,17 @@ class ErnieImageAutoPromptEnhancerStep(AutoPipelineBlocks): block_names = ["prompt_enhancer"] block_trigger_inputs = ["use_pe"] + def select_block(self, use_pe=None) -> str | None: + if use_pe: + return "prompt_enhancer" + return None + @property def description(self): return ( - "Auto block that runs the optional prompt enhancer when `use_pe` is provided.\n" - " - `ErnieImagePromptEnhancerStep` is used when `use_pe` is set.\n" - " - If `use_pe` is not provided, the step is skipped." + "Conditional block that runs the optional prompt enhancer when `use_pe` is truthy.\n" + " - `ErnieImagePromptEnhancerStep` is used when `use_pe=True`.\n" + " - If `use_pe` is `None` or `False`, the step is skipped." ) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index e6ea97c30e29..768b7092eb53 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -361,26 +361,26 @@ def __call__( progress_bar.update() if output_type == "latent": - return latents - - # Decode latents to images - # Unnormalize latents using VAE's BN stats - bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device) - bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device) - latents = latents * bn_std + bn_mean + images = latents + else: + # Decode latents to images + # Unnormalize latents using VAE's BN stats + bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device) + bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device) + latents = latents * bn_std + bn_mean - # Unpatchify - latents = self._unpatchify_latents(latents) + # Unpatchify + latents = self._unpatchify_latents(latents) - # Decode - images = self.vae.decode(latents, return_dict=False)[0] + # Decode + images = self.vae.decode(latents, return_dict=False)[0] - # Post-process - images = (images.clamp(-1, 1) + 1) / 2 - images = images.cpu().permute(0, 2, 3, 1).float().numpy() + # Post-process + images = (images.clamp(-1, 1) + 1) / 2 + images = images.cpu().permute(0, 2, 3, 1).float().numpy() - if output_type == "pil": - images = [Image.fromarray((img * 255).astype("uint8")) for img in images] + if output_type == "pil": + images = [Image.fromarray((img * 255).astype("uint8")) for img in images] # Offload all models self.maybe_free_model_hooks() From 2b297bf4b54deccb6cd5b82e881f29bca18259d7 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 30 Apr 2026 09:26:11 -0700 Subject: [PATCH 2/4] Use concrete Mistral3Model / Ministral3ForCausalLM types --- .../modular_pipelines/ernie_image/encoders.py | 10 +++++----- .../pipelines/ernie_image/pipeline_ernie_image.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/modular_pipelines/ernie_image/encoders.py b/src/diffusers/modular_pipelines/ernie_image/encoders.py index 24e9622c9422..74d02ffb4dba 100644 --- a/src/diffusers/modular_pipelines/ernie_image/encoders.py +++ b/src/diffusers/modular_pipelines/ernie_image/encoders.py @@ -15,7 +15,7 @@ import json import torch -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance @@ -38,7 +38,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: return [ - ComponentSpec("pe", AutoModelForCausalLM), + ComponentSpec("pe", Ministral3ForCausalLM), ComponentSpec("pe_tokenizer", AutoTokenizer), ] @@ -83,7 +83,7 @@ def intermediate_outputs(self) -> list[OutputParam]: @staticmethod def _enhance_prompt( - pe: AutoModelForCausalLM, + pe: Ministral3ForCausalLM, pe_tokenizer: AutoTokenizer, prompt: str, device: torch.device, @@ -160,7 +160,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: return [ - ComponentSpec("text_encoder", AutoModel), + ComponentSpec("text_encoder", Mistral3Model), ComponentSpec("tokenizer", AutoTokenizer), ComponentSpec( "guider", @@ -200,7 +200,7 @@ def intermediate_outputs(self) -> list[OutputParam]: @staticmethod def _encode( - text_encoder: AutoModel, + text_encoder: Mistral3Model, tokenizer: AutoTokenizer, prompt: list[str], device: torch.device, diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index 768b7092eb53..eef424fab1f4 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -21,7 +21,7 @@ import torch from PIL import Image -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model from ...models import AutoencoderKLFlux2 from ...models.transformers import ErnieImageTransformer2DModel @@ -51,10 +51,10 @@ def __init__( self, transformer: ErnieImageTransformer2DModel, vae: AutoencoderKLFlux2, - text_encoder: AutoModel, + text_encoder: Mistral3Model, tokenizer: AutoTokenizer, scheduler: FlowMatchEulerDiscreteScheduler, - pe: Optional[AutoModelForCausalLM] = None, + pe: Optional[Ministral3ForCausalLM] = None, pe_tokenizer: Optional[AutoTokenizer] = None, ): super().__init__() From 11767353905af12d96fc24d55c9d2625c3576c0f Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 30 Apr 2026 12:09:22 -0700 Subject: [PATCH 3/4] Cast bn_mean/bn_std to latents dtype + add TODO for hub eps --- .../pipelines/ernie_image/pipeline_ernie_image.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index eef424fab1f4..a906e98266f1 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -365,8 +365,11 @@ def __call__( else: # Decode latents to images # Unnormalize latents using VAE's BN stats - bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device) - bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device) + # TODO: switch to `self.vae.config.batch_norm_eps` once the hub config is updated to match the trained value (1e-5). + bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device=device, dtype=latents.dtype) + bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to( + device=device, dtype=latents.dtype + ) latents = latents * bn_std + bn_mean # Unpatchify From 26d8bc00e44b82e63d0961af45c3e4c4f9322a44 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 30 Apr 2026 12:17:35 -0700 Subject: [PATCH 4/4] Use VaeImageProcessor.postprocess in standard and modular ernie --- .../modular_pipelines/ernie_image/decoders.py | 22 +++++++------------ .../ernie_image/pipeline_ernie_image.py | 9 +++----- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/src/diffusers/modular_pipelines/ernie_image/decoders.py b/src/diffusers/modular_pipelines/ernie_image/decoders.py index 7cd6d25cc443..d7d056b82584 100644 --- a/src/diffusers/modular_pipelines/ernie_image/decoders.py +++ b/src/diffusers/modular_pipelines/ernie_image/decoders.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import torch -from PIL import Image from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState @@ -44,6 +43,12 @@ def expected_components(self) -> list[ComponentSpec]: config=FrozenDict({"patch_size": 2}), default_creation_method="from_config", ), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), ] @property @@ -81,18 +86,7 @@ def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) latents = components.pachifier.unpack_latents(latents) images = vae.decode(latents.to(vae.dtype), return_dict=False)[0] - images = (images.clamp(-1, 1) + 1) / 2 - - output_type = block_state.output_type - if output_type == "pt": - block_state.images = images - elif output_type == "np": - block_state.images = images.cpu().permute(0, 2, 3, 1).float().numpy() - elif output_type == "pil": - images_np = images.cpu().permute(0, 2, 3, 1).float().numpy() - block_state.images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images_np] - else: - raise ValueError(f"Unsupported `output_type`: {output_type!r}. Expected one of 'pil', 'np', 'pt'.") + block_state.images = components.image_processor.postprocess(images, output_type=block_state.output_type) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index a906e98266f1..464a05b2fb74 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -20,9 +20,9 @@ from typing import Callable, List, Optional, Union import torch -from PIL import Image from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKLFlux2 from ...models.transformers import ErnieImageTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline @@ -68,6 +68,7 @@ def __init__( pe_tokenizer=pe_tokenizer, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @property def guidance_scale(self): @@ -379,11 +380,7 @@ def __call__( images = self.vae.decode(latents, return_dict=False)[0] # Post-process - images = (images.clamp(-1, 1) + 1) / 2 - images = images.cpu().permute(0, 2, 3, 1).float().numpy() - - if output_type == "pil": - images = [Image.fromarray((img * 255).astype("uint8")) for img in images] + images = self.image_processor.postprocess(images, output_type=output_type) # Offload all models self.maybe_free_model_hooks()