Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting of size-0 dimensions not implemented #33

Closed
hawkinsp opened this issue Dec 8, 2018 · 0 comments
Closed

Broadcasting of size-0 dimensions not implemented #33

hawkinsp opened this issue Dec 8, 2018 · 0 comments
Assignees

Comments

@hawkinsp
Copy link
Member

hawkinsp commented Dec 8, 2018

Numpy supports broadcasts with size-0 dimensions against size-1 dimensions:

onp.ones([0,1]) + onp.ones([1,128])

produces:

array([], shape=(0, 128), dtype=float64)

However

to_device = jax.jit(lambda x:x)
to_device(np.ones([0,1])) + to_device(np.ones([1,128]))
ValueError: Incompatible shapes for broadcasting: ((0, 1), (1, 128))

The broadcasting rule computes the output shape as

result_shape = onp.max(shapes, axis=0)

but it probably needs to be something like this:

min_shape = onp.min(shapes, axis=0)
max_shape = onp.max(shapes, axis=0)
result_shape = onp.where(min_shape == 0, 0, max_shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant