hidream_image model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Duplicate search performed with gh for HiDream, class/file names, and specific failure modes. Existing duplicate found only for the torch.compile item: #11477 and #11430.
Issue 1: height and width are silently replaced with default-area dimensions
Affected code:
|
height = height or self.default_sample_size * self.vae_scale_factor |
|
width = width or self.default_sample_size * self.vae_scale_factor |
|
|
|
division = self.vae_scale_factor * 2 |
|
S_max = (self.default_sample_size * self.vae_scale_factor) ** 2 |
|
scale = S_max / (width * height) |
|
scale = math.sqrt(scale) |
|
width, height = int(width * scale // division * division), int(height * scale // division * division) |
Problem:
The pipeline rescales any requested height/width to the default sample area. A user asking for 64x64 gets default-area latents/images instead of 64x64.
Impact:
User-visible output dimensions do not match the public API or docstring, and precomputed latents for the requested size are rejected.
Reproduction:
import torch
from diffusers import FlowMatchEulerDiscreteScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel
m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=4, out_channels=4, num_layers=0, num_single_layers=0,
attention_head_dim=8, num_attention_heads=1, caption_channels=[32, 16], text_emb_dim=64,
num_routed_experts=0, axes_dims_rope=(4, 2, 2), max_resolution=(32, 32), llama_layers=())
p = HiDreamImagePipeline(FlowMatchEulerDiscreteScheduler(), None, None, None, None, None, None, None, None, None, m)
out = p(height=64, width=64, num_inference_steps=0, guidance_scale=1.0, output_type="latent",
pooled_prompt_embeds=torch.randn(1, 64), prompt_embeds_t5=torch.randn(1, 1, 32),
prompt_embeds_llama3=torch.randn(0, 1, 1, 16)).images
print(out.shape) # torch.Size([1, 4, 128, 128])
Relevant precedent:
|
# 4. Prepare latent variables |
|
num_channels_latents = self.transformer.config.in_channels // 4 |
|
latents, latent_image_ids = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
Suggested fix:
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
division = self.vae_scale_factor * self.transformer.config.patch_size
height = int(height) // division * division
width = int(width) // division * division
image_seq_len = (height // self.vae_scale_factor // self.transformer.config.patch_size) * (
width // self.vae_scale_factor // self.transformer.config.patch_size
)
if image_seq_len > self.transformer.max_seq:
raise ValueError(f"Requested image has {image_seq_len} latent tokens, but this model supports {self.transformer.max_seq}.")
Issue 2: attention_kwargs is accepted but never forwarded to the transformer
Affected code:
|
noise_pred = self.transformer( |
|
hidden_states=latent_model_input, |
|
timesteps=timestep, |
|
encoder_hidden_states_t5=prompt_embeds_t5, |
|
encoder_hidden_states_llama3=prompt_embeds_llama3, |
|
pooled_embeds=pooled_prompt_embeds, |
|
return_dict=False, |
|
)[0] |
Problem:
__call__ stores attention_kwargs, and the transformer forward is decorated with @apply_lora_scale("attention_kwargs"), but the denoising call never passes the kwargs.
Impact:
Runtime LoRA scaling via attention_kwargs={"scale": ...} is ignored for HiDream.
Reproduction:
import torch
from diffusers import FlowMatchEulerDiscreteScheduler, HiDreamImagePipeline, HiDreamImageTransformer2DModel
m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=4, out_channels=4, num_layers=1, num_single_layers=0,
attention_head_dim=8, num_attention_heads=1, caption_channels=[32, 16], text_emb_dim=64,
num_routed_experts=0, axes_dims_rope=(4, 2, 2), max_resolution=(32, 32), llama_layers=(0,)).eval()
p = HiDreamImagePipeline(FlowMatchEulerDiscreteScheduler(), None, None, None, None, None, None, None, None, None, m)
seen = []
h = p.transformer.register_forward_pre_hook(lambda module, args, kwargs: seen.append(kwargs.get("attention_kwargs")), with_kwargs=True)
p(height=128, width=128, num_inference_steps=1, guidance_scale=1.0, output_type="latent",
attention_kwargs={"scale": 0.25}, pooled_prompt_embeds=torch.randn(1, 64),
prompt_embeds_t5=torch.randn(1, 4, 32), prompt_embeds_llama3=torch.randn(1, 1, 4, 16))
h.remove()
print(seen) # [None]
Relevant precedent:
|
noise_pred = self.transformer( |
|
hidden_states=latents, |
|
timestep=timestep / 1000, |
|
guidance=guidance, |
|
encoder_hidden_states_mask=prompt_embeds_mask, |
|
encoder_hidden_states=prompt_embeds, |
|
img_shapes=img_shapes, |
|
attention_kwargs=self.attention_kwargs, |
|
return_dict=False, |
Suggested fix:
noise_pred = self.transformer(
hidden_states=latent_model_input,
timesteps=timestep,
encoder_hidden_states_t5=prompt_embeds_t5,
encoder_hidden_states_llama3=prompt_embeds_llama3,
pooled_embeds=pooled_prompt_embeds,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
Issue 3: Padded tokens are not actually masked in attention
Affected code:
|
if hidden_states_masks is not None: |
|
key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1) |
|
|
|
if not attn.single: |
|
query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype) |
|
key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype) |
|
value_t = attn.to_v_t(encoder_hidden_states) |
|
|
|
query_t = query_t.view(batch_size, -1, attn.heads, head_dim) |
|
key_t = key_t.view(batch_size, -1, attn.heads, head_dim) |
|
value_t = value_t.view(batch_size, -1, attn.heads, head_dim) |
|
|
|
num_image_tokens = query_i.shape[1] |
|
num_text_tokens = query_t.shape[1] |
|
query = torch.cat([query_i, query_t], dim=1) |
|
key = torch.cat([key_i, key_t], dim=1) |
|
value = torch.cat([value_i, value_t], dim=1) |
|
else: |
|
query = query_i |
|
key = key_i |
|
value = value_i |
|
|
|
if query.shape[-1] == image_rotary_emb.shape[-3] * 2: |
|
query, key = apply_rope(query, key, image_rotary_emb) |
|
|
|
else: |
|
query_1, query_2 = query.chunk(2, dim=-1) |
|
key_1, key_2 = key.chunk(2, dim=-1) |
|
query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb) |
|
query = torch.cat([query_1, query_2], dim=-1) |
|
key = torch.cat([key_1, key_2], dim=-1) |
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False |
Problem:
For non-square/padded latent sequences, hidden_states_masks only multiplies image keys by 0. It does not mask attention logits and does not mask values, so masked tokens still change valid outputs.
Impact:
Non-square generation can be contaminated by padded latent tokens, and externally supplied padded hidden states are not semantically masked.
Reproduction:
import torch
from diffusers import HiDreamImageTransformer2DModel
torch.manual_seed(0)
m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=1, out_channels=1, num_layers=1, num_single_layers=0,
attention_head_dim=8, num_attention_heads=1, caption_channels=[3, 2], text_emb_dim=5,
num_routed_experts=0, axes_dims_rope=(4, 2, 2), max_resolution=(4, 4), llama_layers=(0,)).eval()
hidden = torch.randn(1, 4, 4)
kwargs = dict(timesteps=torch.tensor([1]), encoder_hidden_states_t5=torch.randn(1, 1, 3),
encoder_hidden_states_llama3=torch.randn(1, 1, 1, 2), pooled_embeds=torch.randn(1, 5),
img_ids=torch.zeros(1, 4, 3), img_sizes=torch.tensor([[1, 2]]),
hidden_states_masks=torch.tensor([[1., 1., 0., 0.]]), return_dict=False)
with torch.no_grad():
out1 = m(hidden_states=hidden.clone(), **kwargs)[0]
hidden[:, 2:] += 1000
out2 = m(hidden_states=hidden, **kwargs)[0]
print((out1 - out2).abs().max().item()) # non-zero
Relevant precedent:
|
joint_hidden_states = dispatch_attention_fn( |
|
joint_query, |
|
joint_key, |
|
joint_value, |
|
attn_mask=attention_mask, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
backend=self._attention_backend, |
|
parallel_config=self._parallel_config, |
Suggested fix:
Route HiDream attention through dispatch_attention_fn with a real boolean attention mask. For double-stream blocks, concatenate the image mask with all-true text masks before dispatch; for single-stream blocks, pass the already-expanded image/text mask.
Issue 4: HiDream attention bypasses diffusers attention backend dispatch
Affected code:
|
class HiDreamAttnProcessor: |
|
"""Attention processor used typically in processing the SD3-like self-attention projections.""" |
|
|
|
def __call__( |
|
self, |
|
attn: HiDreamAttention, |
|
hidden_states: torch.Tensor, |
|
hidden_states_masks: torch.Tensor | None = None, |
|
encoder_hidden_states: torch.Tensor | None = None, |
|
image_rotary_emb: torch.Tensor = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
dtype = hidden_states.dtype |
|
batch_size = hidden_states.shape[0] |
|
|
|
query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype) |
|
key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype) |
|
value_i = attn.to_v(hidden_states) |
|
|
|
inner_dim = key_i.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query_i = query_i.view(batch_size, -1, attn.heads, head_dim) |
|
key_i = key_i.view(batch_size, -1, attn.heads, head_dim) |
|
value_i = value_i.view(batch_size, -1, attn.heads, head_dim) |
|
if hidden_states_masks is not None: |
|
key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1) |
|
|
|
if not attn.single: |
|
query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype) |
|
key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype) |
|
value_t = attn.to_v_t(encoder_hidden_states) |
|
|
|
query_t = query_t.view(batch_size, -1, attn.heads, head_dim) |
|
key_t = key_t.view(batch_size, -1, attn.heads, head_dim) |
|
value_t = value_t.view(batch_size, -1, attn.heads, head_dim) |
|
|
|
num_image_tokens = query_i.shape[1] |
|
num_text_tokens = query_t.shape[1] |
|
query = torch.cat([query_i, query_t], dim=1) |
|
key = torch.cat([key_i, key_t], dim=1) |
|
value = torch.cat([value_i, value_t], dim=1) |
|
else: |
|
query = query_i |
|
key = key_i |
|
value = value_i |
|
|
|
if query.shape[-1] == image_rotary_emb.shape[-3] * 2: |
|
query, key = apply_rope(query, key, image_rotary_emb) |
|
|
|
else: |
|
query_1, query_2 = query.chunk(2, dim=-1) |
|
key_1, key_2 = key.chunk(2, dim=-1) |
|
query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb) |
|
query = torch.cat([query_1, query_2], dim=-1) |
|
key = torch.cat([key_1, key_2], dim=-1) |
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False |
Problem:
HiDreamAttnProcessor calls F.scaled_dot_product_attention directly and has no _attention_backend / _parallel_config. model.set_attention_backend(...) therefore has no effect.
Impact:
HiDream cannot use diffusers attention backends such as Flash/Sage/Flex through the standard API, and context-parallel backend plumbing is bypassed.
Reproduction:
from diffusers import HiDreamImageTransformer2DModel
m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=4, out_channels=4, num_layers=1, num_single_layers=0,
attention_head_dim=8, num_attention_heads=1, caption_channels=[32, 16], text_emb_dim=64,
num_routed_experts=0, axes_dims_rope=(4, 2, 2), max_resolution=(32, 32), llama_layers=(0,))
processors = [mod.processor for mod in m.modules() if mod.__class__.__name__ == "HiDreamAttention"]
print(hasattr(processors[0], "_attention_backend")) # False
m.set_attention_backend("native")
print(getattr(processors[0], "_attention_backend", None)) # None
Relevant precedent:
|
class FluxAttnProcessor: |
|
_attention_backend = None |
|
_parallel_config = None |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") |
|
|
|
def __call__( |
|
self, |
|
attn: "FluxAttention", |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor = None, |
|
attention_mask: torch.Tensor | None = None, |
|
image_rotary_emb: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( |
|
attn, hidden_states, encoder_hidden_states |
|
) |
|
|
|
query = query.unflatten(-1, (attn.heads, -1)) |
|
key = key.unflatten(-1, (attn.heads, -1)) |
|
value = value.unflatten(-1, (attn.heads, -1)) |
|
|
|
query = attn.norm_q(query) |
|
key = attn.norm_k(key) |
|
|
|
if attn.added_kv_proj_dim is not None: |
|
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) |
|
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) |
|
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) |
|
|
|
encoder_query = attn.norm_added_q(encoder_query) |
|
encoder_key = attn.norm_added_k(encoder_key) |
|
|
|
query = torch.cat([encoder_query, query], dim=1) |
|
key = torch.cat([encoder_key, key], dim=1) |
|
value = torch.cat([encoder_value, value], dim=1) |
|
|
|
if image_rotary_emb is not None: |
|
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) |
|
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) |
|
|
|
hidden_states = dispatch_attention_fn( |
|
query, |
|
key, |
|
value, |
|
attn_mask=attention_mask, |
|
backend=self._attention_backend, |
|
parallel_config=self._parallel_config, |
|
) |
Suggested fix:
Make HiDreamAttention follow the new custom-attention pattern: inherit AttentionModuleMixin, declare _default_processor_cls / _available_processors, add _attention_backend and _parallel_config to the processor, and replace the direct SDPA call with dispatch_attention_fn(...).
Issue 5: Transformer config defaults are invalid
Affected code:
|
patch_size: int | None = None, |
|
in_channels: int = 64, |
|
out_channels: int | None = None, |
|
num_layers: int = 16, |
|
num_single_layers: int = 32, |
|
attention_head_dim: int = 128, |
|
num_attention_heads: int = 20, |
|
caption_channels: list[int] = None, |
|
text_emb_dim: int = 2048, |
|
num_routed_experts: int = 4, |
|
num_activated_experts: int = 2, |
|
axes_dims_rope: tuple[int, int] = (32, 32), |
|
max_resolution: tuple[int, int] = (128, 128), |
|
llama_layers: list[int] = None, |
Problem:
Several registered config defaults are unusable: patch_size=None, caption_channels=None, llama_layers=None, and axes_dims_rope=(32, 32) for three-axis ids.
Impact:
Default construction and configs missing these fields fail with low-level errors instead of a clear config error, which weakens serialization/backwards-compatibility behavior.
Reproduction:
from diffusers import HiDreamImageTransformer2DModel
try:
HiDreamImageTransformer2DModel()
except Exception as e:
print(type(e).__name__, e)
# TypeError unsupported operand type(s) for *: 'int' and 'NoneType'
Relevant precedent:
Flux and Qwen transformer constructors use internally consistent config defaults.
Suggested fix:
if patch_size is None:
raise ValueError("`patch_size` must be set for HiDreamImageTransformer2DModel.")
if caption_channels is None or len(caption_channels) != 2:
raise ValueError("`caption_channels` must contain [t5_dim, llama_dim].")
if llama_layers is None:
raise ValueError("`llama_layers` must be set.")
if len(axes_dims_rope) != 3:
raise ValueError("`axes_dims_rope` must contain three axes for HiDream image ids.")
Issue 6: MoE inference path is not torch.compile-friendly
Affected code:
|
@torch.no_grad() |
|
def moe_infer(self, x, flat_expert_indices, flat_expert_weights): |
|
expert_cache = torch.zeros_like(x) |
|
idxs = flat_expert_indices.argsort() |
|
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) |
Problem:
moe_infer performs bincount().cpu().numpy().cumsum(0) inside model forward. This creates host synchronization and graph breaks.
Impact:
torch.compile(fullgraph=True) fails for the default MoE inference path.
Reproduction:
import torch
from diffusers import HiDreamImageTransformer2DModel
m = HiDreamImageTransformer2DModel(patch_size=2, in_channels=1, out_channels=1, num_layers=1, num_single_layers=0,
attention_head_dim=8, num_attention_heads=1, caption_channels=[3, 2], text_emb_dim=5,
num_routed_experts=2, num_activated_experts=1, axes_dims_rope=(4, 2, 2), max_resolution=(4, 4),
llama_layers=(0,)).eval()
compiled = torch.compile(m, fullgraph=True)
try:
compiled(hidden_states=torch.randn(1, 1, 4, 4), timesteps=torch.tensor([1]),
encoder_hidden_states_t5=torch.randn(1, 1, 3), encoder_hidden_states_llama3=torch.randn(1, 1, 1, 2),
pooled_embeds=torch.randn(1, 5), return_dict=False)
except Exception as e:
print(type(e).__name__, str(e).splitlines()[0])
Relevant precedent:
|
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) |
|
|
|
gate_up = F.grouped_mm(x, self.gate_up_proj, offs=offsets) |
|
gate, up = gate_up.chunk(2, dim=-1) |
|
out = F.grouped_mm(F.silu(gate) * up, self.down_proj, offs=offsets) |
|
|
|
return out.type_as(x) |
|
|
|
def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor: |
|
if self.use_grouped_mm: |
|
return self._run_experts_grouped_mm(x, num_tokens_per_expert) |
|
return self._run_experts_for_loop(x, num_tokens_per_expert) |
Suggested fix:
Existing duplicate: #11477. Continue that PR or replace this routing with a torch-only grouped implementation that avoids NumPy and host-side dynamic routing in the compiled path.
Issue 7: No slow HiDream pipeline/model coverage exists
Affected code:
|
class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
pipeline_class = HiDreamImagePipeline |
|
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"} |
|
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS |
|
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS |
|
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS |
|
required_optional_params = PipelineTesterMixin.required_optional_params |
|
test_xformers_attention = False |
|
test_layerwise_casting = True |
|
supports_dduf = False |
|
|
|
def get_dummy_components(self): |
|
torch.manual_seed(0) |
|
transformer = HiDreamImageTransformer2DModel( |
|
patch_size=2, |
|
in_channels=4, |
|
out_channels=4, |
|
num_layers=1, |
|
num_single_layers=1, |
|
attention_head_dim=8, |
|
num_attention_heads=4, |
|
caption_channels=[32, 16], |
|
text_emb_dim=64, |
|
num_routed_experts=4, |
|
num_activated_experts=2, |
|
axes_dims_rope=(4, 2, 2), |
|
max_resolution=(32, 32), |
|
llama_layers=(0, 1), |
|
).eval() |
|
torch.manual_seed(0) |
|
vae = AutoencoderKL(scaling_factor=0.3611, shift_factor=0.1159) |
|
clip_text_encoder_config = CLIPTextConfig( |
|
bos_token_id=0, |
|
eos_token_id=2, |
|
hidden_size=32, |
|
intermediate_size=37, |
|
layer_norm_eps=1e-05, |
|
num_attention_heads=4, |
|
num_hidden_layers=5, |
|
pad_token_id=1, |
|
vocab_size=1000, |
|
hidden_act="gelu", |
|
projection_dim=32, |
|
max_position_embeddings=128, |
|
) |
|
|
|
torch.manual_seed(0) |
|
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config) |
|
|
|
torch.manual_seed(0) |
|
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config) |
|
|
|
torch.manual_seed(0) |
|
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
text_encoder_3 = T5EncoderModel(config).eval() |
|
|
|
torch.manual_seed(0) |
|
text_encoder_4 = LlamaForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") |
|
text_encoder_4.generation_config.pad_token_id = 1 |
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") |
|
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") |
|
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
tokenizer_4 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM") |
|
|
|
scheduler = FlowMatchEulerDiscreteScheduler() |
|
|
|
components = { |
|
"scheduler": scheduler, |
|
"vae": vae, |
|
"text_encoder": text_encoder, |
|
"tokenizer": tokenizer, |
|
"text_encoder_2": text_encoder_2, |
|
"tokenizer_2": tokenizer_2, |
|
"text_encoder_3": text_encoder_3, |
|
"tokenizer_3": tokenizer_3, |
|
"text_encoder_4": text_encoder_4, |
|
"tokenizer_4": tokenizer_4, |
|
"transformer": transformer, |
|
} |
|
return components |
|
|
|
def get_dummy_inputs(self, device, seed=0): |
|
if str(device).startswith("mps"): |
|
generator = torch.manual_seed(seed) |
|
else: |
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
inputs = { |
|
"prompt": "A painting of a squirrel eating a burger", |
|
"generator": generator, |
|
"num_inference_steps": 2, |
|
"guidance_scale": 5.0, |
Problem:
HiDream has fast pipeline and model tests, plus GGUF coverage, but no @slow test for a published HiDream pipeline/model checkpoint.
Impact:
Large-checkpoint loading, real component configs, offload behavior, real output slices, and docs example compatibility are not covered.
Reproduction:
from pathlib import Path
paths = [Path("tests/pipelines/hidream_image"), Path("tests/models/transformers/test_models_transformer_hidream.py")]
hits = []
for path in paths:
files = [path] if path.is_file() else list(path.rglob("*.py"))
hits += [(str(f), i) for f in files for i, line in enumerate(f.read_text().splitlines(), 1) if "@slow" in line or "slow(" in line]
print(hits) # []
Relevant precedent:
|
@slow |
|
@require_big_accelerator |
|
class StableDiffusion3PipelineSlowTests(unittest.TestCase): |
|
pipeline_class = StableDiffusion3Pipeline |
|
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" |
|
|
|
def setUp(self): |
|
super().setUp() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
def tearDown(self): |
|
super().tearDown() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
def get_inputs(self, device, seed=0): |
|
if str(device).startswith("mps"): |
|
generator = torch.manual_seed(seed) |
Suggested fix:
Add a @slow HiDream test class that loads a real or hf-internal-testing published HiDream pipeline checkpoint through from_pretrained, runs a deterministic minimal inference, and asserts shape plus an output slice.
Local verification note: targeted snippets were run with .venv. Direct pytest collection for the HiDream model and pipeline test files failed in this environment because the installed torch build lacks torch._C._distributed_c10d, imported by shared test mixins.
hidream_imagemodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Duplicate search performed with
ghforHiDream, class/file names, and specific failure modes. Existing duplicate found only for the torch.compile item: #11477 and #11430.Issue 1:
heightandwidthare silently replaced with default-area dimensionsAffected code:
diffusers/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
Lines 873 to 880 in 0f1abc4
Problem:
The pipeline rescales any requested
height/widthto the default sample area. A user asking for64x64gets default-area latents/images instead of64x64.Impact:
User-visible output dimensions do not match the public API or docstring, and precomputed latents for the requested size are rejected.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/flux/pipeline_flux.py
Lines 854 to 865 in 0f1abc4
Suggested fix:
Issue 2:
attention_kwargsis accepted but never forwarded to the transformerAffected code:
diffusers/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
Lines 996 to 1003 in 0f1abc4
Problem:
__call__storesattention_kwargs, and the transformer forward is decorated with@apply_lora_scale("attention_kwargs"), but the denoising call never passes the kwargs.Impact:
Runtime LoRA scaling via
attention_kwargs={"scale": ...}is ignored for HiDream.Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Lines 695 to 703 in 0f1abc4
Suggested fix:
Issue 3: Padded tokens are not actually masked in attention
Affected code:
diffusers/src/diffusers/models/transformers/transformer_hidream_image.py
Lines 227 to 260 in 0f1abc4
Problem:
For non-square/padded latent sequences,
hidden_states_masksonly multiplies image keys by0. It does not mask attention logits and does not mask values, so masked tokens still change valid outputs.Impact:
Non-square generation can be contaminated by padded latent tokens, and externally supplied padded hidden states are not semantically masked.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_qwenimage.py
Lines 562 to 570 in 0f1abc4
Suggested fix:
Route HiDream attention through
dispatch_attention_fnwith a real boolean attention mask. For double-stream blocks, concatenate the image mask with all-true text masks before dispatch; for single-stream blocks, pass the already-expanded image/text mask.Issue 4: HiDream attention bypasses diffusers attention backend dispatch
Affected code:
diffusers/src/diffusers/models/transformers/transformer_hidream_image.py
Lines 201 to 260 in 0f1abc4
Problem:
HiDreamAttnProcessorcallsF.scaled_dot_product_attentiondirectly and has no_attention_backend/_parallel_config.model.set_attention_backend(...)therefore has no effect.Impact:
HiDream cannot use diffusers attention backends such as Flash/Sage/Flex through the standard API, and context-parallel backend plumbing is bypassed.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_flux.py
Lines 75 to 125 in 0f1abc4
Suggested fix:
Make
HiDreamAttentionfollow the new custom-attention pattern: inheritAttentionModuleMixin, declare_default_processor_cls/_available_processors, add_attention_backendand_parallel_configto the processor, and replace the direct SDPA call withdispatch_attention_fn(...).Issue 5: Transformer config defaults are invalid
Affected code:
diffusers/src/diffusers/models/transformers/transformer_hidream_image.py
Lines 612 to 625 in 0f1abc4
Problem:
Several registered config defaults are unusable:
patch_size=None,caption_channels=None,llama_layers=None, andaxes_dims_rope=(32, 32)for three-axis ids.Impact:
Default construction and configs missing these fields fail with low-level errors instead of a clear config error, which weakens serialization/backwards-compatibility behavior.
Reproduction:
Relevant precedent:
Flux and Qwen transformer constructors use internally consistent config defaults.
Suggested fix:
Issue 6: MoE inference path is not torch.compile-friendly
Affected code:
diffusers/src/diffusers/models/transformers/transformer_hidream_image.py
Lines 388 to 392 in 0f1abc4
Problem:
moe_inferperformsbincount().cpu().numpy().cumsum(0)inside model forward. This creates host synchronization and graph breaks.Impact:
torch.compile(fullgraph=True)fails for the default MoE inference path.Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_nucleusmoe_image.py
Lines 487 to 498 in 0f1abc4
Suggested fix:
Existing duplicate: #11477. Continue that PR or replace this routing with a torch-only grouped implementation that avoids NumPy and host-side dynamic routing in the compiled path.
Issue 7: No slow HiDream pipeline/model coverage exists
Affected code:
diffusers/tests/pipelines/hidream_image/test_pipeline_hidream.py
Lines 45 to 135 in 0f1abc4
Problem:
HiDream has fast pipeline and model tests, plus GGUF coverage, but no
@slowtest for a published HiDream pipeline/model checkpoint.Impact:
Large-checkpoint loading, real component configs, offload behavior, real output slices, and docs example compatibility are not covered.
Reproduction:
Relevant precedent:
diffusers/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
Lines 227 to 245 in 0f1abc4
Suggested fix:
Add a
@slowHiDream test class that loads a real orhf-internal-testingpublished HiDream pipeline checkpoint throughfrom_pretrained, runs a deterministic minimal inference, and asserts shape plus an output slice.Local verification note: targeted snippets were run with
.venv. Direct pytest collection for the HiDream model and pipeline test files failed in this environment because the installed torch build lackstorch._C._distributed_c10d, imported by shared test mixins.