diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 4f2d56ea8f4d..6f7b3dd72920 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -515,16 +515,19 @@ def patchify_and_embed( start=(cap_ori_len + cap_padding_len + 1, 0, 0), device=device, ).flatten(0, 2) - image_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, + if image_padding_len > 0: + image_padding_pos_ids = ( + self.create_coordinate_grid( + size=(1, 1, 1), + start=(0, 0, 0), + device=device, + ) + .flatten(0, 2) + .repeat(image_padding_len, 1) ) - .flatten(0, 2) - .repeat(image_padding_len, 1) - ) - image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) + else: + image_padded_pos_ids = image_ori_pos_ids all_image_pos_ids.append(image_padded_pos_ids) # pad mask all_image_pad_mask.append( @@ -535,9 +538,13 @@ def patchify_and_embed( ], dim=0, ) + if image_padding_len > 0 + else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) ) # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) + image_padded_feat = ( + torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) if image_padding_len > 0 else image + ) all_image_out.append(image_padded_feat) return (