From 8082e855c291961dbad82295dc2387d308cf895c Mon Sep 17 00:00:00 2001 From: Mingjia Li Date: Sat, 29 Nov 2025 02:34:25 +0800 Subject: [PATCH 1/2] Refactor image padding logic to pervent zero tensor in transformer_z_image.py --- .../transformers/transformer_z_image.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 4f2d56ea8f4d..2c3faf77c95e 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -515,16 +515,19 @@ def patchify_and_embed( start=(cap_ori_len + cap_padding_len + 1, 0, 0), device=device, ).flatten(0, 2) - image_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, + if image_padding_len > 0: + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) ) - .flatten(0, 2) - .repeat(image_padding_len, 1) - ) - image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + else : + image_padded_pos_ids = image_ori_pos_ids all_image_pos_ids.append(image_padded_pos_ids) # pad mask all_image_pad_mask.append( @@ -534,10 +537,10 @@ def patchify_and_embed( torch.ones((image_padding_len,), dtype=torch.bool, device=device), ], dim=0, - ) + ) if image_padding_len > 0 else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) ) # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) if image_padding_len > 0 else image all_image_out.append(image_padded_feat) return ( From e113de6d4ecab883af00a45d6d8496c4bd765d16 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 1 Dec 2025 19:39:01 +0000 Subject: [PATCH 2/2] Apply style fixes --- .../models/transformers/transformer_z_image.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 2c3faf77c95e..6f7b3dd72920 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -526,7 +526,7 @@ def patchify_and_embed( .repeat(image_padding_len, 1) ) image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) - else : + else: image_padded_pos_ids = image_ori_pos_ids all_image_pos_ids.append(image_padded_pos_ids) # pad mask @@ -537,10 +537,14 @@ def patchify_and_embed( torch.ones((image_padding_len,), dtype=torch.bool, device=device), ], dim=0, - ) if image_padding_len > 0 else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) + ) + if image_padding_len > 0 + else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) ) # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) if image_padding_len > 0 else image + image_padded_feat = ( + torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) if image_padding_len > 0 else image + ) all_image_out.append(image_padded_feat) return (