Skip to content

Ernie image#1389

Merged
Artiprocher merged 11 commits intomodelscope:mainfrom
mi804:ernie-image
Apr 13, 2026
Merged

Ernie image#1389
Artiprocher merged 11 commits intomodelscope:mainfrom
mi804:ernie-image

Conversation

@mi804
Copy link
Copy Markdown
Collaborator

@mi804 mi804 commented Apr 12, 2026

Baidu/ERNIE-Image Integration

Summary

Integrate Baidu/ERNIE-Image into DiffSynth-Studio, supporting:

  • ✅ Text-to-Image — Text prompt to high-quality image generation
  • ✅ Text-to-Image with Prompt Enhancement — Optional PE module (Ministral3) that rewrites prompts for richer, more detailed outputs

This is a new_series integration. ERNIE-Image uses a custom SharedAdaLN + RoPE 3D + joint image-text attention architecture, distinct from existing DiT models (FLUX, Qwen-Image, Wan, etc.) in DiffSynth-Studio. The DiT is a 36-layer ~4.7B parameter transformer.

Model Components

Component File Lines Params Converter
DiT diffsynth/models/ernie_image_dit.py 367 ~4.7B ErnieImageDiTStateDictConverter
VAE Reuses existing Flux2VAE (diffsynth/models/flux2_vae.py) Reuses existing
Text Encoder diffsynth/models/ernie_image_text_encoder.py 77 ~3.8B ErnieImageTextEncoderStateDictConverter
PE diffsynth/models/ernie_image_pe.py 116 ~3.8B ErnieImagePEStateDictConverter

Pipeline: diffsynth/pipelines/ernie_image.py (386 lines) — ErnieImagePipeline

Dependencies Added

  • diffusers>=0.38.0.dev0 (optional, only for target library validation, not for DiffSynth runtime)
  • flash-attn (optional, CUDA acceleration)
  • All LLM components use transformers directly: Ministral3Model (text encoder), Ministral3ForCausalLM (PE). Weights loaded automatically via transformers.

Key Features

Text-to-Image

  • Standard text-to-image generation with FlowMatchScheduler (ERNIE-Image template)
  • Default: 1024x1024, 50 steps, CFG=4.0, bfloat16
  • Supports full training and LoRA training (training pipeline has I2I data path)

Text-to-Image with Prompt Enhancement

  • Optional use_pe=True flag enables prompt rewriting via Ministral3 Causal LM
  • Returns both generated image and the revised prompt
  • PE temperature and top_p configurable (pe_temperature=0.6, pe_top_p=0.95 defaults)
  • PE module is entirely optional — disabling it has no impact on base generation
  • No training support for PE (inference only)

Example Usage

from diffsynth.pipelines.ernie_image import ErnieImagePipeline, ModelConfig

# Basic Text-to-Image
pipe = ErnieImagePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    model_configs=[
        ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
        ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
        ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
    ],
)
image = pipe(prompt="一只黑白相间的中华田园犬", num_inference_steps=50, cfg_scale=4.0)
image.save("output.jpg")

# With Prompt Enhancement
pipe_pe = ErnieImagePipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    model_configs=[
        ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
        ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors"),
        ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
        ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="pe/model.safetensors"),
    ],
)
image, revised_prompt = pipe_pe(prompt="一只黑白相间的中华田园犬", use_pe=True)
print(f"Revised prompt: {revised_prompt}")

Architecture Details

  • DiT: SharedAdaLN structure with shift/scale/gate parameters shared between self-attention and MLP. RoPE 3D positional encoding across (t, y, x) dimensions. Joint attention over image tokens + text tokens. Internal format: [S, B, H] for transformer blocks, [B, S, H] for attention.
  • Text Encoder: Ministral3Model (transformers), 26 layers, hidden_size=3072, yarn RoPE. Outputs hidden_states[-2] (second-to-last layer) as text embeddings.
  • PE: Ministral3ForCausalLM (transformers), same config as text encoder. Uses chat_template for prompt rewriting, carries prompt text + target resolution in JSON format.
  • VAE: Directly reuses Flux2VAE — all parameters match (block_out_channels, latent_channels=32, patch_size=(2,2), etc.).

Scripts

Script Feature Description
examples/ernie_image/model_inference/Ernie-Image-T2I.py T2I Base text-to-image (no PE)
examples/ernie_image/model_inference/Ernie-Image-T2I-PE.py T2I + PE Text-to-image with prompt enhancement
examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py T2I (low VRAM) VRAM offload to CPU
examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I-PE.py T2I + PE (low VRAM) VRAM offload + PE
examples/ernie_image/model_training/train.py Training Training framework for full + LoRA
examples/ernie_image/model_training/full/Ernie-Image-T2I.sh Full training DeepSpeed ZeRO-3 config
examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh LoRA training LoRA training script
examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py Full training validation Load epoch-1 checkpoint, generate images
examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py LoRA training validation Load epoch-4 LoRA checkpoint, generate images

Notes

  • Hardware: Recommended 24GB+ VRAM for bfloat16 inference. Low VRAM configurations available with CPU offload (see examples/ernie_image/model_inference_low_vram/).
  • Scheduler: Uses FlowMatchScheduler with ERNIE-Image template (exponential shift). Scheduler parameters are fixed and not exposed to the user.
  • VAE: Reuses the existing Flux2VAE class — no new VAE code added.
  • Text Encoder & PE: Both components wrap transformers Ministral3 models (not Mistral3). Weights are loaded automatically via transformers; custom converters exist for config mapping.
  • Training: Supports full training (DeepSpeed ZeRO-3) and LoRA training. PE has no training support.
  • VRAM Management: Pipeline supports VRAM offload via vram_config in ModelConfig and vram_limit parameter. DiT is the only model kept on-device during denoising iterations.
  • Tiled VAE: Supports tiled VAE decode via tiled=True parameter for large images.

@mi804 mi804 requested review from Artiprocher and Copilot April 12, 2026 02:47
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the ERNIE-Image model family, including text-to-image generation, prompt enhancement, and training capabilities. Key additions include the ErnieImagePipeline, DiT architecture with 3D RoPE, and wrappers for text encoding and prompt enhancement. Feedback focuses on improving the robustness of the pipeline by ensuring tiled VAE operations are correctly parameterized, avoiding hardcoded constants for latent dimensions, and refining prompt enhancement settings to prevent potential memory or performance issues.

self.load_models_to_device(['vae'])
latents = inputs_shared["latents"]
# VAE decode handles BN unnormalization and unpatchify internally (Flux2VAE.decode L2105-2110)
image = self.vae.decode(latents)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The vae.decode call is missing the tiled parameters. This means that even if the user sets tiled=True in the pipeline call, the VAE decoding will not use tiling, which could lead to out-of-memory (OOM) issues for large images.

Suggested change
image = self.vae.decode(latents)
image = self.vae.decode(latents, tiled=inputs_shared.get("tiled", False), tile_size=inputs_shared.get("tile_size", 64), tile_stride=inputs_shared.get("tile_stride", 32))

Comment thread diffsynth/pipelines/ernie_image.py Outdated
output_ids = pipe.pe.generate(
**inputs,
max_new_tokens=pipe.pe_tokenizer.model_max_length,
do_sample=temperature != 1.0 or top_p != 1.0,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using model_max_length as max_new_tokens for prompt enhancement might be excessive and could lead to extremely long generation times or memory issues if the tokenizer's limit is very high (e.g., 32k). It is generally safer to use a fixed reasonable limit for prompt rewriting, such as 1024.

Suggested change
do_sample=temperature != 1.0 or top_p != 1.0,
max_new_tokens=1024,

Comment thread diffsynth/pipelines/ernie_image.py Outdated
Comment on lines +227 to +229
height = getattr(pipe, '_pe_height', 1024)
width = getattr(pipe, '_pe_width', 1024)
enhanced = self.enhance_prompt(pipe, prompt, height, width, pe_temperature, pe_top_p)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The height and width variables are being overwritten by getattr calls. Since these values are already passed as arguments to the process method via input_params, they should be used directly. The current implementation might ignore the user-specified resolution during prompt enhancement if those attributes are not set on the pipeline object.

Suggested change
height = getattr(pipe, '_pe_height', 1024)
width = getattr(pipe, '_pe_width', 1024)
enhanced = self.enhance_prompt(pipe, prompt, height, width, pe_temperature, pe_top_p)
enhanced = self.enhance_prompt(pipe, prompt, height, width, pe_temperature, pe_top_p)

Comment thread diffsynth/pipelines/ernie_image.py Outdated
Comment on lines +314 to +315
latent_h = height // 16
latent_w = width // 16
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The latent dimension calculation uses a hardcoded factor of 16. It is better to use pipe.height_division_factor and pipe.width_division_factor to ensure consistency with the pipeline's configuration and avoid potential mismatches if the division factor changes.

Suggested change
latent_h = height // 16
latent_w = width // 16
latent_h = height // pipe.height_division_factor
latent_w = width // pipe.width_division_factor

# I2I path: VAE encode input image
pipe.load_models_to_device(['vae'])
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
input_latents = pipe.vae.encode(image)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The vae.encode call should also support tiling to handle large input images in the Image-to-Image path, consistent with the decoding logic.

Suggested change
input_latents = pipe.vae.encode(image)
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new ERNIE-Image “new_series” integration to DiffSynth-Studio, including a dedicated pipeline, model implementations (DiT/Text Encoder/PE), state-dict converters, scheduler template wiring, VRAM-management mappings, and user-facing docs/examples.

Changes:

  • Introduces ErnieImagePipeline with optional Prompt Enhancement (PE) flow and registers ERNIE-Image with the model loader/config system.
  • Adds ERNIE-Image model implementations + state-dict converters and a FlowMatchScheduler template.
  • Adds documentation and multiple inference/training/validation example scripts (including low-VRAM variants).

Reviewed changes

Copilot reviewed 26 out of 26 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
README.md Adds ERNIE-Image changelog entry + quick-start/examples section.
README_zh.md Chinese README updates mirroring ERNIE-Image docs/quick-start/examples.
docs/en/index.rst Adds ERNIE-Image page to English docs toctree.
docs/zh/index.rst Adds ERNIE-Image page to Chinese docs toctree.
docs/en/Model_Details/ERNIE-Image.md New English ERNIE-Image model documentation (usage/training args).
docs/zh/Model_Details/ERNIE-Image.md New Chinese ERNIE-Image model documentation (usage/training args).
examples/ernie_image/model_inference/Ernie-Image-T2I.py ERNIE-Image text-to-image inference example.
examples/ernie_image/model_inference/Ernie-Image-T2I-PE.py Inference example enabling PE prompt rewriting.
examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py Low-VRAM inference example using offload config.
examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I-PE.py Low-VRAM inference + PE example.
examples/ernie_image/model_training/train.py Training entrypoint using the shared DiffSynth diffusion training framework.
examples/ernie_image/model_training/full/Ernie-Image-T2I.sh Full training launcher script (DeepSpeed ZeRO-3 config).
examples/ernie_image/model_training/full/accelerate_config_zero3.yaml Accelerate/DeepSpeed ZeRO-3 configuration for full training.
examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh LoRA training launcher script.
examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py Full-checkpoint validation example.
examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py LoRA-checkpoint validation example.
diffsynth/pipelines/ernie_image.py New ERNIE-Image pipeline implementation + PE unit + embedding/noise/VAE units.
diffsynth/models/ernie_image_dit.py New ERNIE-Image DiT implementation (SharedAdaLN + RoPE3D + joint attention).
diffsynth/models/ernie_image_text_encoder.py Text encoder wrapper around Ministral3Model (hidden-states output).
diffsynth/models/ernie_image_pe.py PE wrapper around Ministral3ForCausalLM for prompt rewriting.
diffsynth/utils/state_dict_converters/ernie_image_dit.py DiT state-dict converter stub (pass-through).
diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py Text-encoder state-dict key mapping converter.
diffsynth/utils/state_dict_converters/ernie_image_pe.py PE state-dict converter (key remapping into CausalLM format).
diffsynth/diffusion/flow_match.py Adds ERNIE-Image scheduler template + timestep function.
diffsynth/configs/vram_management_module_maps.py Adds VRAM management module maps for ERNIE-Image DiT/Text Encoder.
diffsynth/configs/model_configs.py Registers ERNIE-Image models into MODEL_CONFIGS series list.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +141 to +144
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'."
)
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ErnieImageAttention.__init__ raises for qk_norm=None even though the error message indicates None is a supported option. This makes qk_layernorm=False (which passes qk_norm=None) crash at model construction; handle qk_norm is None explicitly (e.g., set norm_q/norm_k=None).

Copilot uses AI. Check for mistakes.
Comment thread diffsynth/pipelines/ernie_image.py Outdated
Comment on lines +225 to +228
# Positive prompt: enhance with PE
pipe.load_models_to_device(self.onload_model_names)
height = getattr(pipe, '_pe_height', 1024)
width = getattr(pipe, '_pe_width', 1024)
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ErnieImageUnit_PromptEnhancer.process ignores the height/width values passed in and instead reads pipe._pe_height/_pe_width, which are never set anywhere in this pipeline. This causes PE to always use the fallback 1024x1024 resolution metadata regardless of user input; use the provided height/width arguments (or set these attributes earlier if that was the intent).

Suggested change
# Positive prompt: enhance with PE
pipe.load_models_to_device(self.onload_model_names)
height = getattr(pipe, '_pe_height', 1024)
width = getattr(pipe, '_pe_width', 1024)
# Positive prompt: enhance with PE using the resolution provided to this unit
pipe.load_models_to_device(self.onload_model_names)

Copilot uses AI. Check for mistakes.
Comment thread diffsynth/pipelines/ernie_image.py Outdated
Comment on lines +203 to +210
output_ids = pipe.pe.generate(
**inputs,
max_new_tokens=pipe.pe_tokenizer.model_max_length,
do_sample=temperature != 1.0 or top_p != 1.0,
temperature=temperature,
top_p=top_p,
pad_token_id=pipe.pe_tokenizer.pad_token_id,
eos_token_id=pipe.pe_tokenizer.eos_token_id,
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_new_tokens is set to pipe.pe_tokenizer.model_max_length, which for this model/config is extremely large and can lead to very long / runaway generations and high compute costs if EOS isn't produced promptly. Consider using a small, task-appropriate cap (and optionally a stop condition) for prompt rewriting.

Copilot uses AI. Check for mistakes.
],
)

state_dict = load_state_dict("./models/train/Ernie-Image-T2I_full/epoch-1.safetensors")
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_state_dict defaults to loading tensors on CPU; since the pipeline is created with device="cuda", calling pipe.dit.load_state_dict(state_dict) is likely to fail due to device mismatch. Load the checkpoint with device="cuda" (and matching dtype) or move the model to CPU before loading.

Suggested change
state_dict = load_state_dict("./models/train/Ernie-Image-T2I_full/epoch-1.safetensors")
state_dict = load_state_dict(
"./models/train/Ernie-Image-T2I_full/epoch-1.safetensors",
device="cuda",
torch_dtype=torch.bfloat16,
)

Copilot uses AI. Check for mistakes.
Comment on lines +18 to +26
model_configs=[
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="pe/model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"),
pe_tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="pe/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
Copy link

Copilot AI Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is labeled as a low-VRAM example, but the PE model ModelConfig(...) is missing the **vram_config settings used for the other components. Without offload config, the PE (multi‑B parameter) model will be loaded fully onto the compute device, which likely defeats low-VRAM mode; apply the same VRAM/offload config to the PE model as well.

Copilot uses AI. Check for mistakes.
Comment thread diffsynth/pipelines/ernie_image.py Outdated
from ..models.flux2_vae import Flux2VAE


# ============================================================
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary annotation. Please delete it.

Comment thread diffsynth/pipelines/ernie_image.py Outdated
self.dit: ErnieImageDiT = None
self.vae: Flux2VAE = None
self.tokenizer: AutoTokenizer = None
self.pe: ErnieImagePE = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prompt enhancement is not the capability of the base model. Please remove the prompt enhancement modules, including the model and the pipeline unit. We can assume that the inputed prompt is already enhanced by other tools outside.

@@ -0,0 +1,3 @@
def ErnieImageDiTStateDictConverter(state_dict):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This StateDictConverter doesn't do anything. Please remove it.

@Artiprocher Artiprocher merged commit 960d8c6 into modelscope:main Apr 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants