Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.asarray making copies in latest jax #17702

Closed
mattbarrett98 opened this issue Sep 21, 2023 · 6 comments
Closed

jnp.asarray making copies in latest jax #17702

mattbarrett98 opened this issue Sep 21, 2023 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@mattbarrett98
Copy link

Description

import numpy as np
import jax.numpy as jnp
import os, psutil

process = psutil.Process(os.getpid())
print(process.memory_info().rss / 1e9)  # 0.1 GB
x = np.random.uniform(0, 1, (10000, 20000)).astype("float32")
print(process.memory_info().rss / 1e9)  # 0.9 GB
y = jnp.asarray(x)
print(process.memory_info().rss / 1e9)  # 1.7 GB

In version 0.4.16, jnp.asarray now seems to be making copies when called on numpy arrays. In earlier versions, it would just reuse the underlying data and no (little) extra memory would be consumed. Is this intended?

What jax/jaxlib version are you using?

jax, jaxlib 0.4.16

Which accelerator(s) are you using?

CPU

Additional system info

No response

NVIDIA GPU info

No response

@mattbarrett98 mattbarrett98 added the bug Something isn't working label Sep 21, 2023
@hawkinsp
Copy link
Member

Yash, is it possible that this occurs because of the changes to the primitive dispatch path and the fact that we now copy for trivial computations?

@yashk2810
Copy link
Member

Yeah, that sounds correct to me.

We removed the trivial dispatch path from JAX because it was a pre-ominstaging optimization which doesn't have much use post-omnistaging. Also it helped speed up the dispatch path by 100x.

Is the copy causing problems?

@hawkinsp
Copy link
Member

I think not copying (while it has only ever been best effort) is an important optimization.

copybara-service bot pushed a commit that referenced this issue Sep 21, 2023
…ax.Array via device_put to avoid a copy.

Do a similar thing for jax.Array too if dtypes match.

Fixes #17702

PiperOrigin-RevId: 567406208
copybara-service bot pushed a commit that referenced this issue Sep 22, 2023
…ax.Array via device_put to avoid a copy.

Do a similar thing for jax.Array too if dtypes match.

Fixes #17702

PiperOrigin-RevId: 567406208
copybara-service bot pushed a commit that referenced this issue Sep 22, 2023
…ax.Array via device_put to avoid a copy.

Do a similar thing for jax.Array too if dtypes match.

Fixes #17702

PiperOrigin-RevId: 567406208
copybara-service bot pushed a commit that referenced this issue Sep 22, 2023
…ax.Array via device_put to avoid a copy.

Do a similar thing for jax.Array too if dtypes match.

Fixes #17702

PiperOrigin-RevId: 567406208
copybara-service bot pushed a commit that referenced this issue Sep 22, 2023
…ax.Array via device_put to avoid a copy.

Do a similar thing for jax.Array too if dtypes match.

Fixes #17702

PiperOrigin-RevId: 567406208
@yashk2810
Copy link
Member

#17721 should fix

@mattbarrett98
Copy link
Author

Thanks guys! before the next release is there another way of converting without copies?

copybara-service bot pushed a commit that referenced this issue Sep 22, 2023
…ax.Array via device_put to avoid a copy.

Do a similar thing for jax.Array too if dtypes match.

Fixes #17702

PiperOrigin-RevId: 567406208
copybara-service bot pushed a commit that referenced this issue Sep 22, 2023
…ax.Array via device_put to avoid a copy.

Do a similar thing for jax.Array too if dtypes match.

Fixes #17702

PiperOrigin-RevId: 567406208
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 22, 2023

If you want to convert a numpy array to a jax array without copying in version 0.4.16, you can use y = jax.device_put(x)

This will only be avoid copies if using a CPU device, and if your numpy array's byte alignment is compatible with XLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
4 participants