Skip to content

Fix torchax tensor flatten bug on empty tensors#93

Merged
wdhongtw merged 1 commit into
google:mainfrom
cychiuak:main
May 11, 2026
Merged

Fix torchax tensor flatten bug on empty tensors#93
wdhongtw merged 1 commit into
google:mainfrom
cychiuak:main

Conversation

@cychiuak
Copy link
Copy Markdown
Contributor

@cychiuak cychiuak commented May 6, 2026

Flatten on tensor of shape contains zero is supported in PyTorch, but not through torchax.

import torch
import torchax

with torchax.default_env():
    value = torch.ones((16, 0))
    print(value.flatten(0, -2).shape) # works in CPU > torch.Size([16, 0])
    value = value.to(device='jax')
    print(value.flatten(0, -2)) # throws exception

This is due to jax.numpy fails to calculate the correct dimension when input dimension contains 0 and -1 at the same time. It returns a division by zero error.

import jax.numpy as jnp

print(jnp.reshape(jnp.ones((2, 3)), (-1, 3)).shape) # works > (2, 3)
value = jnp.ones((16, 0))
print(jnp.reshape(value, (16, 0)).shape) # works > (16, 0)
print(jnp.reshape(value, (-1, 0)).shape) # should work but throws exception

The new fix avoids sending -1 into jnp.shape by explicitly calculating the flattening dimension to to fix this error:

flattened_size = math.prod(self._elem.shape[start_dim : end_dim + 1])
new_shape = self._elem.shape[:start_dim] + (flattened_size,) + self._elem.shape[end_dim + 1 :]

Copy link
Copy Markdown
Collaborator

@wdhongtw wdhongtw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! 👍

Just a few suggestions.

Comment thread test/test_functions.py Outdated
Comment thread test/test_functions.py Outdated
Comment thread test/test_functions.py
@cychiuak cychiuak force-pushed the main branch 4 times, most recently from 1736f4e to 35e5b9b Compare May 8, 2026 07:15
Signed-off-by: Anderson Chiu <andersonchiu@google.com>
@wdhongtw wdhongtw added the bug Something isn't working label May 11, 2026
@wdhongtw wdhongtw self-assigned this May 11, 2026
@wdhongtw wdhongtw changed the title [BUG FIX] Fix torchax tensor flatten bug on empty tensors Fix torchax tensor flatten bug on empty tensors May 11, 2026
@wdhongtw wdhongtw merged commit 7605a49 into google:main May 11, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants