Skip to content

latte model/pipeline review #13595

@hlky

Description

@hlky

latte model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate search status: searched existing huggingface/diffusers issues and PRs for latte, LatteTransformer3DModel, pipeline_latte, prompt_embeds, encoder_attention_mask, latents dtype, temp_pos_embed, output_type, and docs/test coverage terms. Related but not duplicate: #11137/#11139 fixed a different Latte dtype mismatch in temp_pos_embed; #10558 fixed output_type="latent" handling.

Test coverage status: fast model and pipeline tests exist, and a slow pipeline test exists. The slow test is not missing, but Issue 6 covers that it is currently ineffective/broken. Pytest collection in this .venv failed before running Latte tests because the local PyTorch build is missing torch._C._distributed_c10d.

Issue 1: Provided latents are not cast to the pipeline dtype

Affected code:

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)

Problem:
prepare_latents() casts provided latents only to device, not to the requested dtype. In half/bfloat16 pipelines, user-provided float32 latents are passed into a half transformer and fail in the first projection.

Impact:
Resuming or reusing precomputed latents with torch_dtype=torch.float16/bfloat16 can crash at inference time.

Reproduction:

import torch
from diffusers import AutoencoderKL, DDIMScheduler, LattePipeline, LatteTransformer3DModel

cfg = dict(sample_size=8, num_layers=1, patch_size=2, attention_head_dim=4, num_attention_heads=2,
           caption_channels=8, in_channels=4, cross_attention_dim=8, out_channels=8,
           num_embeds_ada_norm=1000, norm_type="ada_norm_single")
pipe = LattePipeline(None, None, AutoencoderKL().eval(), LatteTransformer3DModel(**cfg).eval().to(torch.float16), DDIMScheduler())
latents = torch.randn(1, 4, 1, 8, 8, dtype=torch.float32)
print(pipe.prepare_latents(1, 4, 1, 8, 8, torch.float16, torch.device("cpu"), None, latents).dtype)
pipe(prompt_embeds=torch.randn(1, 8, 8, dtype=torch.float16), guidance_scale=1.0,
     num_inference_steps=1, height=8, width=8, video_length=1, latents=latents,
     output_type="latent", mask_feature=False)

Relevant precedent:

if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)

Suggested fix:

else:
    latents = latents.to(device=device, dtype=dtype)

Issue 2: video_length values other than 1 or the configured length crash

Affected code:

# define temporal positional embedding
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
) # 1152 hidden size
self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)

if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)

Problem:
The model registers a fixed-length temp_pos_embed, then adds the full buffer to the runtime hidden states. Any runtime frame count between 2 and config.video_length - 1, or above config.video_length, raises a size mismatch. The pipeline exposes video_length but does not validate this.

Impact:
Users can pass a seemingly supported video_length=8 or video_length=24 and get a low-level tensor error.

Reproduction:

import torch
from diffusers import LatteTransformer3DModel

model = LatteTransformer3DModel(sample_size=8, num_layers=1, patch_size=2, attention_head_dim=4,
    num_attention_heads=2, caption_channels=8, in_channels=4, cross_attention_dim=8,
    out_channels=8, num_embeds_ada_norm=1000, norm_type="ada_norm_single", video_length=16).eval()
model(hidden_states=torch.randn(1, 4, 2, 8, 8),
      encoder_hidden_states=torch.randn(1, 8, 8),
      timestep=torch.tensor([1]))

Relevant precedent:
No exact duplicate found. Related PR #11139 only fixed dtype casting of the same buffer.

Suggested fix:

if i == 0 and num_frame > 1:
    if num_frame > self.temp_pos_embed.shape[1]:
        raise ValueError(f"`num_frame` must be <= {self.temp_pos_embed.shape[1]}, got {num_frame}.")
    hidden_states = hidden_states + self.temp_pos_embed[:, :num_frame].to(hidden_states.dtype)

Issue 3: encoder_attention_mask is not expanded for temporal flattening

Affected code:

encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
# Prepare timesteps for spatial and temporal block
timestep_spatial = timestep.repeat_interleave(
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
).view(-1, timestep.shape[-1])
timestep_temp = timestep.repeat_interleave(
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
).view(-1, timestep.shape[-1])
# Spatial and temporal transformer blocks
for i, (spatial_block, temp_block) in enumerate(
zip(self.transformer_blocks, self.temporal_transformer_blocks)
):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
spatial_block,
hidden_states,
None, # attention_mask
encoder_hidden_states_spatial,
encoder_attention_mask,

Problem:
encoder_hidden_states are repeated from batch B to B * num_frame, but encoder_attention_mask remains shape (B, seq). With temporal inputs, attention reshaping expects the mask batch to match B * num_frame and crashes.

Impact:
The public model API documents encoder_attention_mask, but it is unusable for normal multi-frame inputs.

Reproduction:

import torch
from diffusers import LatteTransformer3DModel

model = LatteTransformer3DModel(sample_size=8, num_layers=1, patch_size=2, attention_head_dim=4,
    num_attention_heads=2, caption_channels=8, in_channels=4, cross_attention_dim=8,
    out_channels=8, num_embeds_ada_norm=1000, norm_type="ada_norm_single", video_length=2).eval()
model(hidden_states=torch.randn(1, 4, 2, 8, 8),
      encoder_hidden_states=torch.randn(1, 8, 8),
      encoder_attention_mask=torch.ones(1, 8),
      timestep=torch.tensor([1]))

Relevant precedent:

# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

Suggested fix:
Convert 2D masks to additive bias, then repeat them with repeat_interleave(num_frame, dim=0) before passing them to spatial blocks.

Issue 4: Prompt-embedding-only CFG path builds an invalid attention mask

Affected code:

else:
prompt_embeds_attention_mask = torch.ones_like(prompt_embeds)
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed, -1)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(num_images_per_prompt, 1)

# Perform additional masking.
if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
masked_negative_prompt_embeds = (
negative_prompt_embeds[:, :keep_indices, :] if negative_prompt_embeds is not None else None
)
return masked_prompt_embeds, masked_negative_prompt_embeds

Problem:
When prompt_embeds are supplied without negative_prompt_embeds, the pipeline creates prompt_embeds_attention_mask = torch.ones_like(prompt_embeds), a 3D tensor. mask_text_embeddings() expects a 2D mask, so batched prompt embeddings crash. The mask duplication order also differs from PixArt and can mismatch masks when num_images_per_prompt > 1.

Impact:
A documented prompt-embedding workflow fails unless users also pass negative embeddings or disable mask_feature.

Reproduction:

from types import SimpleNamespace
import torch
from diffusers import LattePipeline

class Tok:
    def __call__(self, texts, **kw):
        texts = texts if isinstance(texts, list) else [texts]
        n = kw["max_length"]
        return SimpleNamespace(input_ids=torch.ones(len(texts), n, dtype=torch.long),
                               attention_mask=torch.ones(len(texts), n, dtype=torch.long))

class Enc(torch.nn.Module):
    @property
    def dtype(self): return torch.float32
    def forward(self, input_ids, attention_mask=None):
        return (torch.randn(input_ids.shape[0], input_ids.shape[1], 32),)

pipe = LattePipeline(Tok(), Enc(), None, None, None)
pipe.encode_prompt(None, prompt_embeds=torch.randn(2, 8, 32), do_classifier_free_guidance=True,
                   negative_prompt="", mask_feature=True, device=torch.device("cpu"))

Relevant precedent:

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)

Suggested fix:

prompt_embeds_attention_mask = torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=prompt_embeds.device)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.repeat(1, num_images_per_prompt)
prompt_embeds_attention_mask = prompt_embeds_attention_mask.view(bs_embed * num_images_per_prompt, -1)

Issue 5: Latte transformer does not expose attention processor APIs

Affected code:

from ..attention import BasicTransformerBlock
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):

_skip_layerwise_casting_patterns = ["pos_embed", "norm"]

Problem:
The model is built from BasicTransformerBlock attention layers, but it does not inherit AttentionMixin and does not define _no_split_modules. Users cannot call set_attn_processor(), set_default_attn_processor(), or fuse_qkv_projections() on the model.

Impact:
Attention processor tests are skipped, model-level processor replacement is unavailable, and device-map splitting has less guidance than similar PixArt-style transformer models.

Reproduction:

from diffusers import LatteTransformer3DModel

model = LatteTransformer3DModel(sample_size=8, num_layers=1, patch_size=2, attention_head_dim=4,
    num_attention_heads=2, caption_channels=8, in_channels=4, cross_attention_dim=8,
    out_channels=8, num_embeds_ada_norm=1000, norm_type="ada_norm_single")
print(hasattr(model, "set_attn_processor"))
print(getattr(type(model), "_no_split_modules", None))

Relevant precedent:

from ..attention import AttentionMixin, BasicTransformerBlock
from ..attention_processor import Attention, AttnProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class PixArtTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin):

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]

Suggested fix:

from ..attention import AttentionMixin, BasicTransformerBlock

class LatteTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, CacheMixin):
    _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]

Issue 6: Latte pipeline tests do not provide meaningful regression coverage

Affected code:

def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (1, 3, 8, 8))
expected_video = torch.randn(1, 3, 8, 8)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)

def test_latte(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt
videos = pipe(
prompt=prompt,
height=512,
width=512,
generator=generator,
num_inference_steps=2,
clean_caption=False,
).frames
video = videos[0]
expected_video = torch.randn(1, 512, 512, 3).numpy()
max_diff = numpy_cosine_similarity_distance(video.flatten(), expected_video)
assert max_diff < 1e-3, f"Max diff is too high. got {video.flatten()}"

Problem:
The fast test compares against random noise with 1e10 tolerance, so it only checks that execution returns a tensor. The slow test uses default output_type="pil", so videos[0] is a list of PIL frames and video.flatten() is invalid.

Impact:
Regressions in Latte output values, output type, dtype paths, and prompt embedding paths can land without a useful test failure.

Reproduction:

from pathlib import Path

text = Path("tests/pipelines/latte/test_latte.py").read_text()
assert "assertLessEqual(max_diff, 1e10)" in text
assert "video.flatten()" in text
print("Latte fast test has a non-regression threshold and slow test flattens PIL output.")

Relevant precedent:

# fmt: off
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))

Suggested fix:
Use deterministic expected slices for fast tests, set slow test output_type="np" or flatten np.array(video), and add targeted tests for provided latents dtype, prompt_embeds without negative_prompt_embeds, temporal frame counts, and encoder_attention_mask.

Issue 7: Pipeline overview documents Latte as text-to-image

Affected code:

| [Latte](latte) | text2image |

Problem:
The API overview lists Latte as text2image, but LattePipeline is a text-to-video pipeline.

Impact:
Users browsing the model family table get the wrong task type.

Reproduction:

from pathlib import Path

line = next(line for line in Path("docs/source/en/api/pipelines/overview.md").read_text().splitlines() if "[Latte]" in line)
print(line)
assert "text2video" in line

Relevant precedent:

| [AnimateDiff](animatediff) | text2video |
| [AudioLDM2](audioldm2) | text2audio |
| [LongCat-AudioDiT](longcat_audio_dit) | text2audio |
| [AuraFlow](aura_flow) | text2image |
| [Bria 3.2](bria_3_2) | text2image |
| [CogVideoX](cogvideox) | text2video |

Suggested fix:

| [Latte](latte) | text2video |

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions