Skip to content

Commit 2bb39f4

Browse files
committed
Fix neg embed and devide / bug; Reuse pad zero tensor; Turn cat -> repeat; Add hint for attn processor.
1 parent 5b4c907 commit 2bb39f4

File tree

2 files changed

+41
-33
lines changed

2 files changed

+41
-33
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ class ZSingleStreamAttnProcessor:
9090
_attention_backend = None
9191
_parallel_config = None
9292

93+
def __init__(self):
94+
if not hasattr(F, "scaled_dot_product_attention"):
95+
raise ImportError(
96+
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
97+
)
98+
9399
def __call__(
94100
self,
95101
attn: Attention,
@@ -493,7 +499,6 @@ def patchify_and_embed(
493499

494500
image_ori_len = len(image)
495501
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
496-
# padded_pos_ids
497502

498503
image_ori_pos_ids = self.create_coordinate_grid(
499504
size=(F_tokens, H_tokens, W_tokens),
@@ -574,11 +579,7 @@ def forward(
574579
x = list(x.split(x_item_seqlens, dim=0))
575580
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
576581

577-
pad_tensor = torch.zeros(
578-
(1, self.dim),
579-
dtype=x[0].dtype,
580-
device=device,
581-
)
582+
pad_tensor = torch.zeros((1, self.dim), dtype=x[0].dtype, device=device)
582583
freqs_pad_tensor = torch.zeros(
583584
(1, self.dim // self.n_heads // 2),
584585
dtype=x_freqs_cis[0].dtype,
@@ -613,22 +614,19 @@ def forward(
613614
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
614615
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
615616

616-
pad_tensor = torch.zeros(
617-
(1, self.dim),
618-
dtype=cap_feats[0].dtype,
619-
device=device,
620-
)
621-
freqs_pad_tensor = torch.zeros(
622-
(1, self.dim // self.n_heads // 2),
623-
dtype=cap_freqs_cis[0].dtype,
624-
device=device,
617+
# Reuse padding tensors (convert dtype if needed)
618+
cap_pad_tensor = pad_tensor.to(cap_feats[0].dtype) if pad_tensor.dtype != cap_feats[0].dtype else pad_tensor
619+
cap_freqs_pad_tensor = (
620+
freqs_pad_tensor.to(cap_freqs_cis[0].dtype)
621+
if freqs_pad_tensor.dtype != cap_freqs_cis[0].dtype
622+
else freqs_pad_tensor
625623
)
626624
cap_attn_mask = torch.ones((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
627625
for i, (item, freqs_item) in enumerate(zip(cap_feats, cap_freqs_cis)):
628626
seq_len = cap_item_seqlens[i]
629627
pad_len = cap_max_item_seqlen - seq_len
630-
cap_feats[i] = torch.cat([item, pad_tensor.repeat(pad_len, 1)])
631-
cap_freqs_cis[i] = torch.cat([freqs_item, freqs_pad_tensor.repeat(pad_len, 1)])
628+
cap_feats[i] = torch.cat([item, cap_pad_tensor.repeat(pad_len, 1)])
629+
cap_freqs_cis[i] = torch.cat([freqs_item, cap_freqs_pad_tensor.repeat(pad_len, 1)])
632630
cap_attn_mask[i, seq_len:] = 0
633631
cap_feats = torch.stack(cap_feats)
634632
cap_freqs_cis = torch.stack(cap_freqs_cis)
@@ -652,22 +650,18 @@ def forward(
652650
assert unified_item_seqlens == [len(_) for _ in unified]
653651
unified_max_item_seqlen = max(unified_item_seqlens)
654652

655-
pad_tensor = torch.zeros(
656-
(1, self.dim),
657-
dtype=unified[0].dtype,
658-
device=device,
659-
)
660-
freqs_pad_tensor = torch.zeros(
661-
(1, self.dim // self.n_heads // 2),
662-
dtype=unified_freqs_cis[0].dtype,
663-
device=device,
653+
unified_pad_tensor = pad_tensor.to(unified[0].dtype) if pad_tensor.dtype != unified[0].dtype else pad_tensor
654+
unified_freqs_pad_tensor = (
655+
freqs_pad_tensor.to(unified_freqs_cis[0].dtype)
656+
if freqs_pad_tensor.dtype != unified_freqs_cis[0].dtype
657+
else freqs_pad_tensor
664658
)
665659
unified_attn_mask = torch.ones((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
666660
for i, (item, freqs_item) in enumerate(zip(unified, unified_freqs_cis)):
667661
seq_len = unified_item_seqlens[i]
668662
pad_len = unified_max_item_seqlen - seq_len
669-
unified[i] = torch.cat([item, pad_tensor.repeat(pad_len, 1)])
670-
unified_freqs_cis[i] = torch.cat([freqs_item, freqs_pad_tensor.repeat(pad_len, 1)])
663+
unified[i] = torch.cat([item, unified_pad_tensor.repeat(pad_len, 1)])
664+
unified_freqs_cis[i] = torch.cat([freqs_item, unified_freqs_pad_tensor.repeat(pad_len, 1)])
671665
unified_attn_mask[i, seq_len:] = 0
672666
unified = torch.stack(unified)
673667
unified_freqs_cis = torch.stack(unified_freqs_cis)

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def encode_prompt(
193193
prompt_embeds=negative_prompt_embeds,
194194
max_sequence_length=max_sequence_length,
195195
)
196+
else:
197+
negative_prompt_embeds = []
196198
return prompt_embeds, negative_prompt_embeds
197199

198200
def _encode_prompt(
@@ -398,6 +400,18 @@ def __call__(
398400
height = height or 1024
399401
width = width or 1024
400402

403+
vae_scale = self.vae_scale_factor * 2
404+
if height % vae_scale != 0:
405+
raise ValueError(
406+
f"Height must be divisible by {vae_scale} (got {height}). "
407+
f"Please adjust the height to a multiple of {vae_scale}."
408+
)
409+
if width % vae_scale != 0:
410+
raise ValueError(
411+
f"Width must be divisible by {vae_scale} (got {width}). "
412+
f"Please adjust the width to a multiple of {vae_scale}."
413+
)
414+
401415
assert self.dtype == torch.bfloat16
402416
dtype = self.dtype
403417
device = self._execution_device
@@ -447,7 +461,7 @@ def __call__(
447461
generator,
448462
latents,
449463
)
450-
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] / 2)
464+
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
451465

452466
# 5. Prepare timesteps
453467
mu = calculate_shift(
@@ -495,12 +509,12 @@ def __call__(
495509
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
496510

497511
if apply_cfg:
498-
# Prepare inputs for CFG
499-
latent_model_input = torch.cat([latents.to(dtype)] * 2)
512+
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
513+
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
500514
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
501-
timestep_model_input = torch.cat([timestep] * 2)
515+
timestep_model_input = timestep.repeat(2)
502516
else:
503-
latent_model_input = latents.to(dtype)
517+
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
504518
prompt_embeds_model_input = prompt_embeds
505519
timestep_model_input = timestep
506520

0 commit comments

Comments
 (0)