Skip to content

[core] fix autoencoderkl qwenimage for xla#13480

Merged
sayakpaul merged 3 commits intomainfrom
xla-autoencoder-qwenimage
Apr 16, 2026
Merged

[core] fix autoencoderkl qwenimage for xla#13480
sayakpaul merged 3 commits intomainfrom
xla-autoencoder-qwenimage

Conversation

@sayakpaul
Copy link
Copy Markdown
Member

What does this PR do?

Verify with the following:

Code
"""Minimal reproducer: XLA rejects out-of-bounds negative index that CPU/CUDA silently clamps."""

import torch
import torch_xla.core.xla_model as xm

CACHE_T = 2

# Temporal dim = 1 (single image, not video)
x_cpu = torch.randn(1, 4, 1, 8, 8)

# CPU: silently clamps -2 to 0, returns the whole tensor
result_cpu = x_cpu[:, :, -CACHE_T:, :, :]
print(f"CPU: x.shape={x_cpu.shape}, x[:,:,-2:,:,:].shape={result_cpu.shape}")
assert result_cpu.shape == (1, 4, 1, 8, 8), "CPU should clamp and return full tensor"

# CPU with fix: identical result
result_cpu_fix = x_cpu[:, :, -min(CACHE_T, x_cpu.shape[2]):, :, :]
print(f"CPU fix: x[:,:,-min(2,1):,:,:].shape={result_cpu_fix.shape}")
assert torch.equal(result_cpu, result_cpu_fix), "Fix produces same result on CPU"

# XLA: strict bounds checking
device = xm.xla_device()
x_xla = x_cpu.to(device)

print("\nXLA: trying x[:,:,-2:,:,:] with temporal dim=1...")
try:
    result_xla = x_xla[:, :, -CACHE_T:, :, :]
    print(f"XLA: succeeded (shape={result_xla.shape})")
except RuntimeError as e:
    print(f"XLA: FAILED as expected — {e}")

# XLA with fix: works
result_xla_fix = x_xla[:, :, -min(CACHE_T, x_xla.shape[2]):, :, :]
print(f"XLA fix: x[:,:,-min(2,1):,:,:].shape={result_xla_fix.shape}")

# Also verify temporal dim >= 2 works identically with both approaches
print("\n--- Temporal dim = 3 (no issue on either backend) ---")
x_cpu_3 = torch.randn(1, 4, 3, 8, 8)
x_xla_3 = x_cpu_3.to(device)

result_orig = x_xla_3[:, :, -CACHE_T:, :, :]
result_fix = x_xla_3[:, :, -min(CACHE_T, x_xla_3.shape[2]):, :, :]
print(f"Original: shape={result_orig.shape}, Fix: shape={result_fix.shape}")
assert result_orig.shape == result_fix.shape == (1, 4, 2, 8, 8)
assert torch.equal(result_orig.cpu(), result_fix.cpu()), "Fix must produce same values when temporal dim >= CACHE_T"

# Also compare CPU vs XLA values for the fix (temporal dim=1)
assert torch.equal(result_cpu_fix, result_xla_fix.cpu()), "CPU and XLA fix must produce same values"

# And CPU vs XLA for temporal dim >= 2
result_cpu_3_orig = x_cpu_3[:, :, -CACHE_T:, :, :]
assert torch.equal(result_cpu_3_orig, result_orig.cpu()), "CPU and XLA must produce same values when temporal dim >= CACHE_T"

print("\nAll checks passed!")

@github-actions github-actions bot added models size/S PR with diff < 50 LOC labels Apr 15, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul requested a review from dg845 April 15, 2026 10:14
@github-actions github-actions bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 16, 2026
Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@github-actions github-actions bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 16, 2026
@sayakpaul sayakpaul merged commit 33a1317 into main Apr 16, 2026
14 of 15 checks passed
@sayakpaul sayakpaul deleted the xla-autoencoder-qwenimage branch April 16, 2026 05:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants