-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jnp.asarray making copies in latest jax #17702
Comments
Yash, is it possible that this occurs because of the changes to the primitive dispatch path and the fact that we now copy for trivial computations? |
Yeah, that sounds correct to me. We removed the trivial dispatch path from JAX because it was a pre-ominstaging optimization which doesn't have much use post-omnistaging. Also it helped speed up the dispatch path by 100x. Is the copy causing problems? |
I think not copying (while it has only ever been best effort) is an important optimization. |
…ax.Array via device_put to avoid a copy. Do a similar thing for jax.Array too if dtypes match. Fixes #17702 PiperOrigin-RevId: 567406208
…ax.Array via device_put to avoid a copy. Do a similar thing for jax.Array too if dtypes match. Fixes #17702 PiperOrigin-RevId: 567406208
…ax.Array via device_put to avoid a copy. Do a similar thing for jax.Array too if dtypes match. Fixes #17702 PiperOrigin-RevId: 567406208
…ax.Array via device_put to avoid a copy. Do a similar thing for jax.Array too if dtypes match. Fixes #17702 PiperOrigin-RevId: 567406208
…ax.Array via device_put to avoid a copy. Do a similar thing for jax.Array too if dtypes match. Fixes #17702 PiperOrigin-RevId: 567406208
#17721 should fix |
Thanks guys! before the next release is there another way of converting without copies? |
…ax.Array via device_put to avoid a copy. Do a similar thing for jax.Array too if dtypes match. Fixes #17702 PiperOrigin-RevId: 567406208
…ax.Array via device_put to avoid a copy. Do a similar thing for jax.Array too if dtypes match. Fixes #17702 PiperOrigin-RevId: 567406208
If you want to convert a numpy array to a jax array without copying in version 0.4.16, you can use This will only be avoid copies if using a CPU device, and if your numpy array's byte alignment is compatible with XLA. |
Description
In version 0.4.16,
jnp.asarray
now seems to be making copies when called on numpy arrays. In earlier versions, it would just reuse the underlying data and no (little) extra memory would be consumed. Is this intended?What jax/jaxlib version are you using?
jax, jaxlib 0.4.16
Which accelerator(s) are you using?
CPU
Additional system info
No response
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: