Skip to content

Conversation

@JerryWu-code
Copy link
Contributor

@JerryWu-code JerryWu-code commented Dec 2, 2025

What does this PR do?

In ‎src/diffusers/models/transformers/transformer_z_image.py, there exists a image_padding_len for padding image to be multiple of SEQ_MULTI_OF. However, when the image shape is already multiple of SEQ_MULTI_OF, this will create a tensor with zero shape. This triggers INVALID_ARGUMENT: Concatenate expects at least one argument. for torch_xla.device() on TPU. This error currently won't emerge when changing device to cpu or cuda, but is in trouble with xla on TPU.

Fixes #12742 and #12743. And it's a final version of #12743 and built upon it.

Fixes # (issue)

  • This PR fixes compatibility of torch.Tensor repeat func error with empty dim for TPU in torch_xla.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This pr:

Minimal Test Cases in TPU

You could try this on Colab when changing to TPU backend:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

def test_device(device_name, device):
    """Test repeat(0, 1) behavior"""
    print(f"\n{'='*60}")
    print(f"Device: {device_name} ({device})")
    print('='*60)
    
    tensor = torch.randn(64, 2048).to(device)
    
    # Test 1: repeat(0, 1)
    try:
        result = tensor[-1:].repeat(0, 1)
        print(f"✅ Success repeat(0, 1): {result.shape}")
    except RuntimeError as e:
        print(f"❌ Failed repeat(0, 1): {str(e)[:50]}")
    
    # Test 2: Avoid repeat(0)
    padding_len = 0
    result = tensor if padding_len == 0 else torch.cat([tensor, tensor[-1:].repeat(padding_len, 1)], dim=0)
    print(f"✅ Success fix: {result.shape}")

# (1) Test TPU
torch_xla.experimental.eager_mode(True)
test_device("TPU/XLA", torch_xla.device())

# (2) CUDA/CPU
if torch.cuda.is_available():
    test_device("CUDA", torch.device("cuda:0"))
else:
    test_device("CPU", torch.device("cpu"))

And you would get like:

============================================================
Device: TPU/XLA (xla:0)
============================================================
❌ Failed repeat(0, 1): Concatenate expects at least one argument.
✅ Success fix: torch.Size([64, 2048])

============================================================
Device: CPU (cpu)
============================================================
✅ Success repeat(0, 1): torch.Size([0, 2048])
✅ Success fix: torch.Size([64, 2048])

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu merged commit 9379b23 into huggingface:main Dec 2, 2025
9 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Zero tensor in transformer_z_image

4 participants