Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the ERNIE-Image model family, including text-to-image generation, prompt enhancement, and training capabilities. Key additions include the ErnieImagePipeline, DiT architecture with 3D RoPE, and wrappers for text encoding and prompt enhancement. Feedback focuses on improving the robustness of the pipeline by ensuring tiled VAE operations are correctly parameterized, avoiding hardcoded constants for latent dimensions, and refining prompt enhancement settings to prevent potential memory or performance issues.
| self.load_models_to_device(['vae']) | ||
| latents = inputs_shared["latents"] | ||
| # VAE decode handles BN unnormalization and unpatchify internally (Flux2VAE.decode L2105-2110) | ||
| image = self.vae.decode(latents) |
There was a problem hiding this comment.
The vae.decode call is missing the tiled parameters. This means that even if the user sets tiled=True in the pipeline call, the VAE decoding will not use tiling, which could lead to out-of-memory (OOM) issues for large images.
| image = self.vae.decode(latents) | |
| image = self.vae.decode(latents, tiled=inputs_shared.get("tiled", False), tile_size=inputs_shared.get("tile_size", 64), tile_stride=inputs_shared.get("tile_stride", 32)) |
| output_ids = pipe.pe.generate( | ||
| **inputs, | ||
| max_new_tokens=pipe.pe_tokenizer.model_max_length, | ||
| do_sample=temperature != 1.0 or top_p != 1.0, |
There was a problem hiding this comment.
Using model_max_length as max_new_tokens for prompt enhancement might be excessive and could lead to extremely long generation times or memory issues if the tokenizer's limit is very high (e.g., 32k). It is generally safer to use a fixed reasonable limit for prompt rewriting, such as 1024.
| do_sample=temperature != 1.0 or top_p != 1.0, | |
| max_new_tokens=1024, |
| height = getattr(pipe, '_pe_height', 1024) | ||
| width = getattr(pipe, '_pe_width', 1024) | ||
| enhanced = self.enhance_prompt(pipe, prompt, height, width, pe_temperature, pe_top_p) |
There was a problem hiding this comment.
The height and width variables are being overwritten by getattr calls. Since these values are already passed as arguments to the process method via input_params, they should be used directly. The current implementation might ignore the user-specified resolution during prompt enhancement if those attributes are not set on the pipeline object.
| height = getattr(pipe, '_pe_height', 1024) | |
| width = getattr(pipe, '_pe_width', 1024) | |
| enhanced = self.enhance_prompt(pipe, prompt, height, width, pe_temperature, pe_top_p) | |
| enhanced = self.enhance_prompt(pipe, prompt, height, width, pe_temperature, pe_top_p) |
| latent_h = height // 16 | ||
| latent_w = width // 16 |
There was a problem hiding this comment.
The latent dimension calculation uses a hardcoded factor of 16. It is better to use pipe.height_division_factor and pipe.width_division_factor to ensure consistency with the pipeline's configuration and avoid potential mismatches if the division factor changes.
| latent_h = height // 16 | |
| latent_w = width // 16 | |
| latent_h = height // pipe.height_division_factor | |
| latent_w = width // pipe.width_division_factor |
| # I2I path: VAE encode input image | ||
| pipe.load_models_to_device(['vae']) | ||
| image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype) | ||
| input_latents = pipe.vae.encode(image) |
There was a problem hiding this comment.
There was a problem hiding this comment.
Pull request overview
Adds a new ERNIE-Image “new_series” integration to DiffSynth-Studio, including a dedicated pipeline, model implementations (DiT/Text Encoder/PE), state-dict converters, scheduler template wiring, VRAM-management mappings, and user-facing docs/examples.
Changes:
- Introduces
ErnieImagePipelinewith optional Prompt Enhancement (PE) flow and registers ERNIE-Image with the model loader/config system. - Adds ERNIE-Image model implementations + state-dict converters and a FlowMatchScheduler template.
- Adds documentation and multiple inference/training/validation example scripts (including low-VRAM variants).
Reviewed changes
Copilot reviewed 26 out of 26 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| README.md | Adds ERNIE-Image changelog entry + quick-start/examples section. |
| README_zh.md | Chinese README updates mirroring ERNIE-Image docs/quick-start/examples. |
| docs/en/index.rst | Adds ERNIE-Image page to English docs toctree. |
| docs/zh/index.rst | Adds ERNIE-Image page to Chinese docs toctree. |
| docs/en/Model_Details/ERNIE-Image.md | New English ERNIE-Image model documentation (usage/training args). |
| docs/zh/Model_Details/ERNIE-Image.md | New Chinese ERNIE-Image model documentation (usage/training args). |
| examples/ernie_image/model_inference/Ernie-Image-T2I.py | ERNIE-Image text-to-image inference example. |
| examples/ernie_image/model_inference/Ernie-Image-T2I-PE.py | Inference example enabling PE prompt rewriting. |
| examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.py | Low-VRAM inference example using offload config. |
| examples/ernie_image/model_inference_low_vram/Ernie-Image-T2I-PE.py | Low-VRAM inference + PE example. |
| examples/ernie_image/model_training/train.py | Training entrypoint using the shared DiffSynth diffusion training framework. |
| examples/ernie_image/model_training/full/Ernie-Image-T2I.sh | Full training launcher script (DeepSpeed ZeRO-3 config). |
| examples/ernie_image/model_training/full/accelerate_config_zero3.yaml | Accelerate/DeepSpeed ZeRO-3 configuration for full training. |
| examples/ernie_image/model_training/lora/Ernie-Image-T2I.sh | LoRA training launcher script. |
| examples/ernie_image/model_training/validate_full/Ernie-Image-T2I.py | Full-checkpoint validation example. |
| examples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.py | LoRA-checkpoint validation example. |
| diffsynth/pipelines/ernie_image.py | New ERNIE-Image pipeline implementation + PE unit + embedding/noise/VAE units. |
| diffsynth/models/ernie_image_dit.py | New ERNIE-Image DiT implementation (SharedAdaLN + RoPE3D + joint attention). |
| diffsynth/models/ernie_image_text_encoder.py | Text encoder wrapper around Ministral3Model (hidden-states output). |
| diffsynth/models/ernie_image_pe.py | PE wrapper around Ministral3ForCausalLM for prompt rewriting. |
| diffsynth/utils/state_dict_converters/ernie_image_dit.py | DiT state-dict converter stub (pass-through). |
| diffsynth/utils/state_dict_converters/ernie_image_text_encoder.py | Text-encoder state-dict key mapping converter. |
| diffsynth/utils/state_dict_converters/ernie_image_pe.py | PE state-dict converter (key remapping into CausalLM format). |
| diffsynth/diffusion/flow_match.py | Adds ERNIE-Image scheduler template + timestep function. |
| diffsynth/configs/vram_management_module_maps.py | Adds VRAM management module maps for ERNIE-Image DiT/Text Encoder. |
| diffsynth/configs/model_configs.py | Registers ERNIE-Image models into MODEL_CONFIGS series list. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| else: | ||
| raise ValueError( | ||
| f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'rms_norm'." | ||
| ) |
There was a problem hiding this comment.
ErnieImageAttention.__init__ raises for qk_norm=None even though the error message indicates None is a supported option. This makes qk_layernorm=False (which passes qk_norm=None) crash at model construction; handle qk_norm is None explicitly (e.g., set norm_q/norm_k=None).
| # Positive prompt: enhance with PE | ||
| pipe.load_models_to_device(self.onload_model_names) | ||
| height = getattr(pipe, '_pe_height', 1024) | ||
| width = getattr(pipe, '_pe_width', 1024) |
There was a problem hiding this comment.
ErnieImageUnit_PromptEnhancer.process ignores the height/width values passed in and instead reads pipe._pe_height/_pe_width, which are never set anywhere in this pipeline. This causes PE to always use the fallback 1024x1024 resolution metadata regardless of user input; use the provided height/width arguments (or set these attributes earlier if that was the intent).
| # Positive prompt: enhance with PE | |
| pipe.load_models_to_device(self.onload_model_names) | |
| height = getattr(pipe, '_pe_height', 1024) | |
| width = getattr(pipe, '_pe_width', 1024) | |
| # Positive prompt: enhance with PE using the resolution provided to this unit | |
| pipe.load_models_to_device(self.onload_model_names) |
| output_ids = pipe.pe.generate( | ||
| **inputs, | ||
| max_new_tokens=pipe.pe_tokenizer.model_max_length, | ||
| do_sample=temperature != 1.0 or top_p != 1.0, | ||
| temperature=temperature, | ||
| top_p=top_p, | ||
| pad_token_id=pipe.pe_tokenizer.pad_token_id, | ||
| eos_token_id=pipe.pe_tokenizer.eos_token_id, |
There was a problem hiding this comment.
max_new_tokens is set to pipe.pe_tokenizer.model_max_length, which for this model/config is extremely large and can lead to very long / runaway generations and high compute costs if EOS isn't produced promptly. Consider using a small, task-appropriate cap (and optionally a stop condition) for prompt rewriting.
| ], | ||
| ) | ||
|
|
||
| state_dict = load_state_dict("./models/train/Ernie-Image-T2I_full/epoch-1.safetensors") |
There was a problem hiding this comment.
load_state_dict defaults to loading tensors on CPU; since the pipeline is created with device="cuda", calling pipe.dit.load_state_dict(state_dict) is likely to fail due to device mismatch. Load the checkpoint with device="cuda" (and matching dtype) or move the model to CPU before loading.
| state_dict = load_state_dict("./models/train/Ernie-Image-T2I_full/epoch-1.safetensors") | |
| state_dict = load_state_dict( | |
| "./models/train/Ernie-Image-T2I_full/epoch-1.safetensors", | |
| device="cuda", | |
| torch_dtype=torch.bfloat16, | |
| ) |
| model_configs=[ | ||
| ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), | ||
| ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="text_encoder/model.safetensors", **vram_config), | ||
| ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), | ||
| ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="pe/model.safetensors"), | ||
| ], | ||
| tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="tokenizer/"), | ||
| pe_tokenizer_config=ModelConfig(model_id="baidu/ERNIE-Image", origin_file_pattern="pe/"), | ||
| vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, |
There was a problem hiding this comment.
This is labeled as a low-VRAM example, but the PE model ModelConfig(...) is missing the **vram_config settings used for the other components. Without offload config, the PE (multi‑B parameter) model will be loaded fully onto the compute device, which likely defeats low-VRAM mode; apply the same VRAM/offload config to the PE model as well.
| from ..models.flux2_vae import Flux2VAE | ||
|
|
||
|
|
||
| # ============================================================ |
There was a problem hiding this comment.
Unnecessary annotation. Please delete it.
| self.dit: ErnieImageDiT = None | ||
| self.vae: Flux2VAE = None | ||
| self.tokenizer: AutoTokenizer = None | ||
| self.pe: ErnieImagePE = None |
There was a problem hiding this comment.
Prompt enhancement is not the capability of the base model. Please remove the prompt enhancement modules, including the model and the pipeline unit. We can assume that the inputed prompt is already enhanced by other tools outside.
| @@ -0,0 +1,3 @@ | |||
| def ErnieImageDiTStateDictConverter(state_dict): | |||
There was a problem hiding this comment.
This StateDictConverter doesn't do anything. Please remove it.
Baidu/ERNIE-Image Integration
Summary
Integrate Baidu/ERNIE-Image into DiffSynth-Studio, supporting:
This is a new_series integration. ERNIE-Image uses a custom SharedAdaLN + RoPE 3D + joint image-text attention architecture, distinct from existing DiT models (FLUX, Qwen-Image, Wan, etc.) in DiffSynth-Studio. The DiT is a 36-layer ~4.7B parameter transformer.
Model Components
diffsynth/models/ernie_image_dit.pyErnieImageDiTStateDictConverterFlux2VAE(diffsynth/models/flux2_vae.py)diffsynth/models/ernie_image_text_encoder.pyErnieImageTextEncoderStateDictConverterdiffsynth/models/ernie_image_pe.pyErnieImagePEStateDictConverterPipeline:
diffsynth/pipelines/ernie_image.py(386 lines) —ErnieImagePipelineDependencies Added
diffusers>=0.38.0.dev0(optional, only for target library validation, not for DiffSynth runtime)flash-attn(optional, CUDA acceleration)Ministral3Model(text encoder),Ministral3ForCausalLM(PE). Weights loaded automatically via transformers.Key Features
Text-to-Image
ERNIE-Imagetemplate)Text-to-Image with Prompt Enhancement
use_pe=Trueflag enables prompt rewriting via Ministral3 Causal LMpe_temperature=0.6,pe_top_p=0.95defaults)Example Usage
Architecture Details
[S, B, H]for transformer blocks,[B, S, H]for attention.Ministral3Model(transformers), 26 layers, hidden_size=3072, yarn RoPE. Outputshidden_states[-2](second-to-last layer) as text embeddings.Ministral3ForCausalLM(transformers), same config as text encoder. Uses chat_template for prompt rewriting, carries prompt text + target resolution in JSON format.Flux2VAE— all parameters match (block_out_channels, latent_channels=32, patch_size=(2,2), etc.).Scripts
examples/ernie_image/model_inference/Ernie-Image-T2I.pyexamples/ernie_image/model_inference/Ernie-Image-T2I-PE.pyexamples/ernie_image/model_inference_low_vram/Ernie-Image-T2I.pyexamples/ernie_image/model_inference_low_vram/Ernie-Image-T2I-PE.pyexamples/ernie_image/model_training/train.pyexamples/ernie_image/model_training/full/Ernie-Image-T2I.shexamples/ernie_image/model_training/lora/Ernie-Image-T2I.shexamples/ernie_image/model_training/validate_full/Ernie-Image-T2I.pyexamples/ernie_image/model_training/validate_lora/Ernie-Image-T2I.pyNotes
examples/ernie_image/model_inference_low_vram/).ERNIE-Imagetemplate (exponential shift). Scheduler parameters are fixed and not exposed to the user.Flux2VAEclass — no new VAE code added.vram_configin ModelConfig andvram_limitparameter. DiT is the only model kept on-device during denoising iterations.tiled=Trueparameter for large images.