Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions src/diffusers/modular_pipelines/ernie_image/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch
from PIL import Image

from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLFlux2
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
Expand Down Expand Up @@ -44,6 +43,12 @@ def expected_components(self) -> list[ComponentSpec]:
config=FrozenDict({"patch_size": 2}),
default_creation_method="from_config",
),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]

@property
Expand Down Expand Up @@ -75,26 +80,13 @@ def __call__(self, components: ErnieImageModularPipeline, state: PipelineState)

latents = block_state.latents
bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device=device, dtype=latents.dtype)
bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
device=device, dtype=latents.dtype
)
bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device=device, dtype=latents.dtype)
latents = latents * bn_std + bn_mean

latents = components.pachifier.unpack_latents(latents)

images = vae.decode(latents.to(vae.dtype), return_dict=False)[0]
images = (images.clamp(-1, 1) + 1) / 2

output_type = block_state.output_type
if output_type == "pt":
block_state.images = images
elif output_type == "np":
block_state.images = images.cpu().permute(0, 2, 3, 1).float().numpy()
elif output_type == "pil":
images_np = images.cpu().permute(0, 2, 3, 1).float().numpy()
block_state.images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images_np]
else:
raise ValueError(f"Unsupported `output_type`: {output_type!r}. Expected one of 'pil', 'np', 'pt'.")
block_state.images = components.image_processor.postprocess(images, output_type=block_state.output_type)

self.set_block_state(state, block_state)
return components, state
10 changes: 5 additions & 5 deletions src/diffusers/modular_pipelines/ernie_image/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json

import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model

from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
Expand All @@ -38,7 +38,7 @@ def description(self) -> str:
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("pe", AutoModelForCausalLM),
ComponentSpec("pe", Ministral3ForCausalLM),
ComponentSpec("pe_tokenizer", AutoTokenizer),
]

Expand Down Expand Up @@ -83,7 +83,7 @@ def intermediate_outputs(self) -> list[OutputParam]:

@staticmethod
def _enhance_prompt(
pe: AutoModelForCausalLM,
pe: Ministral3ForCausalLM,
pe_tokenizer: AutoTokenizer,
prompt: str,
device: torch.device,
Expand Down Expand Up @@ -160,7 +160,7 @@ def description(self) -> str:
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("text_encoder", AutoModel),
ComponentSpec("text_encoder", Mistral3Model),
ComponentSpec("tokenizer", AutoTokenizer),
ComponentSpec(
"guider",
Expand Down Expand Up @@ -200,7 +200,7 @@ def intermediate_outputs(self) -> list[OutputParam]:

@staticmethod
def _encode(
text_encoder: AutoModel,
text_encoder: Mistral3Model,
tokenizer: AutoTokenizer,
prompt: list[str],
device: torch.device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline import ConditionalPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import OutputParam
from .before_denoise import (
ErnieImagePrepareLatentsStep,
Expand All @@ -29,11 +29,11 @@


# auto_docstring
class ErnieImageAutoPromptEnhancerStep(AutoPipelineBlocks):
class ErnieImageAutoPromptEnhancerStep(ConditionalPipelineBlocks):
"""
Auto block that runs the optional prompt enhancer when `use_pe` is provided.
- `ErnieImagePromptEnhancerStep` is used when `use_pe` is set.
- If `use_pe` is not provided, the step is skipped.
Conditional block that runs the optional prompt enhancer when `use_pe` is truthy.
- `ErnieImagePromptEnhancerStep` is used when `use_pe=True`.
- If `use_pe` is `None` or `False`, the step is skipped.

Components:
pe (`AutoModelForCausalLM`) pe_tokenizer (`AutoTokenizer`)
Expand Down Expand Up @@ -66,12 +66,17 @@ class ErnieImageAutoPromptEnhancerStep(AutoPipelineBlocks):
block_names = ["prompt_enhancer"]
block_trigger_inputs = ["use_pe"]

def select_block(self, use_pe=None) -> str | None:
if use_pe:
return "prompt_enhancer"
return None

@property
def description(self):
return (
"Auto block that runs the optional prompt enhancer when `use_pe` is provided.\n"
" - `ErnieImagePromptEnhancerStep` is used when `use_pe` is set.\n"
" - If `use_pe` is not provided, the step is skipped."
"Conditional block that runs the optional prompt enhancer when `use_pe` is truthy.\n"
" - `ErnieImagePromptEnhancerStep` is used when `use_pe=True`.\n"
" - If `use_pe` is `None` or `False`, the step is skipped."
)


Expand Down
42 changes: 21 additions & 21 deletions src/diffusers/pipelines/ernie_image/pipeline_ernie_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from typing import Callable, List, Optional, Union

import torch
from PIL import Image
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer, Ministral3ForCausalLM, Mistral3Model

from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLFlux2
from ...models.transformers import ErnieImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
Expand Down Expand Up @@ -51,10 +51,10 @@ def __init__(
self,
transformer: ErnieImageTransformer2DModel,
vae: AutoencoderKLFlux2,
text_encoder: AutoModel,
text_encoder: Mistral3Model,
tokenizer: AutoTokenizer,
scheduler: FlowMatchEulerDiscreteScheduler,
pe: Optional[AutoModelForCausalLM] = None,
pe: Optional[Ministral3ForCausalLM] = None,
pe_tokenizer: Optional[AutoTokenizer] = None,
):
super().__init__()
Expand All @@ -68,6 +68,7 @@ def __init__(
pe_tokenizer=pe_tokenizer,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

@property
def guidance_scale(self):
Expand Down Expand Up @@ -361,26 +362,25 @@ def __call__(
progress_bar.update()

if output_type == "latent":
return latents

# Decode latents to images
# Unnormalize latents using VAE's BN stats
bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device)
bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device)
latents = latents * bn_std + bn_mean

# Unpatchify
latents = self._unpatchify_latents(latents)
images = latents
else:
# Decode latents to images
# Unnormalize latents using VAE's BN stats
# TODO: switch to `self.vae.config.batch_norm_eps` once the hub config is updated to match the trained value (1e-5).
bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device=device, dtype=latents.dtype)
bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(
device=device, dtype=latents.dtype
)
latents = latents * bn_std + bn_mean

# Decode
images = self.vae.decode(latents, return_dict=False)[0]
# Unpatchify
latents = self._unpatchify_latents(latents)

# Post-process
images = (images.clamp(-1, 1) + 1) / 2
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
# Decode
images = self.vae.decode(latents, return_dict=False)[0]

if output_type == "pil":
images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
# Post-process
images = self.image_processor.postprocess(images, output_type=output_type)

# Offload all models
self.maybe_free_model_hooks()
Expand Down
Loading