Skip to content

Conversation

@lime-j
Copy link
Contributor

@lime-j lime-j commented Nov 28, 2025

What does this PR do?

In ‎src/diffusers/models/transformers/transformer_z_image.py, there exists a image_padding_len for padding image to be multiple of SEQ_MULTI_OF. However, when the image shape is already multiple of SEQ_MULTI_OF, this will create a tensor with zero shape. This triggers INVALID_ARGUMENT: Concatenate expects at least one argument. for PyTorch/XLA on TPU. It may also fail on other devices than cuda.

Fixes #12742

Who can review?

@yiyixuxu @asomoza

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 1, 2025

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

Style bot fixed some files and pushed the changes.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 1, 2025

cc @JerryWu-code can you take a look too?

@JerryWu-code
Copy link
Contributor

cc @JerryWu-code can you take a look too?

Sure, I'm happy to proceed that, this error does exist on TPU with using torch_xla, and I add a minimal test case to report the problem and fix more on #12770. Ready to merge ~

@lime-j
Copy link
Contributor Author

lime-j commented Dec 3, 2025

Should I close the pr now? It's my first time contributing to diffusers :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Zero tensor in transformer_z_image

4 participants