Skip to content

dit model/pipeline review #13616

@hlky

Description

@hlky

dit model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Reviewed: target pipeline/model files, top-level and package lazy exports, config/loading/serialization paths, runtime dtype/device/scheduler behavior, offload/device-map behavior, docs, examples, and tests.

Duplicate search: searched GitHub Issues and PRs in huggingface/diffusers for dit, DiTPipeline, DiTTransformer2DModel, pipeline_dit.py, get_label_ids, id2label, scale_model_input, out_channels, _no_split_modules, device_map, class_null, and rectangular/unpatchify failures. I found no likely duplicates for the findings below.

Local verification: direct .venv Python reproductions were run. Full fast pytest collection was attempted, but both DiT test files failed before running tests because this .venv Torch build is missing torch._C._distributed_c10d, imported via diffusers.training_utils.

Issue 1: CFG null class id is hardcoded to 1000

Affected code:

class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1)
class_null = torch.tensor([1000] * batch_size, device=self._execution_device)
class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels

Problem:
DiTPipeline always uses 1000 as the unconditional class id for classifier-free guidance. That only works for ImageNet-1k DiT configs. The transformer already stores the class count in num_embeds_ada_norm; for tiny/custom configs, the null class id should be that value, not a hardcoded ImageNet constant.

Impact:
Any custom or tiny DiT pipeline with num_embeds_ada_norm != 1000 crashes with an embedding index error under the default guidance_scale=4.0. The fast pipeline test misses this because its dummy transformer also uses num_embeds_ada_norm=1000.

Reproduction:

import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel

transformer = DiTTransformer2DModel(
    sample_size=16, num_layers=1, patch_size=4,
    attention_head_dim=8, num_attention_heads=2,
    in_channels=4, out_channels=8, num_embeds_ada_norm=8,
).eval()

pipe = DiTPipeline(transformer=transformer, vae=AutoencoderKL().eval(), scheduler=DDIMScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(class_labels=[1], generator=torch.Generator(device="cpu").manual_seed(0), num_inference_steps=1, output_type="np")

Relevant precedent:
The null class id is created by LabelEmbedding as self.num_classes, so callers should use the configured class count.

Suggested fix:

num_classes = self.transformer.config.num_embeds_ada_norm
class_null = torch.full((batch_size,), num_classes, device=self._execution_device, dtype=class_labels.dtype)

Issue 2: Scaled model input is passed to scheduler.step

Affected code:

self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale > 1:
half = latent_model_input[: len(latent_model_input) // 2]
latent_model_input = torch.cat([half, half], dim=0)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# compute previous image: x_t -> x_t-1
latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample

Problem:
The loop overwrites latent_model_input with scheduler.scale_model_input(...), then passes that scaled tensor as the sample argument to scheduler.step. For schedulers where scale_model_input is not identity, such as Euler or LMS, step receives the wrong latent state.

Impact:
DiT produces incorrect samples when users swap to supported Karras-style schedulers that scale model input. DDIM and DPMSolver do not expose this because their scaling is currently identity.

Reproduction:

import torch
from diffusers import AutoencoderKL, DiTPipeline, DiTTransformer2DModel, EulerDiscreteScheduler

class TrackingEulerScheduler(EulerDiscreteScheduler):
    def scale_model_input(self, sample, timestep):
        self.pre_scale = sample.detach().clone()
        scaled = super().scale_model_input(sample, timestep)
        self.scaled = scaled.detach().clone()
        return scaled

    def step(self, model_output, timestep, sample, *args, **kwargs):
        print("matches unscaled:", torch.equal(sample, self.pre_scale))
        print("matches scaled:", torch.equal(sample, self.scaled))
        raise SystemExit

transformer = DiTTransformer2DModel(
    sample_size=16, num_layers=1, patch_size=4,
    attention_head_dim=8, num_attention_heads=2,
    in_channels=4, out_channels=8, num_embeds_ada_norm=1000,
).eval()

pipe = DiTPipeline(transformer=transformer, vae=AutoencoderKL().eval(), scheduler=TrackingEulerScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(class_labels=[1], generator=torch.Generator(device="cpu").manual_seed(0), num_inference_steps=1, output_type="np")

Relevant precedent:

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if hasattr(self.scheduler, "scale_model_input"):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

Suggested fix:

self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
latents = torch.cat([latents] * 2) if guidance_scale > 1 else latents

for t in self.progress_bar(self.scheduler.timesteps):
    latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
    ...
    latents = self.scheduler.step(model_output, t, latents).prev_sample

Issue 3: Pipeline crashes when out_channels uses the model default

Affected code:

# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
else:
model_output = noise_pred

Problem:
DiTTransformer2DModel accepts out_channels=None and internally resolves it to in_channels, but the pipeline checks self.transformer.config.out_channels // 2. When the config value is None, this raises TypeError.

Impact:
A valid default-config DiT transformer cannot be used in DiTPipeline.

Reproduction:

import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel

transformer = DiTTransformer2DModel(
    sample_size=16, num_layers=1, patch_size=4,
    attention_head_dim=8, num_attention_heads=2,
    in_channels=4, num_embeds_ada_norm=1000,
).eval()

pipe = DiTPipeline(transformer=transformer, vae=AutoencoderKL().eval(), scheduler=DDIMScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(class_labels=[1], generator=torch.Generator(device="cpu").manual_seed(0), num_inference_steps=1, output_type="np")

Relevant precedent:
The model resolves the effective output channel count on self.out_channels.

Suggested fix:

out_channels = self.transformer.config.out_channels
out_channels = latent_channels if out_channels is None else out_channels

if out_channels // 2 == latent_channels:
    model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
else:
    model_output = noise_pred

Issue 4: id2label is not serialized

Affected code:

def __init__(
self,
transformer: DiTTransformer2DModel,
vae: AutoencoderKL,
scheduler: KarrasDiffusionSchedulers,
id2label: dict[int, str] | None = None,
):
super().__init__()
self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)
# create a imagenet -> id dictionary for easier use
self.labels = {}
if id2label is not None:
for key, value in id2label.items():
for label in value.split(","):
self.labels[label.lstrip().rstrip()] = int(key)
self.labels = dict(sorted(self.labels.items()))

Problem:
id2label is consumed to build self.labels, but it is never registered into the pipeline config. save_pretrained() therefore drops the label map, and a reloaded pipeline loses get_label_ids functionality unless the original model index supplied id2label.

Impact:
Custom or resaved DiT pipelines silently lose their class-name mapping.

Reproduction:

from diffusers import DiTPipeline

pipe = DiTPipeline(transformer=None, vae=None, scheduler=None, id2label={0: "vase"})
print(pipe.labels)
print("id2label" in pipe.config)

Relevant precedent:

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)

Suggested fix:

self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)
self.register_to_config(id2label=id2label)

Issue 5: get_label_ids breaks for a single string

Affected code:

if not isinstance(label, list):
label = list(label)
for l in label:
if l not in self.labels:
raise ValueError(
f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}."
)
return [self.labels[l] for l in label]

Problem:
The method advertises label: str | list[str], but label = list(label) turns "vase" into ["v", "a", "s", "e"].

Impact:
The public helper fails for the documented single-string input.

Reproduction:

from diffusers import DiTPipeline

pipe = DiTPipeline(transformer=None, vae=None, scheduler=None, id2label={0: "vase"})
print(pipe.get_label_ids("vase"))

Relevant precedent:
Most pipeline helpers normalize scalar string input with [value], not list(value).

Suggested fix:

if isinstance(label, str):
    label = [label]

Issue 6: DiT unpatchify assumes a square token grid

Affected code:

height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)

# unpatchify
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)

Problem:
The model computes the patch-grid height, width from the input, but later overwrites both with int(hidden_states.shape[1] ** 0.5). Rectangular inputs with valid patch dimensions fail during reshape.

Impact:
DiTTransformer2DModel cannot process non-square latent tensors even though PatchEmbed supports interpolated rectangular positional embeddings and the forward docstring does not document a square-only restriction.

Reproduction:

import torch
from diffusers import DiTTransformer2DModel

model = DiTTransformer2DModel(
    sample_size=8, num_layers=1, patch_size=2,
    attention_head_dim=4, num_attention_heads=2,
    in_channels=4, out_channels=8, num_embeds_ada_norm=8,
).eval()

model(torch.randn(1, 4, 8, 12), timestep=torch.tensor([1]), class_labels=torch.tensor([1]))

Relevant precedent:

hidden_states.shape[-2] // self.config.patch_size,
hidden_states.shape[-1] // self.config.patch_size,
)
hidden_states = self.pos_embed(hidden_states)

# unpatchify
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
)

Suggested fix:

# keep the height, width computed before self.pos_embed(hidden_states)
hidden_states = hidden_states.reshape(
    shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)

Issue 7: device_map="auto" is unsupported because _no_split_modules is missing

Affected code:

_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_supports_gradient_checkpointing = True
_supports_group_offloading = False

Problem:
DiTTransformer2DModel is a ModelMixin subclass but does not define _no_split_modules. Diffusers rejects device_map="auto" for such models.

Impact:
Large DiT checkpoints cannot use model-level automatic device placement, unlike related transformer models.

Reproduction:

from diffusers import DiTTransformer2DModel

model = DiTTransformer2DModel(
    sample_size=8, num_layers=1, patch_size=2,
    attention_head_dim=4, num_attention_heads=2,
    in_channels=4, out_channels=8, num_embeds_ada_norm=8,
)
print(model._get_no_split_modules("auto"))

Relevant precedent:

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock"]

Suggested fix:

_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]

Issue 8: Pipeline slow tests are missing

Affected code:

@nightly
@require_torch_accelerator
class DiTPipelineIntegrationTests(unittest.TestCase):

Problem:
The DiT pipeline has fast tests and nightly accelerator integration tests, but no @slow pipeline test. The model file has one @slow remapping test, but the pipeline integration coverage is @nightly only.

Impact:
Standard slow CI does not exercise pretrained DiT pipeline loading/inference, so regressions can be missed outside nightly jobs.

Reproduction:

from pathlib import Path

text = Path("tests/pipelines/dit/test_dit.py").read_text()
print("@slow" in text, "@nightly" in text)

Relevant precedent:

@slow
@require_torch_accelerator
class HunyuanDiTPipelineIntegrationTests(unittest.TestCase):

Suggested fix:

from ...testing_utils import slow

@slow
@require_torch_accelerator
class DiTPipelineIntegrationTests(unittest.TestCase):
    ...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions