-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make device_array.copy() return a device array #10069
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
jakevdp
force-pushed
the
devicearray-copy
branch
from
March 29, 2022 17:33
e63f61c
to
de9a948
Compare
mattjj
approved these changes
Mar 29, 2022
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds right to me. It's reversing an ancient API decision, so it might break some very old (and hopefully easily fixable) code, but I think jax.device_get
or np.array(x)
is the right way to spell "give me a NumPy array".
google-ml-butler
bot
added
kokoro:force-run
pull ready
Ready for copybara import and testing
labels
Mar 29, 2022
jakevdp
added a commit
to jakevdp/jax
that referenced
this pull request
Mar 30, 2022
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 30, 2022
PiperOrigin-RevId: 438323845
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 31, 2022
PiperOrigin-RevId: 438323845
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 31, 2022
PiperOrigin-RevId: 438323845
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 31, 2022
PiperOrigin-RevId: 438323845
copybara-service bot
pushed a commit
that referenced
this pull request
Mar 31, 2022
PiperOrigin-RevId: 438323845
copybara-service bot
pushed a commit
to google/trax
that referenced
this pull request
Apr 1, 2022
In google/jax#10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is now np.asarray(device_array). PiperOrigin-RevId: 438711926
copybara-service bot
pushed a commit
to google/trax
that referenced
this pull request
Apr 1, 2022
In google/jax#10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is to explicitly call np.asarray(device_array). PiperOrigin-RevId: 438718784
copybara-service bot
pushed a commit
to google/trax
that referenced
this pull request
Apr 4, 2022
In google/jax#10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is to explicitly call np.asarray(device_array). PiperOrigin-RevId: 438718784
copybara-service bot
pushed a commit
to google/trax
that referenced
this pull request
Apr 4, 2022
In google/jax#10069, JAX changes the behavior of DeviceArray.copy() so that it returns a DeviceArray rather than returning a numpy array. For converting a DeviceArray to numpy, the preferred method is to explicitly call np.asarray(device_array). PiperOrigin-RevId: 439338425
copybara-service bot
pushed a commit
that referenced
this pull request
Apr 4, 2022
PiperOrigin-RevId: 438323845
copybara-service bot
pushed a commit
that referenced
this pull request
Apr 4, 2022
PiperOrigin-RevId: 438323845
copybara-service bot
pushed a commit
that referenced
this pull request
Apr 4, 2022
PiperOrigin-RevId: 439381161
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Addresses part of #2632
In #9732 we made
jnp.copy
return a device array rather than a numpy array. This gives the same treatment to thecopy()
method of device arrays.I think this is the behavior that we want – note that it will not chgne the behavior of
copy.copy()
orcopy.deepcopy()
, which current return numpy arrays due to their use of thereduce()
method (similar topickle
). We could address those separately via the__copy__
and__deepcopy__
methods if we wish.