Skip to content

mochi model/pipeline review #13615

@hlky

Description

@hlky

mochi model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules. I read the available referenced rule files; AGENTS.md is referenced by .ai/review-rules.md but is not present in this checkout.

Duplicate search status: searched GitHub Issues and PRs for mochi, pipeline_mochi.py, AutoencoderKLMochi, MochiPipelineOutput, linear_quadratic_schedule, force_zeros_for_empty_prompt, MochiAttention, and the failure modes below. I did not find likely duplicates. Related but not duplicates: closed issue #11291 covers a different Mochi VAE tiling-width bug, and open PR #13348 refactors Mochi transformer tests only.

Issue 1: Custom timesteps are unusable and one-step schedules crash

Affected code:

def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None:
linear_steps = num_steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
quadratic_steps = num_steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
const = quadratic_coef * (linear_steps**2)
quadratic_sigma_schedule = [
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule

# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
threshold_noise = 0.025
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
sigmas = np.array(sigmas)
if XLA_AVAILABLE:
timestep_device = "cpu"
else:
timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
timestep_device,
timesteps,
sigmas,

Problem:
__call__ always builds Mochi’s internal sigmas schedule and passes it to retrieve_timesteps, even when the user also passes timesteps. retrieve_timesteps rejects receiving both. Separately, linear_quadratic_schedule(num_steps=1, ...) divides by zero.

Impact:
The public timesteps argument documented on MochiPipeline.__call__ cannot be used. num_inference_steps=1 also crashes before denoising.

Reproduction:

from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.mochi.pipeline_mochi import linear_quadratic_schedule, retrieve_timesteps

for name, fn in {
    "one step": lambda: linear_quadratic_schedule(1, 0.025),
    "custom timesteps": lambda: retrieve_timesteps(
        FlowMatchEulerDiscreteScheduler(),
        num_inference_steps=2,
        device="cpu",
        timesteps=[999, 500],
        sigmas=linear_quadratic_schedule(2, 0.025),
    ),
}.items():
    try:
        fn()
    except Exception as e:
        print(name, type(e).__name__, e)

Relevant precedent:
retrieve_timesteps already supports either custom timesteps or sigmas; the pipeline should not pass both.

Suggested fix:

sigmas = None
if timesteps is None:
    sigmas = np.array(linear_quadratic_schedule(num_inference_steps, threshold_noise), dtype=np.float32)

timesteps, num_inference_steps = retrieve_timesteps(
    self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
)

Also either support num_inference_steps=1 in linear_quadratic_schedule or reject it with a clear check_inputs error.

Issue 2: Spatial and temporal input validation accepts unsupported requests

Affected code:

# TODO: determine these scaling factors from model parameters
self.vae_spatial_scale_factor = 8
self.vae_temporal_scale_factor = 6
self.patch_size = 2

if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

height = height // self.vae_spatial_scale_factor
width = width // self.vae_spatial_scale_factor
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
shape = (batch_size, num_channels_latents, num_frames, height, width)

Problem:
The pipeline only checks height/width divisibility by 8, but the transformer also patches the VAE latents by patch_size=2, so image dimensions need to be divisible by 16 for the default model. num_frames is also accepted even when (num_frames - 1) is not divisible by the VAE temporal scale; those requests silently decode fewer frames than requested.

Impact:
Users get late transformer reshape errors for sizes like 24x24, or silently receive fewer frames than requested for values like num_frames=8.

Reproduction:

import torch
from diffusers import AutoencoderKLMochi, MochiPipeline, MochiTransformer3DModel

pipe = MochiPipeline(scheduler=None, vae=None, text_encoder=None, tokenizer=None, transformer=None)
pipe.check_inputs(prompt="x", height=24, width=24)
print("24x24 accepted")

model = MochiTransformer3DModel(
    patch_size=2, num_attention_heads=2, attention_head_dim=8, num_layers=1,
    pooled_projection_dim=16, in_channels=4, text_embed_dim=16, time_embed_dim=4, max_sequence_length=8,
)
try:
    model(torch.randn(1, 4, 1, 3, 3), torch.randn(1, 8, 16), torch.tensor([1.0]), torch.ones(1, 8).bool())
except RuntimeError as e:
    print(type(e).__name__, e)

vae = AutoencoderKLMochi(
    latent_channels=4, encoder_block_out_channels=(32, 32, 32, 32),
    decoder_block_out_channels=(32, 32, 32, 32), layers_per_block=(1, 1, 1, 1, 1),
)
pipe = MochiPipeline(scheduler=None, vae=vae, text_encoder=None, tokenizer=None, transformer=None)
latents = pipe.prepare_latents(1, 4, 16, 16, 8, torch.float32, "cpu", torch.Generator().manual_seed(0))
print("requested 8, decoded", vae.decode(latents, return_dict=False)[0].shape[2])

Relevant precedent:
Wan derives and enforces a 16 spatial factor and rounds invalid frame counts:

if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")

if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)

Suggested fix:

self.vae_spatial_scale_factor = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.vae_temporal_scale_factor = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 6
self.patch_size = self.transformer.config.patch_size if getattr(self, "transformer", None) else 2

spatial_factor = self.vae_spatial_scale_factor * self.patch_size
if height % spatial_factor != 0 or width % spatial_factor != 0:
    raise ValueError(f"`height` and `width` have to be divisible by {spatial_factor}.")

if num_frames % self.vae_temporal_scale_factor != 1:
    logger.warning(...)
    num_frames = num_frames // self.vae_temporal_scale_factor * self.vae_temporal_scale_factor + 1

Issue 3: Empty-prompt zeroing is applied to the wrong batch rows

Affected code:

# The original Mochi implementation zeros out empty negative prompts
# but this can lead to overflow when placing the entire pipeline under the autocast context
# adding this here so that we can enable zeroing prompts if necessary
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
text_input_ids = torch.zeros_like(text_input_ids, device=device)
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)

Problem:
_get_t5_prompt_embeds checks prompt == "" or prompt[-1] == "". For batched prompts, this only looks at the last prompt and then zeroes the entire batch. If the first prompt is empty and the last is not, the empty row is not zeroed.

Impact:
Batched negative prompts can condition the wrong samples, especially with force_zeros_for_empty_prompt=True.

Reproduction:

import torch
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
from diffusers import MochiPipeline

config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
pipe = MochiPipeline(
    scheduler=None, vae=None, transformer=None,
    text_encoder=T5EncoderModel(config),
    tokenizer=AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5"),
    force_zeros_for_empty_prompt=True,
)

_, mask = pipe._get_t5_prompt_embeds(["not empty", ""], max_sequence_length=8, device=torch.device("cpu"))
print(mask.sum(dim=1).tolist())  # [0, 0], but only row 2 should be zeroed

_, mask = pipe._get_t5_prompt_embeds(["", "not empty"], max_sequence_length=8, device=torch.device("cpu"))
print(mask.sum(dim=1).tolist())  # first row is not zeroed

Relevant precedent:
The current behavior was made configurable in PR #10284, but I found no duplicate for the batched-row bug.

Suggested fix:

if self.config.force_zeros_for_empty_prompt:
    empty_prompt_mask = torch.tensor([p == "" for p in prompt], device=device)
    text_input_ids = text_input_ids.to(device)
    text_input_ids[empty_prompt_mask] = 0
    prompt_attention_mask[empty_prompt_mask] = False

Issue 4: AutoencoderKLMochi.forward(return_dict=False) returns a DecoderOutput

Affected code:

def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: torch.Generator | None = None,
) -> torch.Tensor | torch.Tensor:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
if not return_dict:
return (dec,)
return dec

Problem:
forward calls self.decode(z) without taking .sample. With return_dict=False, it returns (DecoderOutput(...),) instead of (tensor,).

Impact:
Generic model callers and tests expecting tuple outputs get the wrong type.

Reproduction:

import torch
from diffusers import AutoencoderKLMochi

vae = AutoencoderKLMochi(
    latent_channels=4, encoder_block_out_channels=(32, 32, 32, 32),
    decoder_block_out_channels=(32, 32, 32, 32), layers_per_block=(1, 1, 1, 1, 1),
)
out = vae(torch.randn(1, 3, 7, 16, 16), return_dict=False)[0]
print(type(out), hasattr(out, "sample"))

Relevant precedent:

x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

Suggested fix:

dec = self.decode(z).sample
if not return_dict:
    return (dec,)
return DecoderOutput(sample=dec)

Issue 5: Mochi attention ignores set_attention_backend

Affected code:

self.attn1 = MochiAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=False,
added_kv_proj_dim=pooled_projection_dim,
added_proj_bias=False,
out_dim=dim,
out_context_dim=pooled_projection_dim,
context_pre_only=context_pre_only,
processor=MochiAttnProcessor2_0(),
eps=1e-5,

class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: "MochiAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: torch.Tensor | None = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
total_length = sequence_length + encoder_sequence_length
batch_size, heads, _, dim = query.shape
attn_outputs = []
for idx in range(batch_size):
mask = attention_mask[idx][None, :]
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
attn_output = F.scaled_dot_product_attention(
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
)

Problem:
MochiAttnProcessor2_0 has no _attention_backend attribute and calls F.scaled_dot_product_attention directly. ModelMixin.set_attention_backend() therefore skips the processor and the forward path ignores the selected backend.

Impact:
Users can call model.set_attention_backend("...") and receive no Mochi attention behavior change. This also leaves Mochi outside the current dispatcher/context-parallel attention conventions.

Reproduction:

from diffusers import MochiTransformer3DModel

model = MochiTransformer3DModel(
    patch_size=2, num_attention_heads=2, attention_head_dim=8, num_layers=1,
    pooled_projection_dim=16, in_channels=4, text_embed_dim=16, time_embed_dim=4, max_sequence_length=8,
)
processor = model.transformer_blocks[0].attn1.processor
print(hasattr(processor, "_attention_backend"), getattr(processor, "_attention_backend", None))
model.set_attention_backend("_native_math")
processor = model.transformer_blocks[0].attn1.processor
print(hasattr(processor, "_attention_backend"), getattr(processor, "_attention_backend", None))

Relevant precedent:
QwenImage routes attention through dispatch_attention_fn:

joint_hidden_states = dispatch_attention_fn(
joint_query,
joint_key,
joint_value,

Suggested fix:
This is larger than an inline patch: move Mochi attention to the current AttentionModuleMixin pattern, give the processor _attention_backend / _parallel_config, and route the joint valid-token attention through dispatch_attention_fn.

Issue 6: MochiPipelineOutput is not exported from the Mochi package

Affected code:

_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_mochi"] = ["MochiPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_mochi import MochiPipeline

@dataclass
class MochiPipelineOutput(BaseOutput):
r"""
Output class for Mochi pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or list[list[PIL.Image.Image]]):
list of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor

Problem:
pipeline_output.py defines MochiPipelineOutput, but src/diffusers/pipelines/mochi/__init__.py only exports MochiPipeline.

Impact:
from diffusers.pipelines.mochi import MochiPipelineOutput fails, unlike many newer pipeline families that expose their output dataclasses through the package lazy loader.

Reproduction:

from diffusers.pipelines.mochi import MochiPipelineOutput

Relevant precedent:

_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}

Suggested fix:

_import_structure = {"pipeline_output": ["MochiPipelineOutput"]}

# TYPE_CHECKING / slow import branch
from .pipeline_output import MochiPipelineOutput

Issue 7: Slow coverage is missing and existing expected-output checks are placeholders

Affected code:

def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (7, 3, 16, 16))
expected_video = torch.randn(7, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)

@nightly
@require_torch_accelerator
@require_big_accelerator
class MochiPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_mochi(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
videos = pipe(
prompt=prompt,
height=480,
width=848,
num_frames=19,
generator=generator,
num_inference_steps=2,
output_type="pt",
).frames
video = videos[0]
expected_video = torch.randn(1, 19, 480, 848, 3).numpy()
max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
assert max_diff < 1e-3, f"Max diff is too high. got {video}"

Problem:
There are fast tests and a @nightly pipeline integration test, but no @slow Mochi tests in the pipeline/model/LoRA Mochi test files. The fast test compares to torch.randn(...) with max_diff <= 1e10, and the nightly test compares the full model output to a random tensor instead of a fixed expected slice.

Impact:
Shape-level smoke tests exist, but numerical regressions in Mochi pipeline/model behavior are unlikely to be caught by normal slow test runs.

Reproduction:

from pathlib import Path

paths = [
    Path("tests/pipelines/mochi/test_mochi.py"),
    Path("tests/models/autoencoders/test_models_autoencoder_mochi.py"),
    Path("tests/models/transformers/test_models_transformer_mochi.py"),
    Path("tests/lora/test_lora_layers_mochi.py"),
]
assert any("@slow" in path.read_text() for path in paths), "No @slow Mochi tests found"

Relevant precedent:
Other mature pipeline tests use fixed expected slices rather than random placeholders for integration assertions.

Suggested fix:
Add at least one @slow Mochi pipeline test using genmo/mochi-1-preview with a small deterministic run and fixed expected output slice, and replace the random placeholder assertions in the fast/nightly tests with meaningful deterministic assertions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions