From 8082e855c291961dbad82295dc2387d308cf895c Mon Sep 17 00:00:00 2001 From: Mingjia Li Date: Sat, 29 Nov 2025 02:34:25 +0800 Subject: [PATCH 1/4] Refactor image padding logic to pervent zero tensor in transformer_z_image.py --- .../transformers/transformer_z_image.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 4f2d56ea8f4d..2c3faf77c95e 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -515,16 +515,19 @@ def patchify_and_embed( start=(cap_ori_len + cap_padding_len + 1, 0, 0), device=device, ).flatten(0, 2) - image_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, + if image_padding_len > 0: + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) ) - .flatten(0, 2) - .repeat(image_padding_len, 1) - ) - image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + else : + image_padded_pos_ids = image_ori_pos_ids all_image_pos_ids.append(image_padded_pos_ids) # pad mask all_image_pad_mask.append( @@ -534,10 +537,10 @@ def patchify_and_embed( torch.ones((image_padding_len,), dtype=torch.bool, device=device), ], dim=0, - ) + ) if image_padding_len > 0 else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) ) # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) if image_padding_len > 0 else image all_image_out.append(image_padded_feat) return ( From e113de6d4ecab883af00a45d6d8496c4bd765d16 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 1 Dec 2025 19:39:01 +0000 Subject: [PATCH 2/4] Apply style fixes --- .../models/transformers/transformer_z_image.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 2c3faf77c95e..6f7b3dd72920 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -526,7 +526,7 @@ def patchify_and_embed( .repeat(image_padding_len, 1) ) image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) - else : + else: image_padded_pos_ids = image_ori_pos_ids all_image_pos_ids.append(image_padded_pos_ids) # pad mask @@ -537,10 +537,14 @@ def patchify_and_embed( torch.ones((image_padding_len,), dtype=torch.bool, device=device), ], dim=0, - ) if image_padding_len > 0 else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) + ) + if image_padding_len > 0 + else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) ) # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) if image_padding_len > 0 else image + image_padded_feat = ( + torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) if image_padding_len > 0 else image + ) all_image_out.append(image_padded_feat) return ( From faa73897954de577e301b5801cb75b7f3b590e9a Mon Sep 17 00:00:00 2001 From: Jerry Qilong Wu Date: Tue, 2 Dec 2025 07:54:07 +0000 Subject: [PATCH 3/4] Add more support to fix repeat bug on tpu devices. --- .../transformers/transformer_z_image.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 6f7b3dd72920..807d4ae5ae7e 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -490,13 +490,18 @@ def patchify_and_embed( ], dim=0, ) + if cap_padding_len > 0 + else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) ) # padded feature - cap_padded_feat = torch.cat( - [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], - dim=0, + all_cap_feats_out.append( + torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, + ) + if cap_padding_len > 0 + else cap_feat ) - all_cap_feats_out.append(cap_padded_feat) ### Process Image C, F, H, W = image.size() @@ -515,20 +520,19 @@ def patchify_and_embed( start=(cap_ori_len + cap_padding_len + 1, 0, 0), device=device, ).flatten(0, 2) - if image_padding_len > 0: - image_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(image_padding_len, 1) + all_image_pos_ids.append( + torch.cat( + [ + image_ori_pos_ids, + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(image_padding_len, 1), + ], + dim=0, ) - image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) - else: - image_padded_pos_ids = image_ori_pos_ids - all_image_pos_ids.append(image_padded_pos_ids) + if image_padding_len > 0 + else image_ori_pos_ids + ) # pad mask all_image_pad_mask.append( torch.cat( From b5bb7252b0ae609dfdf43e4d30eff9d82e2325a4 Mon Sep 17 00:00:00 2001 From: Jerry Qilong Wu Date: Tue, 2 Dec 2025 16:20:25 +0000 Subject: [PATCH 4/4] Fix for dynamo compile error for multi if-branches. --- .../transformers/transformer_z_image.py | 85 ++++++++++--------- 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 807d4ae5ae7e..097672e0f73b 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -482,26 +482,23 @@ def patchify_and_embed( ).flatten(0, 2) all_cap_pos_ids.append(cap_padded_pos_ids) # pad mask + cap_pad_mask = torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) all_cap_pad_mask.append( - torch.cat( - [ - torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), - torch.ones((cap_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) - if cap_padding_len > 0 - else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) + cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) ) + # padded feature - all_cap_feats_out.append( - torch.cat( - [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], - dim=0, - ) - if cap_padding_len > 0 - else cap_feat + cap_padded_feat = torch.cat( + [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], + dim=0, ) + all_cap_feats_out.append(cap_padded_feat if cap_padding_len > 0 else cap_feat) ### Process Image C, F, H, W = image.size() @@ -520,36 +517,35 @@ def patchify_and_embed( start=(cap_ori_len + cap_padding_len + 1, 0, 0), device=device, ).flatten(0, 2) - all_image_pos_ids.append( - torch.cat( - [ - image_ori_pos_ids, - self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) - .flatten(0, 2) - .repeat(image_padding_len, 1), - ], - dim=0, - ) - if image_padding_len > 0 - else image_ori_pos_ids + image_padded_pos_ids = torch.cat( + [ + image_ori_pos_ids, + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(image_padding_len, 1), + ], + dim=0, ) + all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) # pad mask + image_pad_mask = torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) all_image_pad_mask.append( - torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) + image_pad_mask if image_padding_len > 0 else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) ) # padded feature - image_padded_feat = ( - torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) if image_padding_len > 0 else image + image_padded_feat = torch.cat( + [image, image[-1:].repeat(image_padding_len, 1)], + dim=0, ) - all_image_out.append(image_padded_feat) + all_image_out.append(image_padded_feat if image_padding_len > 0 else image) return ( all_image_out, @@ -599,10 +595,13 @@ def forward( adaln_input = t.type_as(x) x[torch.cat(x_inner_pad_mask)] = self.x_pad_token x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) x = pad_sequence(x, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(x_item_seqlens): x_attn_mask[i, :seq_len] = 1 @@ -616,17 +615,21 @@ def forward( # cap embed & refine cap_item_seqlens = [len(_) for _ in cap_feats] - assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) cap_max_item_seqlen = max(cap_item_seqlens) cap_feats = torch.cat(cap_feats, dim=0) cap_feats = self.cap_embedder(cap_feats) cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(cap_item_seqlens): cap_attn_mask[i, :seq_len] = 1