diff --git a/src/diffusers/modular_pipelines/ernie_image/decoders.py b/src/diffusers/modular_pipelines/ernie_image/decoders.py index fb65e80f112f..d7d056b82584 100644 --- a/src/diffusers/modular_pipelines/ernie_image/decoders.py +++ b/src/diffusers/modular_pipelines/ernie_image/decoders.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np import torch -from PIL import Image from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKLFlux2 from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState @@ -44,6 +43,12 @@ def expected_components(self) -> list[ComponentSpec]: config=FrozenDict({"patch_size": 2}), default_creation_method="from_config", ), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), ] @property @@ -75,26 +80,13 @@ def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) latents = block_state.latents bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device=device, dtype=latents.dtype) - bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to( - device=device, dtype=latents.dtype - ) + bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device=device, dtype=latents.dtype) latents = latents * bn_std + bn_mean latents = components.pachifier.unpack_latents(latents) images = vae.decode(latents.to(vae.dtype), return_dict=False)[0] - images = (images.clamp(-1, 1) + 1) / 2 - - output_type = block_state.output_type - if output_type == "pt": - block_state.images = images - elif output_type == "np": - block_state.images = images.cpu().permute(0, 2, 3, 1).float().numpy() - elif output_type == "pil": - images_np = images.cpu().permute(0, 2, 3, 1).float().numpy() - block_state.images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images_np] - else: - raise ValueError(f"Unsupported `output_type`: {output_type!r}. Expected one of 'pil', 'np', 'pt'.") + block_state.images = components.image_processor.postprocess(images, output_type=block_state.output_type) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/ernie_image/encoders.py b/src/diffusers/modular_pipelines/ernie_image/encoders.py index 24e9622c9422..74d02ffb4dba 100644 --- a/src/diffusers/modular_pipelines/ernie_image/encoders.py +++ b/src/diffusers/modular_pipelines/ernie_image/encoders.py @@ -15,7 +15,7 @@ import json import torch -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance @@ -38,7 +38,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: return [ - ComponentSpec("pe", AutoModelForCausalLM), + ComponentSpec("pe", Ministral3ForCausalLM), ComponentSpec("pe_tokenizer", AutoTokenizer), ] @@ -83,7 +83,7 @@ def intermediate_outputs(self) -> list[OutputParam]: @staticmethod def _enhance_prompt( - pe: AutoModelForCausalLM, + pe: Ministral3ForCausalLM, pe_tokenizer: AutoTokenizer, prompt: str, device: torch.device, @@ -160,7 +160,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: return [ - ComponentSpec("text_encoder", AutoModel), + ComponentSpec("text_encoder", Mistral3Model), ComponentSpec("tokenizer", AutoTokenizer), ComponentSpec( "guider", @@ -200,7 +200,7 @@ def intermediate_outputs(self) -> list[OutputParam]: @staticmethod def _encode( - text_encoder: AutoModel, + text_encoder: Mistral3Model, tokenizer: AutoTokenizer, prompt: list[str], device: torch.device, diff --git a/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py b/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py index e8d4c23a87b8..db27b897215e 100644 --- a/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py +++ b/src/diffusers/modular_pipelines/ernie_image/modular_blocks_ernie_image.py @@ -13,7 +13,7 @@ # limitations under the License. from ...utils import logging -from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline import ConditionalPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import OutputParam from .before_denoise import ( ErnieImagePrepareLatentsStep, @@ -29,11 +29,11 @@ # auto_docstring -class ErnieImageAutoPromptEnhancerStep(AutoPipelineBlocks): +class ErnieImageAutoPromptEnhancerStep(ConditionalPipelineBlocks): """ - Auto block that runs the optional prompt enhancer when `use_pe` is provided. - - `ErnieImagePromptEnhancerStep` is used when `use_pe` is set. - - If `use_pe` is not provided, the step is skipped. + Conditional block that runs the optional prompt enhancer when `use_pe` is truthy. + - `ErnieImagePromptEnhancerStep` is used when `use_pe=True`. + - If `use_pe` is `None` or `False`, the step is skipped. Components: pe (`AutoModelForCausalLM`) pe_tokenizer (`AutoTokenizer`) @@ -66,12 +66,17 @@ class ErnieImageAutoPromptEnhancerStep(AutoPipelineBlocks): block_names = ["prompt_enhancer"] block_trigger_inputs = ["use_pe"] + def select_block(self, use_pe=None) -> str | None: + if use_pe: + return "prompt_enhancer" + return None + @property def description(self): return ( - "Auto block that runs the optional prompt enhancer when `use_pe` is provided.\n" - " - `ErnieImagePromptEnhancerStep` is used when `use_pe` is set.\n" - " - If `use_pe` is not provided, the step is skipped." + "Conditional block that runs the optional prompt enhancer when `use_pe` is truthy.\n" + " - `ErnieImagePromptEnhancerStep` is used when `use_pe=True`.\n" + " - If `use_pe` is `None` or `False`, the step is skipped." ) diff --git a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py index e6ea97c30e29..464a05b2fb74 100644 --- a/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py +++ b/src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py @@ -20,9 +20,9 @@ from typing import Callable, List, Optional, Union import torch -from PIL import Image -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKLFlux2 from ...models.transformers import ErnieImageTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline @@ -51,10 +51,10 @@ def __init__( self, transformer: ErnieImageTransformer2DModel, vae: AutoencoderKLFlux2, - text_encoder: AutoModel, + text_encoder: Mistral3Model, tokenizer: AutoTokenizer, scheduler: FlowMatchEulerDiscreteScheduler, - pe: Optional[AutoModelForCausalLM] = None, + pe: Optional[Ministral3ForCausalLM] = None, pe_tokenizer: Optional[AutoTokenizer] = None, ): super().__init__() @@ -68,6 +68,7 @@ def __init__( pe_tokenizer=pe_tokenizer, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @property def guidance_scale(self): @@ -361,26 +362,25 @@ def __call__( progress_bar.update() if output_type == "latent": - return latents - - # Decode latents to images - # Unnormalize latents using VAE's BN stats - bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device) - bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device) - latents = latents * bn_std + bn_mean - - # Unpatchify - latents = self._unpatchify_latents(latents) + images = latents + else: + # Decode latents to images + # Unnormalize latents using VAE's BN stats + # TODO: switch to `self.vae.config.batch_norm_eps` once the hub config is updated to match the trained value (1e-5). + bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device=device, dtype=latents.dtype) + bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to( + device=device, dtype=latents.dtype + ) + latents = latents * bn_std + bn_mean - # Decode - images = self.vae.decode(latents, return_dict=False)[0] + # Unpatchify + latents = self._unpatchify_latents(latents) - # Post-process - images = (images.clamp(-1, 1) + 1) / 2 - images = images.cpu().permute(0, 2, 3, 1).float().numpy() + # Decode + images = self.vae.decode(latents, return_dict=False)[0] - if output_type == "pil": - images = [Image.fromarray((img * 255).astype("uint8")) for img in images] + # Post-process + images = self.image_processor.postprocess(images, output_type=output_type) # Offload all models self.maybe_free_model_hooks()