Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of the Python Array API standard #16099

Merged
merged 1 commit into from
Nov 16, 2023

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented May 23, 2023

Part of #18353

Usage:

from jax.experimental import array_api as xp

And then xp is the Array API namespace backed by JAX.

This initial implementation still has some missing features (see array-api-skips.txt for examples of known failures) and so I'm not yet making it available via jax.Array.__array_namespace__ unless jax.experimental.array_api is first explicitly imported.

@jakevdp jakevdp force-pushed the array-api branch 2 times, most recently from 2932adb to 87ff9c7 Compare May 23, 2023 23:44
@jakevdp jakevdp marked this pull request as draft May 23, 2023 23:46
@jakevdp jakevdp self-assigned this May 23, 2023
@jakevdp jakevdp force-pushed the array-api branch 2 times, most recently from 596164b to b35c628 Compare May 24, 2023 17:47
@jakevdp
Copy link
Collaborator Author

jakevdp commented May 24, 2023

It looks like there's a bug in the 2022.12 release of array-api; I filed an issue here: data-apis/array-api#631

@asmeurer
Copy link

Nice to see this work. Feel free to reach out if you have any questions on the spec or the test suite.

@asmeurer
Copy link

I notice that you're using a separate namespace here. I'm curious what your rationale is for that. Is it mostly just so that you can experiment without having to worry about breaking things?

I would recommend aiming to make the main jax (or jax.numpy) namespace be array API compliant. Our experience with NumPy and numpy.array_api is that that way is much better. numpy.array_api will still remain as a separate namespace because it's useful as a minimal implementation, but this is something that only needs to be implemented once. Other libraries like JAX don't really need to try to making things minimal and restrict what they allow beyond what they already do. In practice, numpy.array_api is only useful for libraries to test that they are not deviating from the standard. Actual end user usage uses the main numpy namespace and array object.

I should also point out that we have the array-api-compat library, https://github.com/data-apis/array-api-compat, which can be used to provide a compatibility layer if there are places where JAX deviates from the standard and cannot easily change because of backwards compatibility concerns. It already supports NumPy, CuPy, and PyTorch, and is being used by scikit-learn and SciPy (and hopefully others soon).

@jakevdp
Copy link
Collaborator Author

jakevdp commented May 24, 2023

I'm using a separate namespace for ease of experimentation. Making the main namespace compliant will be a much larger project because of the number of existing behaviors that will have to be deprecated. (I suspect this is the same reason e.g. numpy first implemented the array API in numpy.array_api rather than in the main namespace).

@jakevdp
Copy link
Collaborator Author

jakevdp commented May 24, 2023

In any case, it's not clear that JAX can be made compatible with the array-api at all, because nearly every test in array-api-tests fails when it tries to mutate an array here: https://github.com/HypothesisWorks/hypothesis/blob/661af850bbfcb091a820eb16c84f56048c8e21c8/hypothesis-python/src/hypothesis/extra/array_api.py#L448-L454

Is mutation a necessary feature of the array API standard? If so then JAX is basically disqualified entirely. If not, then perhaps this is a bug in hypothesis?

@jakevdp
Copy link
Collaborator Author

jakevdp commented May 24, 2023

We are running into other issues as well; e.g. jax.Array has a device() method that when called returns the device that the array lives on.

In the array API standard, device must be a property rather than a method.

@asmeurer
Copy link

Mutation is not a requirement of the array API. See https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html. This was done specifically so that libraries like JAX can be made compliant.

However, it is true that the test suite currently expects mutation to work. This is the first time someone has tried to implement the array API on a library that doesn't allow mutation. So we're going to have to do some work to update it to only do mutation in places where it is actually required.

We are running into other issues as well; e.g. jax.Array has a device() method that when called returns the device that the array lives on.

In the array API standard, device must be a property rather than a method.

That's a harder thing for sure (the good news is that at least this shouldn't be an issue for the test suite, as device support isn't really tested right now). PyTorch has a similar issue for size, which it defines as a method instead of a property.

The approach we're using for that is to define a size function in array-api-compat, and have libraries use that instead. It looks like we'll have to do something similar for device once people want to start using JAX. Fortunately, the spec is mostly functional, so this particular compatibility issue is minimized.

@asmeurer
Copy link

CCing @honno on the test suite stuff. Fixing things to not require mutation could be a hard problem, given that it's part of hypothesis itself (maybe hypothesis should instead generate a list of lists and pass it to asarray).

Does your separate namespace here somehow enable mutation? As far as I can tell, you're just reusing the same JAX array object. If not, the test suite will still be unusable for you in its current state.

What you're doing here is somewhat akin to what might go in array_api_compat.jax. I'm open to adding this to array-api-compat, or you can keep it JAX itself (it's up to you). Either way, once this is in a usable state you'll want to either make jax.array.__array_namespace__() return this namespace or else add JAX to https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/common/_helpers.py#L56. Otherwise JAX arrays won't be usable from libraries using the array API, because array_namespace(jax_array) will need to return a array API complaint namespace that can operate on jax_array.

@jakevdp
Copy link
Collaborator Author

jakevdp commented May 24, 2023

Does your separate namespace here somehow enable mutation? As far as I can tell, you're just reusing the same JAX array object. If not, the test suite will still be unusable for you in its current state.

No, there's no meaningful way to make JAX programs support in-place array mutation.

I'm open to adding this to array-api-compat, or you can keep it JAX itself (it's up to you).

Sounds good! I'd propose working off this branch for now, and long-term we can discuss what makes the most sense.

Let me know if there's anything I can do to help get the test suite working for non-mutable objects.

@jakevdp
Copy link
Collaborator Author

jakevdp commented May 24, 2023

@shoyer pointed out that we could fix the device()/device issue by defining Device.__call__ = lambda self: self which is not an entirely terrible idea...

At the very least it gives a viable deprecation path if we want to make that change to the core API.

@asmeurer
Copy link

Assuming your device() doesn't take any arguments I guess that would work. For pytorch .size it's a lot harder because they define size() as the array shape rather than number of elements. If it does take arguments, then maybe that's something we should have considered for the array API.

@honno
Copy link

honno commented May 31, 2023

Note there's no specified way in the array API to check what dtypes a given namespace supports.

Currently in Hypothesis and array-api-tests we essentially use hasattr(<array_namespace>, "<dtype>") (e.g. hasattr(torch, "uint32") == False) to check if a dtype is supported. I see this doesn't work for JAX as the double-precision dtype objects always exist, but when double-precision is not enabled it will lead to single-precision results.

A dirty "fix"—is there a world where jax.experimental.array_api could delete the double-precision dtype objects from its namespace if they're not enabled?

Also commented on this in data-apis/array-api#499 (comment)

@asmeurer
Copy link

asmeurer commented May 31, 2023

A dirty "fix"—is there a world where jax.experimental.array_api could delete the double-precision dtype objects from its namespace if they're not enabled?

This is what I would recommend. Or at least make them error when they are used.

JAX creation functions returning float32 arrays when given dtype=float64 basically goes directly against the standard. So this would not only make things much easier for the test suite (which, outside of the tests for the creation functions themselves, implicitly assumes that they return the correct dtypes; and OTOH it already has logic for skipping undefined dtypes), but it is also better behavior for users of the array API. Otherwise, libraries like SciPy or scikit-learn using the array API would have to work around this behavior. Or more realistically, we would have to work around it in the compat layer. And given that jax.experimental.array_api basically is the compat layer for JAX, we should just do the right thing there.

Making things work for array API consumers should be the top priority. We can always adjust the test suite.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Jun 1, 2023

For what it's worth - my solution to this has been to only run the tests with JAX_ENABLE_X64=True.

@asmeurer
Copy link

asmeurer commented Jun 2, 2023

For what it's worth - my solution to this has been to only run the tests with JAX_ENABLE_X64=True.

I guess libraries like scikit-learn and scipy would need to discuss whether this sort of solution works for them (or indeed whether this is an actual problem at all for them; I guess it depends if there's anywhere where they require float64). The problem with an environment variable is that it can only be set by an end-user.

@asmeurer
Copy link

asmeurer commented Jun 2, 2023

I opened an upstream issue to track test suite support for non-mutation data-apis/array-api-tests#188

@NeilGirdhar
Copy link
Contributor

In any case, it's not clear that JAX can be made compatible with the array-api at all, because nearly every test in array-api-tests fails when it tries to mutate an array here:

Has the Array API team considered adding mutable to the interface as described in @shoyer's comment? This would mean updating the test linked by Jake above to:

                try:
                    rm = result.mutable()
                    rm[i] = val
                except Exception as e:
                    ...

and adding a mutable method to Jax arrays that exposes __setitem__.

@shoyer
Copy link
Member

shoyer commented Jun 21, 2023

and adding a mutable method to Jax arrays that exposes __setitem__.

To be clear, my statement "I think we could make it work in JAX" in the linked comment was my own opinion, not one vetted with the JAX team.

In principle, the object returned by mutable() could look something like the following:

@dataclasses.dataclass
class Mutable:
  value: jax.Array

  def __setitem__(self, key, value):
    self.value = self.value.at[key].set(value)

  def __jax_array__(self):
    return self.value

  # TODO(shoyer): implement array methods

That said, I'm not sure it would actually be worth the cognitive overhead of adding this into JAX. It might be just as sane to imagine adding x.at[i].set(y) into the array standard, as sugar for:

x = x.copy()
x[i] = y

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Jun 21, 2023

That said, I'm not sure it would actually be worth the cognitive overhead of adding this into JAX.

I think it might be worth elaborating on your idea a bit more completely to really evaluate it. Yes, the mutable method is additional cognitive overhead, but it saves the cognitive overhead of using the at method. I wonder what the Array API people would prefer: adding at or adding mutable?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 14, 2023

One wart here is that the array API specifies that arr.device should return a device identifier... however jax.Array already has a device method, such that arr.device() returns a device identifier.

We work around this because the array API defines the expected device type as the type of the object returned by arr.device... so in JAX's case, canonically a device identifier is a Callable that when called with no arguments returns a Device!

A bit messy, but it works, and avoids us having to deprecate the device() method in favor of a property.

@asmeurer
Copy link

The compat library already has a device() helper function https://github.com/data-apis/array-api-compat#helper-functions, so we could make it call x.device() on JAX arrays.

We work around this because the array API defines the expected device type as the type of the object returned by arr.device... so in JAX's case, canonically a device identifier is a Callable that when called with no arguments returns a Device!

The object returned by .device needs to be comparable with ==, and needs to be accepted as the device keyword for creation functions like ones, linspace, etc. If you can make these both work then there is no issue.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 15, 2023

The object returned by .device needs to be comparable with ==

Oh well, there goes my clever hack... how bad is it in practice if device is a method rather than a property? That design choice was made a long time ago, and it would be quite painful to change.

@asmeurer
Copy link

Well right now everyone using the array API is using the device() helper function from the compat library, because numpy doesn't even have .device yet.

@shoyer
Copy link
Member

shoyer commented Nov 15, 2023

What about adding a dummy __call__ method that returns self to jax.Device?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 15, 2023

What about adding a dummy __call__ method that returns self to jax.Device?

Yeah, I think that would be a good way forward. A deeper problem, though, is that in general jax arrays can be sharded across multiple devices, which seems to conflict with the core assumption of the Array API that each array lives on a single device.

@leofang
Copy link

leofang commented Nov 15, 2023

I think it's DLPack assuming a single-device setting. For Array API, .device can be a logical device (that represents multiple physical devices over which the array is sharded). It's OK to use logical devices as long as they are comparable as Aaron noted.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Nov 15, 2023

Complicating this, it looks like there is a TODO to remove the existing jax.device method:

jax/jax/_src/array.py

Lines 444 to 445 in 5c3da21

# TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device:

@jakevdp jakevdp force-pushed the array-api branch 7 times, most recently from ef5bed7 to c35eac2 Compare November 15, 2023 23:09
@jakevdp jakevdp marked this pull request as ready for review November 16, 2023 16:10
@jakevdp jakevdp requested a review from pschuh November 16, 2023 16:10
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Nov 16, 2023
@jakevdp jakevdp force-pushed the array-api branch 3 times, most recently from f11e740 to 421df9f Compare November 16, 2023 22:14
@copybara-service copybara-service bot merged commit 1fbcb24 into google:main Nov 16, 2023
14 of 15 checks passed
@jakevdp jakevdp deleted the array-api branch November 16, 2023 23:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants