Skip to content

chronoedit model/pipeline review #13620

@hlky

Description

@hlky

chronoedit model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate search: checked GitHub Issues/PRs for chronoedit, ChronoEditPipeline, ChronoEditTransformer3DModel, pipeline_chronoedit, transformer_chronoedit, temporal reasoning + model_outputs, num_videos_per_prompt + image_embeds, and slow-test coverage. Existing related items found: #12661, #12660, #12679, #13347. None duplicate Issues 1-6 below; PR #13347 is related to Issue 7's missing model-test coverage.

Issue 1: Temporal reasoning crashes with the default scheduler

Affected code:

if enable_temporal_reasoning and i == num_temporal_reasoning_steps:
latents = latents[:, :, [0, -1]]
condition = condition[:, :, [0, -1]]
for j in range(len(self.scheduler.model_outputs)):
if self.scheduler.model_outputs[j] is not None:
if latents.shape[-3] != self.scheduler.model_outputs[j].shape[-3]:
self.scheduler.model_outputs[j] = self.scheduler.model_outputs[j][:, :, [0, -1]]
if self.scheduler.last_sample is not None:
self.scheduler.last_sample = self.scheduler.last_sample[:, :, [0, -1]]

Problem:
enable_temporal_reasoning=True unconditionally accesses self.scheduler.model_outputs and self.scheduler.last_sample. The pipeline imports/types/tests FlowMatchEulerDiscreteScheduler, which does not define those UniPC-style fields.

Impact:
The documented temporal reasoning path can fail before completing the first denoising step when used with the scheduler family the pipeline itself advertises.

Reproduction:

from diffusers import FlowMatchEulerDiscreteScheduler

scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
print(hasattr(scheduler, "model_outputs"))  # False

# Same field accessed by ChronoEditPipeline when temporal reasoning truncates latents.
len(scheduler.model_outputs)

Relevant precedent:

# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.model_outputs = [None] * solver_order
self.timestep_list = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None

Suggested fix:

if hasattr(self.scheduler, "model_outputs"):
    for j, model_output in enumerate(self.scheduler.model_outputs):
        if model_output is not None and latents.shape[-3] != model_output.shape[-3]:
            self.scheduler.model_outputs[j] = model_output[:, :, [0, -1]]
if getattr(self.scheduler, "last_sample", None) is not None:
    self.scheduler.last_sample = self.scheduler.last_sample[:, :, [0, -1]]

Issue 2: num_videos_per_prompt > 1 leaves image embeddings under-batched

Affected code:

if image_embeds is None:
image_embeds = self.encode_image(image, device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)

Problem:
Prompt embeddings are expanded to batch_size * num_videos_per_prompt, but image embeddings are repeated only by batch_size. With num_videos_per_prompt=2, the transformer receives prompt batch 2 and image batch 1.

Impact:
Multi-video generation per prompt crashes or conditions samples with mismatched image context.

Reproduction:

import torch
from diffusers import ChronoEditTransformer3DModel

m = ChronoEditTransformer3DModel(
    patch_size=(1, 2, 2), num_attention_heads=2, attention_head_dim=12,
    in_channels=36, out_channels=16, text_dim=32, ffn_dim=32,
    num_layers=1, image_dim=4, rope_max_seq_len=32,
)
hidden = torch.randn(2, 36, 1, 2, 2)
text = torch.randn(2, 16, 32)
image = torch.randn(1, 257, 4)  # current pipeline batch after repeat(batch_size=1)
m(hidden, torch.tensor([1, 1]), text, image)

Relevant precedent:

# Get CLIP features from the reference image
if image_embeds is None:
image_embeds = self.encode_image(image, device)
image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)

Suggested fix:

if image_embeds.shape[0] == 1:
    image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1)
else:
    image_embeds = image_embeds.repeat_interleave(num_videos_per_prompt, dim=0)

Issue 3: image_embeds is accepted without image, but image is still required

Affected code:

if image is not None and image_embeds is not None:
raise ValueError(
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
" only forward one of the two."
)
if image is None and image_embeds is None:
raise ValueError(
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
)
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")

# 5. Prepare latent variables
num_channels_latents = self.vae.config.z_dim
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
latents, condition = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)

Problem:
check_inputs allows image=None when image_embeds is provided, but __call__ always preprocesses image and uses it to build VAE conditioning latents.

Impact:
Users trying to skip only the CLIP image encoder get a later, less actionable processor error. The API contract is misleading because image_embeds alone is insufficient.

Reproduction:

import torch
from diffusers import ChronoEditPipeline

pipe = object.__new__(ChronoEditPipeline)
pipe._callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

ChronoEditPipeline.check_inputs(
    pipe, prompt="x", negative_prompt=None, image=None,
    image_embeds=torch.zeros(1, 257, 1280), height=16, width=16,
)
print("validation accepted image_embeds without image")

Relevant precedent:
image_embeds may skip CLIP encoding, but VAE conditioning still needs pixels.

Suggested fix:

if image is None:
    raise ValueError("`image` is required for VAE conditioning; `image_embeds` only skips CLIP image encoding.")

Issue 4: num_frames is silently ignored unless temporal reasoning is enabled

Affected code:

num_frames = 5 if not enable_temporal_reasoning else num_frames
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)

Problem:
When enable_temporal_reasoning=False, num_frames is overwritten to 5 before validation and latent preparation.

Impact:
The default num_frames=81 and user-provided values do not describe the actual output shape in the default path.

Reproduction:

num_frames = 9
enable_temporal_reasoning = False
num_frames = 5 if not enable_temporal_reasoning else num_frames
print(num_frames)  # 5, not 9

Relevant precedent:
The num_frames docstring says it controls generated video length:

num_frames (`int`, defaults to `81`):
The number of frames in the generated video.

Suggested fix:
Either honor num_frames in both modes, or remove it from the non-temporal API path and raise if users pass a non-5 value while temporal reasoning is disabled.

Issue 5: Optional ftfy guard is incomplete

Affected code:

if is_ftfy_available():
import ftfy

def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()

Problem:
ftfy is imported only when available, but basic_clean() always calls ftfy.fix_text.

Impact:
A normal diffusers install with torch/transformers but without the test extra ftfy can import the pipeline and then fail during prompt encoding.

Reproduction:

import diffusers.pipelines.chronoedit.pipeline_chronoedit as chrono

old_ftfy = getattr(chrono, "ftfy", None)
if hasattr(chrono, "ftfy"):
    del chrono.ftfy
try:
    chrono.basic_clean("hello & world")
finally:
    if old_ftfy is not None:
        chrono.ftfy = old_ftfy

Relevant precedent:

def basic_clean(text):
if is_ftfy_available():
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()

Suggested fix:

def basic_clean(text):
    if is_ftfy_available():
        text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()

Issue 6: RoPE precompute still requests float64

Affected code:

h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
freqs_sin = []
for dim in [t_dim, h_dim, w_dim]:
freq_cos, freq_sin = get_1d_rotary_pos_embed(
dim,
max_seq_len,
theta,
use_real=True,
repeat_interleave_real=True,
freqs_dtype=freqs_dtype,
)

Problem:
ChronoEditRotaryPosEmbed selects torch.float64 on non-MPS systems. The helper later casts real cos/sin buffers to float32, so the float64 work is unnecessary and violates the review rule against unconditional float64 in models.

Impact:
This adds avoidable construction-time dtype work and keeps an NPU/MPS portability footgun in new model code.

Reproduction:

import inspect
from diffusers.models.transformers import transformer_chronoedit

src = inspect.getsource(transformer_chronoedit.ChronoEditRotaryPosEmbed.__init__)
print("torch.float64" in src)

Relevant precedent:

if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio, allegro
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]

Suggested fix:

freqs_dtype = torch.float32

Issue 7: Slow tests and dedicated model tests are missing

Affected code:

class ChronoEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = ChronoEditPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False

@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
def test_inference_batch_single_identical(self):
pass
@unittest.skip(
"ChronoEditPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
)
def test_save_load_float16(self):

Problem:
The target has one fast pipeline test file, no @slow ChronoEdit integration test, and no checked-in tests/models/transformers/test_models_transformer_chronoedit.py at this commit. Several common pipeline tests are skipped, including batch-identical and fp16 save/load.

Impact:
The real checkpoint path, temporal reasoning, LoRA examples, model save/load behavior, attention processors, gradient checkpointing, and batch/image edge cases above are not covered.

Reproduction:

from pathlib import Path

files = [str(p) for p in Path("tests").rglob("*chronoedit*.py")]
print(files)
print(any("@slow" in Path(p).read_text(encoding="utf-8") for p in files))
print(Path("tests/models/transformers/test_models_transformer_chronoedit.py").exists())

Relevant precedent:

@slow
@require_torch_accelerator
class WanPipelineIntegrationTests(unittest.TestCase):

class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Transformer 3D."""
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Transformer 3D."""
class TestWanTransformer3DTraining(WanTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanTransformer3DAttention(WanTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Transformer 3D."""
class TestWanTransformer3DCompile(WanTransformer3DTesterConfig, TorchCompileTesterMixin):

Suggested fix:
Add a ChronoEdit transformer model test file following Wan's model mixins, and add at least one @slow pipeline integration test for nvidia/ChronoEdit-14B-Diffusers, including temporal reasoning and the documented LoRA/scheduler path.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions