We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numpy supports broadcasts with size-0 dimensions against size-1 dimensions:
onp.ones([0,1]) + onp.ones([1,128])
produces:
array([], shape=(0, 128), dtype=float64)
However
to_device = jax.jit(lambda x:x) to_device(np.ones([0,1])) + to_device(np.ones([1,128])) ValueError: Incompatible shapes for broadcasting: ((0, 1), (1, 128))
The broadcasting rule computes the output shape as
result_shape = onp.max(shapes, axis=0)
but it probably needs to be something like this:
min_shape = onp.min(shapes, axis=0) max_shape = onp.max(shapes, axis=0) result_shape = onp.where(min_shape == 0, 0, max_shape)
The text was updated successfully, but these errors were encountered:
hawkinsp
No branches or pull requests
Numpy supports broadcasts with size-0 dimensions against size-1 dimensions:
produces:
However
The broadcasting rule computes the output shape as
but it probably needs to be something like this:
The text was updated successfully, but these errors were encountered: