dit model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Reviewed: target pipeline/model files, top-level and package lazy exports, config/loading/serialization paths, runtime dtype/device/scheduler behavior, offload/device-map behavior, docs, examples, and tests.
Duplicate search: searched GitHub Issues and PRs in huggingface/diffusers for dit, DiTPipeline, DiTTransformer2DModel, pipeline_dit.py, get_label_ids, id2label, scale_model_input, out_channels, _no_split_modules, device_map, class_null, and rectangular/unpatchify failures. I found no likely duplicates for the findings below.
Local verification: direct .venv Python reproductions were run. Full fast pytest collection was attempted, but both DiT test files failed before running tests because this .venv Torch build is missing torch._C._distributed_c10d, imported via diffusers.training_utils.
Issue 1: CFG null class id is hardcoded to 1000
Affected code:
|
class_labels = torch.tensor(class_labels, device=self._execution_device).reshape(-1) |
|
class_null = torch.tensor([1000] * batch_size, device=self._execution_device) |
|
class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels |
Problem:
DiTPipeline always uses 1000 as the unconditional class id for classifier-free guidance. That only works for ImageNet-1k DiT configs. The transformer already stores the class count in num_embeds_ada_norm; for tiny/custom configs, the null class id should be that value, not a hardcoded ImageNet constant.
Impact:
Any custom or tiny DiT pipeline with num_embeds_ada_norm != 1000 crashes with an embedding index error under the default guidance_scale=4.0. The fast pipeline test misses this because its dummy transformer also uses num_embeds_ada_norm=1000.
Reproduction:
import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel
transformer = DiTTransformer2DModel(
sample_size=16, num_layers=1, patch_size=4,
attention_head_dim=8, num_attention_heads=2,
in_channels=4, out_channels=8, num_embeds_ada_norm=8,
).eval()
pipe = DiTPipeline(transformer=transformer, vae=AutoencoderKL().eval(), scheduler=DDIMScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(class_labels=[1], generator=torch.Generator(device="cpu").manual_seed(0), num_inference_steps=1, output_type="np")
Relevant precedent:
The null class id is created by LabelEmbedding as self.num_classes, so callers should use the configured class count.
Suggested fix:
num_classes = self.transformer.config.num_embeds_ada_norm
class_null = torch.full((batch_size,), num_classes, device=self._execution_device, dtype=class_labels.dtype)
Issue 2: Scaled model input is passed to scheduler.step
Affected code:
|
self.scheduler.set_timesteps(num_inference_steps) |
|
for t in self.progress_bar(self.scheduler.timesteps): |
|
if guidance_scale > 1: |
|
half = latent_model_input[: len(latent_model_input) // 2] |
|
latent_model_input = torch.cat([half, half], dim=0) |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
# compute previous image: x_t -> x_t-1 |
|
latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample |
Problem:
The loop overwrites latent_model_input with scheduler.scale_model_input(...), then passes that scaled tensor as the sample argument to scheduler.step. For schedulers where scale_model_input is not identity, such as Euler or LMS, step receives the wrong latent state.
Impact:
DiT produces incorrect samples when users swap to supported Karras-style schedulers that scale model input. DDIM and DPMSolver do not expose this because their scaling is currently identity.
Reproduction:
import torch
from diffusers import AutoencoderKL, DiTPipeline, DiTTransformer2DModel, EulerDiscreteScheduler
class TrackingEulerScheduler(EulerDiscreteScheduler):
def scale_model_input(self, sample, timestep):
self.pre_scale = sample.detach().clone()
scaled = super().scale_model_input(sample, timestep)
self.scaled = scaled.detach().clone()
return scaled
def step(self, model_output, timestep, sample, *args, **kwargs):
print("matches unscaled:", torch.equal(sample, self.pre_scale))
print("matches scaled:", torch.equal(sample, self.scaled))
raise SystemExit
transformer = DiTTransformer2DModel(
sample_size=16, num_layers=1, patch_size=4,
attention_head_dim=8, num_attention_heads=2,
in_channels=4, out_channels=8, num_embeds_ada_norm=1000,
).eval()
pipe = DiTPipeline(transformer=transformer, vae=AutoencoderKL().eval(), scheduler=TrackingEulerScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(class_labels=[1], generator=torch.Generator(device="cpu").manual_seed(0), num_inference_steps=1, output_type="np")
Relevant precedent:
|
# expand the latents if we are doing classifier free guidance |
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
|
if hasattr(self.scheduler, "scale_model_input"): |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
# predict the noise residual |
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
timestep_cond=timestep_cond, |
|
cross_attention_kwargs=self.cross_attention_kwargs, |
|
added_cond_kwargs=added_cond_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
# perform guidance |
|
if self.do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: |
|
# Based on 3.4. in https://huggingface.co/papers/2305.08891 |
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) |
|
|
|
# compute the previous noisy sample x_t -> x_t-1 |
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
# expand the latents if we are doing classifier free guidance |
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input |
|
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( |
|
dtype=latent_model_input.dtype |
|
) |
|
|
|
# predict the noise residual |
|
noise_pred = self.transformer( |
|
latent_model_input, |
|
t_expand, |
|
encoder_hidden_states=prompt_embeds, |
|
text_embedding_mask=prompt_attention_mask, |
|
encoder_hidden_states_t5=prompt_embeds_2, |
|
text_embedding_mask_t5=prompt_attention_mask_2, |
|
image_meta_size=add_time_ids, |
|
style=style, |
|
image_rotary_emb=image_rotary_emb, |
|
return_dict=False, |
|
)[0] |
|
|
|
noise_pred, _ = noise_pred.chunk(2, dim=1) |
|
|
|
# perform guidance |
|
if self.do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
if self.do_classifier_free_guidance and guidance_rescale > 0.0: |
|
# Based on 3.4. in https://huggingface.co/papers/2305.08891 |
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) |
|
|
|
# compute the previous noisy sample x_t -> x_t-1 |
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
Suggested fix:
self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
latents = torch.cat([latents] * 2) if guidance_scale > 1 else latents
for t in self.progress_bar(self.scheduler.timesteps):
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
...
latents = self.scheduler.step(model_output, t, latents).prev_sample
Issue 3: Pipeline crashes when out_channels uses the model default
Affected code:
|
# learned sigma |
|
if self.transformer.config.out_channels // 2 == latent_channels: |
|
model_output, _ = torch.split(noise_pred, latent_channels, dim=1) |
|
else: |
|
model_output = noise_pred |
Problem:
DiTTransformer2DModel accepts out_channels=None and internally resolves it to in_channels, but the pipeline checks self.transformer.config.out_channels // 2. When the config value is None, this raises TypeError.
Impact:
A valid default-config DiT transformer cannot be used in DiTPipeline.
Reproduction:
import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel
transformer = DiTTransformer2DModel(
sample_size=16, num_layers=1, patch_size=4,
attention_head_dim=8, num_attention_heads=2,
in_channels=4, num_embeds_ada_norm=1000,
).eval()
pipe = DiTPipeline(transformer=transformer, vae=AutoencoderKL().eval(), scheduler=DDIMScheduler())
pipe.set_progress_bar_config(disable=True)
pipe(class_labels=[1], generator=torch.Generator(device="cpu").manual_seed(0), num_inference_steps=1, output_type="np")
Relevant precedent:
The model resolves the effective output channel count on self.out_channels.
Suggested fix:
out_channels = self.transformer.config.out_channels
out_channels = latent_channels if out_channels is None else out_channels
if out_channels // 2 == latent_channels:
model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
else:
model_output = noise_pred
Issue 4: id2label is not serialized
Affected code:
|
def __init__( |
|
self, |
|
transformer: DiTTransformer2DModel, |
|
vae: AutoencoderKL, |
|
scheduler: KarrasDiffusionSchedulers, |
|
id2label: dict[int, str] | None = None, |
|
): |
|
super().__init__() |
|
self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) |
|
|
|
# create a imagenet -> id dictionary for easier use |
|
self.labels = {} |
|
if id2label is not None: |
|
for key, value in id2label.items(): |
|
for label in value.split(","): |
|
self.labels[label.lstrip().rstrip()] = int(key) |
|
self.labels = dict(sorted(self.labels.items())) |
Problem:
id2label is consumed to build self.labels, but it is never registered into the pipeline config. save_pretrained() therefore drops the label map, and a reloaded pipeline loses get_label_ids functionality unless the original model index supplied id2label.
Impact:
Custom or resaved DiT pipelines silently lose their class-name mapping.
Reproduction:
from diffusers import DiTPipeline
pipe = DiTPipeline(transformer=None, vae=None, scheduler=None, id2label={0: "vase"})
print(pipe.labels)
print("id2label" in pipe.config)
Relevant precedent:
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|
self.register_to_config(requires_safety_checker=requires_safety_checker) |
Suggested fix:
self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)
self.register_to_config(id2label=id2label)
Issue 5: get_label_ids breaks for a single string
Affected code:
|
if not isinstance(label, list): |
|
label = list(label) |
|
|
|
for l in label: |
|
if l not in self.labels: |
|
raise ValueError( |
|
f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." |
|
) |
|
|
|
return [self.labels[l] for l in label] |
Problem:
The method advertises label: str | list[str], but label = list(label) turns "vase" into ["v", "a", "s", "e"].
Impact:
The public helper fails for the documented single-string input.
Reproduction:
from diffusers import DiTPipeline
pipe = DiTPipeline(transformer=None, vae=None, scheduler=None, id2label={0: "vase"})
print(pipe.get_label_ids("vase"))
Relevant precedent:
Most pipeline helpers normalize scalar string input with [value], not list(value).
Suggested fix:
if isinstance(label, str):
label = [label]
Issue 6: DiT unpatchify assumes a square token grid
Affected code:
|
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size |
|
hidden_states = self.pos_embed(hidden_states) |
|
# unpatchify |
|
height = width = int(hidden_states.shape[1] ** 0.5) |
|
hidden_states = hidden_states.reshape( |
|
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) |
|
) |
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) |
|
output = hidden_states.reshape( |
|
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) |
Problem:
The model computes the patch-grid height, width from the input, but later overwrites both with int(hidden_states.shape[1] ** 0.5). Rectangular inputs with valid patch dimensions fail during reshape.
Impact:
DiTTransformer2DModel cannot process non-square latent tensors even though PatchEmbed supports interpolated rectangular positional embeddings and the forward docstring does not document a square-only restriction.
Reproduction:
import torch
from diffusers import DiTTransformer2DModel
model = DiTTransformer2DModel(
sample_size=8, num_layers=1, patch_size=2,
attention_head_dim=4, num_attention_heads=2,
in_channels=4, out_channels=8, num_embeds_ada_norm=8,
).eval()
model(torch.randn(1, 4, 8, 12), timestep=torch.tensor([1]), class_labels=torch.tensor([1]))
Relevant precedent:
|
hidden_states.shape[-2] // self.config.patch_size, |
|
hidden_states.shape[-1] // self.config.patch_size, |
|
) |
|
hidden_states = self.pos_embed(hidden_states) |
|
# unpatchify |
|
hidden_states = hidden_states.reshape( |
|
shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels) |
|
) |
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) |
|
output = hidden_states.reshape( |
|
shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size) |
|
) |
Suggested fix:
# keep the height, width computed before self.pos_embed(hidden_states)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
Issue 7: device_map="auto" is unsupported because _no_split_modules is missing
Affected code:
|
_skip_layerwise_casting_patterns = ["pos_embed", "norm"] |
|
_supports_gradient_checkpointing = True |
|
_supports_group_offloading = False |
Problem:
DiTTransformer2DModel is a ModelMixin subclass but does not define _no_split_modules. Diffusers rejects device_map="auto" for such models.
Impact:
Large DiT checkpoints cannot use model-level automatic device placement, unlike related transformer models.
Reproduction:
from diffusers import DiTTransformer2DModel
model = DiTTransformer2DModel(
sample_size=8, num_layers=1, patch_size=2,
attention_head_dim=4, num_attention_heads=2,
in_channels=4, out_channels=8, num_embeds_ada_norm=8,
)
print(model._get_no_split_modules("auto"))
Relevant precedent:
|
_supports_gradient_checkpointing = True |
|
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"] |
|
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"] |
|
_supports_gradient_checkpointing = True |
|
_no_split_modules = ["BasicTransformerBlock"] |
Suggested fix:
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
Issue 8: Pipeline slow tests are missing
Affected code:
|
@nightly |
|
@require_torch_accelerator |
|
class DiTPipelineIntegrationTests(unittest.TestCase): |
Problem:
The DiT pipeline has fast tests and nightly accelerator integration tests, but no @slow pipeline test. The model file has one @slow remapping test, but the pipeline integration coverage is @nightly only.
Impact:
Standard slow CI does not exercise pretrained DiT pipeline loading/inference, so regressions can be missed outside nightly jobs.
Reproduction:
from pathlib import Path
text = Path("tests/pipelines/dit/test_dit.py").read_text()
print("@slow" in text, "@nightly" in text)
Relevant precedent:
|
@slow |
|
@require_torch_accelerator |
|
class HunyuanDiTPipelineIntegrationTests(unittest.TestCase): |
Suggested fix:
from ...testing_utils import slow
@slow
@require_torch_accelerator
class DiTPipelineIntegrationTests(unittest.TestCase):
...
ditmodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Reviewed: target pipeline/model files, top-level and package lazy exports, config/loading/serialization paths, runtime dtype/device/scheduler behavior, offload/device-map behavior, docs, examples, and tests.
Duplicate search: searched GitHub Issues and PRs in
huggingface/diffusersfordit,DiTPipeline,DiTTransformer2DModel,pipeline_dit.py,get_label_ids,id2label,scale_model_input,out_channels,_no_split_modules,device_map,class_null, and rectangular/unpatchify failures. I found no likely duplicates for the findings below.Local verification: direct
.venvPython reproductions were run. Full fast pytest collection was attempted, but both DiT test files failed before running tests because this.venvTorch build is missingtorch._C._distributed_c10d, imported viadiffusers.training_utils.Issue 1: CFG null class id is hardcoded to 1000
Affected code:
diffusers/src/diffusers/pipelines/dit/pipeline_dit.py
Lines 173 to 175 in 0f1abc4
Problem:
DiTPipelinealways uses1000as the unconditional class id for classifier-free guidance. That only works for ImageNet-1k DiT configs. The transformer already stores the class count innum_embeds_ada_norm; for tiny/custom configs, the null class id should be that value, not a hardcoded ImageNet constant.Impact:
Any custom or tiny DiT pipeline with
num_embeds_ada_norm != 1000crashes with an embedding index error under the defaultguidance_scale=4.0. The fast pipeline test misses this because its dummy transformer also usesnum_embeds_ada_norm=1000.Reproduction:
Relevant precedent:
The null class id is created by
LabelEmbeddingasself.num_classes, so callers should use the configured class count.Suggested fix:
Issue 2: Scaled model input is passed to
scheduler.stepAffected code:
diffusers/src/diffusers/pipelines/dit/pipeline_dit.py
Lines 178 to 183 in 0f1abc4
diffusers/src/diffusers/pipelines/dit/pipeline_dit.py
Lines 221 to 222 in 0f1abc4
Problem:
The loop overwrites
latent_model_inputwithscheduler.scale_model_input(...), then passes that scaled tensor as thesampleargument toscheduler.step. For schedulers wherescale_model_inputis not identity, such as Euler or LMS,stepreceives the wrong latent state.Impact:
DiT produces incorrect samples when users swap to supported Karras-style schedulers that scale model input. DDIM and DPMSolver do not expose this because their scaling is currently identity.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
Lines 1036 to 1062 in 0f1abc4
diffusers/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
Lines 831 to 866 in 0f1abc4
Suggested fix:
Issue 3: Pipeline crashes when
out_channelsuses the model defaultAffected code:
diffusers/src/diffusers/pipelines/dit/pipeline_dit.py
Lines 215 to 219 in 0f1abc4
Problem:
DiTTransformer2DModelacceptsout_channels=Noneand internally resolves it toin_channels, but the pipeline checksself.transformer.config.out_channels // 2. When the config value isNone, this raisesTypeError.Impact:
A valid default-config DiT transformer cannot be used in
DiTPipeline.Reproduction:
Relevant precedent:
The model resolves the effective output channel count on
self.out_channels.Suggested fix:
Issue 4:
id2labelis not serializedAffected code:
diffusers/src/diffusers/pipelines/dit/pipeline_dit.py
Lines 58 to 74 in 0f1abc4
Problem:
id2labelis consumed to buildself.labels, but it is never registered into the pipeline config.save_pretrained()therefore drops the label map, and a reloaded pipeline losesget_label_idsfunctionality unless the original model index suppliedid2label.Impact:
Custom or resaved DiT pipelines silently lose their class-name mapping.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
Lines 239 to 241 in 0f1abc4
Suggested fix:
Issue 5:
get_label_idsbreaks for a single stringAffected code:
diffusers/src/diffusers/pipelines/dit/pipeline_dit.py
Lines 90 to 99 in 0f1abc4
Problem:
The method advertises
label: str | list[str], butlabel = list(label)turns"vase"into["v", "a", "s", "e"].Impact:
The public helper fails for the documented single-string input.
Reproduction:
Relevant precedent:
Most pipeline helpers normalize scalar string input with
[value], notlist(value).Suggested fix:
Issue 6: DiT unpatchify assumes a square token grid
Affected code:
diffusers/src/diffusers/models/transformers/dit_transformer_2d.py
Lines 180 to 181 in 0f1abc4
diffusers/src/diffusers/models/transformers/dit_transformer_2d.py
Lines 213 to 220 in 0f1abc4
Problem:
The model computes the patch-grid
height, widthfrom the input, but later overwrites both withint(hidden_states.shape[1] ** 0.5). Rectangular inputs with valid patch dimensions fail during reshape.Impact:
DiTTransformer2DModelcannot process non-square latent tensors even thoughPatchEmbedsupports interpolated rectangular positional embeddings and the forward docstring does not document a square-only restriction.Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/pixart_transformer_2d.py
Lines 303 to 306 in 0f1abc4
diffusers/src/diffusers/models/transformers/pixart_transformer_2d.py
Lines 350 to 357 in 0f1abc4
Suggested fix:
Issue 7:
device_map="auto"is unsupported because_no_split_modulesis missingAffected code:
diffusers/src/diffusers/models/transformers/dit_transformer_2d.py
Lines 67 to 69 in 0f1abc4
Problem:
DiTTransformer2DModelis aModelMixinsubclass but does not define_no_split_modules. Diffusers rejectsdevice_map="auto"for such models.Impact:
Large DiT checkpoints cannot use model-level automatic device placement, unlike related transformer models.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/pixart_transformer_2d.py
Lines 80 to 82 in 0f1abc4
diffusers/src/diffusers/models/transformers/transformer_2d.py
Lines 67 to 68 in 0f1abc4
Suggested fix:
Issue 8: Pipeline slow tests are missing
Affected code:
diffusers/tests/pipelines/dit/test_dit.py
Lines 117 to 119 in 0f1abc4
Problem:
The DiT pipeline has fast tests and nightly accelerator integration tests, but no
@slowpipeline test. The model file has one@slowremapping test, but the pipeline integration coverage is@nightlyonly.Impact:
Standard slow CI does not exercise pretrained DiT pipeline loading/inference, so regressions can be missed outside nightly jobs.
Reproduction:
Relevant precedent:
diffusers/tests/pipelines/hunyuandit/test_hunyuan_dit.py
Lines 316 to 318 in 0f1abc4
Suggested fix: