Skip to content

hidream_image model/pipeline review #13617

@hlky

Description

@hlky

hidream_image model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate search performed with gh for HiDream, class/file names, and specific failure modes. Existing duplicate found only for the torch.compile item: #11477 and #11430.

Issue 1: height and width are silently replaced with default-area dimensions

Affected code:

height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
division = self.vae_scale_factor * 2
S_max = (self.default_sample_size * self.vae_scale_factor) ** 2
scale = S_max / (width * height)
scale = math.sqrt(scale)
width, height = int(width * scale // division * division), int(height * scale // division * division)

Problem:
The pipeline rescales any requested height/width to the default sample area. A user asking for 64x64 gets default-area latents/images instead of 64x64.

Impact:
User-visible output dimensions do not match the public API or docstring, and precomputed latents for the requested size are rejected.

Reproduction:

import torch
from diffusers import FlowMatchEulerDiscreteScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel

m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=4, out_channels=4, num_layers=0, num_single_layers=0,
    attention_head_dim=8, num_attention_heads=1, caption_channels=[32, 16], text_emb_dim=64,
    num_routed_experts=0, axes_dims_rope=(4, 2, 2), max_resolution=(32, 32), llama_layers=())
p = HiDreamImagePipeline(FlowMatchEulerDiscreteScheduler(), None, None, None, None, None, None, None, None, None, m)
out = p(height=64, width=64, num_inference_steps=0, guidance_scale=1.0, output_type="latent",
        pooled_prompt_embeds=torch.randn(1, 64), prompt_embeds_t5=torch.randn(1, 1, 32),
        prompt_embeds_llama3=torch.randn(0, 1, 1, 16)).images
print(out.shape)  # torch.Size([1, 4, 128, 128])

Relevant precedent:

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)

Suggested fix:

height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor

division = self.vae_scale_factor * self.transformer.config.patch_size
height = int(height) // division * division
width = int(width) // division * division
image_seq_len = (height // self.vae_scale_factor // self.transformer.config.patch_size) * (
    width // self.vae_scale_factor // self.transformer.config.patch_size
)
if image_seq_len > self.transformer.max_seq:
    raise ValueError(f"Requested image has {image_seq_len} latent tokens, but this model supports {self.transformer.max_seq}.")

Issue 2: attention_kwargs is accepted but never forwarded to the transformer

Affected code:

noise_pred = self.transformer(
hidden_states=latent_model_input,
timesteps=timestep,
encoder_hidden_states_t5=prompt_embeds_t5,
encoder_hidden_states_llama3=prompt_embeds_llama3,
pooled_embeds=pooled_prompt_embeds,
return_dict=False,
)[0]

Problem:
__call__ stores attention_kwargs, and the transformer forward is decorated with @apply_lora_scale("attention_kwargs"), but the denoising call never passes the kwargs.

Impact:
Runtime LoRA scaling via attention_kwargs={"scale": ...} is ignored for HiDream.

Reproduction:

import torch
from diffusers import FlowMatchEulerDiscreteScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel

m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=4, out_channels=4, num_layers=1, num_single_layers=0,
    attention_head_dim=8, num_attention_heads=1, caption_channels=[32, 16], text_emb_dim=64,
    num_routed_experts=0, axes_dims_rope=(4, 2, 2), max_resolution=(32, 32), llama_layers=(0,)).eval()
p = HiDreamImagePipeline(FlowMatchEulerDiscreteScheduler(), None, None, None, None, None, None, None, None, None, m)
seen = []
h = p.transformer.register_forward_pre_hook(lambda module, args, kwargs: seen.append(kwargs.get("attention_kwargs")), with_kwargs=True)
p(height=128, width=128, num_inference_steps=1, guidance_scale=1.0, output_type="latent",
  attention_kwargs={"scale": 0.25}, pooled_prompt_embeds=torch.randn(1, 64),
  prompt_embeds_t5=torch.randn(1, 4, 32), prompt_embeds_llama3=torch.randn(1, 1, 4, 16))
h.remove()
print(seen)  # [None]

Relevant precedent:

noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,

Suggested fix:

noise_pred = self.transformer(
    hidden_states=latent_model_input,
    timesteps=timestep,
    encoder_hidden_states_t5=prompt_embeds_t5,
    encoder_hidden_states_llama3=prompt_embeds_llama3,
    pooled_embeds=pooled_prompt_embeds,
    attention_kwargs=self.attention_kwargs,
    return_dict=False,
)[0]

Issue 3: Padded tokens are not actually masked in attention

Affected code:

if hidden_states_masks is not None:
key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1)
if not attn.single:
query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype)
key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype)
value_t = attn.to_v_t(encoder_hidden_states)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
num_image_tokens = query_i.shape[1]
num_text_tokens = query_t.shape[1]
query = torch.cat([query_i, query_t], dim=1)
key = torch.cat([key_i, key_t], dim=1)
value = torch.cat([value_i, value_t], dim=1)
else:
query = query_i
key = key_i
value = value_i
if query.shape[-1] == image_rotary_emb.shape[-3] * 2:
query, key = apply_rope(query, key, image_rotary_emb)
else:
query_1, query_2 = query.chunk(2, dim=-1)
key_1, key_2 = key.chunk(2, dim=-1)
query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb)
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = F.scaled_dot_product_attention(
query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False

Problem:
For non-square/padded latent sequences, hidden_states_masks only multiplies image keys by 0. It does not mask attention logits and does not mask values, so masked tokens still change valid outputs.

Impact:
Non-square generation can be contaminated by padded latent tokens, and externally supplied padded hidden states are not semantically masked.

Reproduction:

import torch
from diffusers import HiDreamImageTransformer2DModel

torch.manual_seed(0)
m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=1, out_channels=1, num_layers=1, num_single_layers=0,
    attention_head_dim=8, num_attention_heads=1, caption_channels=[3, 2], text_emb_dim=5,
    num_routed_experts=0, axes_dims_rope=(4, 2, 2), max_resolution=(4, 4), llama_layers=(0,)).eval()
hidden = torch.randn(1, 4, 4)
kwargs = dict(timesteps=torch.tensor([1]), encoder_hidden_states_t5=torch.randn(1, 1, 3),
    encoder_hidden_states_llama3=torch.randn(1, 1, 1, 2), pooled_embeds=torch.randn(1, 5),
    img_ids=torch.zeros(1, 4, 3), img_sizes=torch.tensor([[1, 2]]),
    hidden_states_masks=torch.tensor([[1., 1., 0., 0.]]), return_dict=False)
with torch.no_grad():
    out1 = m(hidden_states=hidden.clone(), **kwargs)[0]
    hidden[:, 2:] += 1000
    out2 = m(hidden_states=hidden, **kwargs)[0]
print((out1 - out2).abs().max().item())  # non-zero

Relevant precedent:

joint_hidden_states = dispatch_attention_fn(
joint_query,
joint_key,
joint_value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,

Suggested fix:
Route HiDream attention through dispatch_attention_fn with a real boolean attention mask. For double-stream blocks, concatenate the image mask with all-true text masks before dispatch; for single-stream blocks, pass the already-expanded image/text mask.

Issue 4: HiDream attention bypasses diffusers attention backend dispatch

Affected code:

class HiDreamAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __call__(
self,
attn: HiDreamAttention,
hidden_states: torch.Tensor,
hidden_states_masks: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
dtype = hidden_states.dtype
batch_size = hidden_states.shape[0]
query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype)
key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype)
value_i = attn.to_v(hidden_states)
inner_dim = key_i.shape[-1]
head_dim = inner_dim // attn.heads
query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
if hidden_states_masks is not None:
key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1)
if not attn.single:
query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype)
key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype)
value_t = attn.to_v_t(encoder_hidden_states)
query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
num_image_tokens = query_i.shape[1]
num_text_tokens = query_t.shape[1]
query = torch.cat([query_i, query_t], dim=1)
key = torch.cat([key_i, key_t], dim=1)
value = torch.cat([value_i, value_t], dim=1)
else:
query = query_i
key = key_i
value = value_i
if query.shape[-1] == image_rotary_emb.shape[-3] * 2:
query, key = apply_rope(query, key, image_rotary_emb)
else:
query_1, query_2 = query.chunk(2, dim=-1)
key_1, key_2 = key.chunk(2, dim=-1)
query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb)
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = F.scaled_dot_product_attention(
query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False

Problem:
HiDreamAttnProcessor calls F.scaled_dot_product_attention directly and has no _attention_backend / _parallel_config. model.set_attention_backend(...) therefore has no effect.

Impact:
HiDream cannot use diffusers attention backends such as Flash/Sage/Flex through the standard API, and context-parallel backend plumbing is bypassed.

Reproduction:

from diffusers import HiDreamImageTransformer2DModel

m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=4, out_channels=4, num_layers=1, num_single_layers=0,
    attention_head_dim=8, num_attention_heads=1, caption_channels=[32, 16], text_emb_dim=64,
    num_routed_experts=0, axes_dims_rope=(4, 2, 2), max_resolution=(32, 32), llama_layers=(0,))
processors = [mod.processor for mod in m.modules() if mod.__class__.__name__ == "HiDreamAttention"]
print(hasattr(processors[0], "_attention_backend"))  # False
m.set_attention_backend("native")
print(getattr(processors[0], "_attention_backend", None))  # None

Relevant precedent:

class FluxAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "FluxAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)

Suggested fix:
Make HiDreamAttention follow the new custom-attention pattern: inherit AttentionModuleMixin, declare _default_processor_cls / _available_processors, add _attention_backend and _parallel_config to the processor, and replace the direct SDPA call with dispatch_attention_fn(...).

Issue 5: Transformer config defaults are invalid

Affected code:

patch_size: int | None = None,
in_channels: int = 64,
out_channels: int | None = None,
num_layers: int = 16,
num_single_layers: int = 32,
attention_head_dim: int = 128,
num_attention_heads: int = 20,
caption_channels: list[int] = None,
text_emb_dim: int = 2048,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
axes_dims_rope: tuple[int, int] = (32, 32),
max_resolution: tuple[int, int] = (128, 128),
llama_layers: list[int] = None,

Problem:
Several registered config defaults are unusable: patch_size=None, caption_channels=None, llama_layers=None, and axes_dims_rope=(32, 32) for three-axis ids.

Impact:
Default construction and configs missing these fields fail with low-level errors instead of a clear config error, which weakens serialization/backwards-compatibility behavior.

Reproduction:

from diffusers import HiDreamImageTransformer2DModel

try:
    HiDreamImageTransformer2DModel()
except Exception as e:
    print(type(e).__name__, e)
# TypeError unsupported operand type(s) for *: 'int' and 'NoneType'

Relevant precedent:
Flux and Qwen transformer constructors use internally consistent config defaults.

Suggested fix:

if patch_size is None:
    raise ValueError("`patch_size` must be set for HiDreamImageTransformer2DModel.")
if caption_channels is None or len(caption_channels) != 2:
    raise ValueError("`caption_channels` must contain [t5_dim, llama_dim].")
if llama_layers is None:
    raise ValueError("`llama_layers` must be set.")
if len(axes_dims_rope) != 3:
    raise ValueError("`axes_dims_rope` must contain three axes for HiDream image ids.")

Issue 6: MoE inference path is not torch.compile-friendly

Affected code:

@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)

Problem:
moe_infer performs bincount().cpu().numpy().cumsum(0) inside model forward. This creates host synchronization and graph breaks.

Impact:
torch.compile(fullgraph=True) fails for the default MoE inference path.

Reproduction:

import torch
from diffusers import HiDreamImageTransformer2DModel

m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=1, out_channels=1, num_layers=1, num_single_layers=0,
    attention_head_dim=8, num_attention_heads=1, caption_channels=[3, 2], text_emb_dim=5,
    num_routed_experts=2, num_activated_experts=1, axes_dims_rope=(4, 2, 2), max_resolution=(4, 4),
    llama_layers=(0,)).eval()
compiled = torch.compile(m, fullgraph=True)
try:
    compiled(hidden_states=torch.randn(1, 1, 4, 4), timesteps=torch.tensor([1]),
        encoder_hidden_states_t5=torch.randn(1, 1, 3), encoder_hidden_states_llama3=torch.randn(1, 1, 1, 2),
        pooled_embeds=torch.randn(1, 5), return_dict=False)
except Exception as e:
    print(type(e).__name__, str(e).splitlines()[0])

Relevant precedent:

offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
gate_up = F.grouped_mm(x, self.gate_up_proj, offs=offsets)
gate, up = gate_up.chunk(2, dim=-1)
out = F.grouped_mm(F.silu(gate) * up, self.down_proj, offs=offsets)
return out.type_as(x)
def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor:
if self.use_grouped_mm:
return self._run_experts_grouped_mm(x, num_tokens_per_expert)
return self._run_experts_for_loop(x, num_tokens_per_expert)

Suggested fix:
Existing duplicate: #11477. Continue that PR or replace this routing with a torch-only grouped implementation that avoids NumPy and host-side dynamic routing in the compiled path.

Issue 7: No slow HiDream pipeline/model coverage exists

Affected code:

class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HiDreamImagePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
test_xformers_attention = False
test_layerwise_casting = True
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = HiDreamImageTransformer2DModel(
patch_size=2,
in_channels=4,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=8,
num_attention_heads=4,
caption_channels=[32, 16],
text_emb_dim=64,
num_routed_experts=4,
num_activated_experts=2,
axes_dims_rope=(4, 2, 2),
max_resolution=(32, 32),
llama_layers=(0, 1),
).eval()
torch.manual_seed(0)
vae = AutoencoderKL(scaling_factor=0.3611, shift_factor=0.1159)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
max_position_embeddings=128,
)
torch.manual_seed(0)
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder_3 = T5EncoderModel(config).eval()
torch.manual_seed(0)
text_encoder_4 = LlamaForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
text_encoder_4.generation_config.pad_token_id = 1
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer_4 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
scheduler = FlowMatchEulerDiscreteScheduler()
components = {
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"text_encoder_3": text_encoder_3,
"tokenizer_3": tokenizer_3,
"text_encoder_4": text_encoder_4,
"tokenizer_4": tokenizer_4,
"transformer": transformer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,

Problem:
HiDream has fast pipeline and model tests, plus GGUF coverage, but no @slow test for a published HiDream pipeline/model checkpoint.

Impact:
Large-checkpoint loading, real component configs, offload behavior, real output slices, and docs example compatibility are not covered.

Reproduction:

from pathlib import Path

paths = [Path("tests/pipelines/hidream_image"), Path("tests/models/transformers/test_models_transformer_hidream.py")]
hits = []
for path in paths:
    files = [path] if path.is_file() else list(path.rglob("*.py"))
    hits += [(str(f), i) for f in files for i, line in enumerate(f.read_text().splitlines(), 1) if "@slow" in line or "slow(" in line]
print(hits)  # []

Relevant precedent:

@slow
@require_big_accelerator
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)

Suggested fix:
Add a @slow HiDream test class that loads a real or hf-internal-testing published HiDream pipeline checkpoint through from_pretrained, runs a deterministic minimal inference, and asserts shape plus an output slice.

Local verification note: targeted snippets were run with .venv. Direct pytest collection for the HiDream model and pipeline test files failed in this environment because the installed torch build lacks torch._C._distributed_c10d, imported by shared test mixins.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions