Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE]: Implement support for Python Array API Standard. #1244

Open
imh opened this issue Feb 1, 2024 · 19 comments
Open

[FEATURE]: Implement support for Python Array API Standard. #1244

imh opened this issue Feb 1, 2024 · 19 comments
Labels

Comments

@imh
Copy link

imh commented Feb 1, 2024

Description

JAX includes a numpy compatible jax.numpy module which has a bunch of nice features (automatic differentiation, jit compilation, vectorized mapping, GPU runtime, js export). They've taken great pains to make sure it's usually as simple as swapping import numpy as np for import jax.numpy as np. LIkewise (but less extensively) for the jax.scipy module.

I'd like to do some optimization for which it would be really convenient to automatically differentiate some of the great stuff you've implemented and export it to js. It should be as simple as changing import numpy as np around the library:

if HOWEVER_WE_SET_THE_CONFIG:
    import jax.numpy as np
else:
    import numpy as np

Changing the type signatures probably has more degrees of freedom we can choose, but is basically the same.

I'd be happy to implement it, but don't want to make a PR that you don't want.

I expect that the added maintenance burden would be pretty minimal.

@imh imh added the Feature label Feb 1, 2024
@KelSolaar
Copy link
Member

KelSolaar commented Feb 2, 2024

Hi imh,

This is a great topic and I was thinking to give a crack at Jax again a few weeks ago. So you are quite on point here! :)

I have always wanted to provide GPU support to Colour, and the first thing I tried 3.5 years ago was Jax. You will even see some discussions I had with the developers here: google/jax#3689. We ended up during GSoC using Cupy as it was much more mature. The main issue with it is that it is tied to NVidia GPU, and I don't have a Windows machine nor it is trivial to test with them, so the PR is parked for now and highlighted a few interesting things.

Suffice to say that I actually did start prototyping with Jax: 69f8d3b#diff-3c8cad174551c24c304eefcdbb2a880c3c7e4ef58f83df1d85af343a6fd3f194R1

One of the core problem was that the Jax import was incredibly slow but besides that it kinda worked. I'm certainly happy to give that a new crack but we would go down the backend road as it would allow us to swap Jax for Cupy or something else.

Keen to hear your thoughts!

@imh
Copy link
Author

imh commented Feb 2, 2024

Okay, after posting this, I did some reading about numpy's NEPs and other library's takes on this kind of thing. tl;dr there's a long history of proposals with the python array api standard finally gaining traction, which may be a good candidate here:

Instead of this:

def user_facing_function(x: ArrayLike):
    return np.exp(x)

we do this:

def user_facing_function(x: ImNotSureProbablyAProtocol):
    xp = get_namespace(x)
    return xp.exp(x)

Because support for the standard is generally in experimental state for most libraries, get_namespace should depend on an opt-in config along these lines:

def get_namespace(*xs):
    if not USE_ARRAY_API:
        return np
    
    # The rest of this is would be exactly the implementation suggested in [NEP 47](https://numpy.org/neps/nep-0047-array-api-standard.html#appendix-a-possible-get-namespace-implementation)
    # or it would use https://github.com/data-apis/array-api-compat?tab=readme-ov-file
    # as `return array_api_compat.array_namespace(x, y)`
    ...

Again, I'm happy to implement this if agree with the direction.


In more detail:

The design choices appendix in NEP 37 lays out a lot of the accumulated thinking around them (opt in vs opt out, explicit vs implicit, local vs non-local vs global), which is probably useful for what colour might choose to do. To make colour composable with user code, local control via call time xp = np.get_array_module(...) or xp = np.get_namespace(...) seems preferable to a config flag.

NEP 47 (the array API standard) seems to be the one with the most traction across the numeric python ecosystem, so I'm recommending that.

So far, both the array API standard as well as the implementations in numpy, jax, cupy, dask, etc are all still experimental, so it seems good to keep it opt-in, like sklearn is currently doing it.

However, static typing could be more complex if colour does this (static typing design doc in the API spec). A protocol should suffice, but it may not be as clean as you have it right now.

This way does increase the complexity and maintenance burden a bit, relative to just changing imports, but still seems worthwhile.

@KelSolaar
Copy link
Member

Thanks for digging deeper, I came across some of those in recent years but haven't checked recently admittedly!
The sklearn approach seems sensible. A cursory look at skimage did not show anything going into that direction though, I might missed it!

Most of our functions are starting with a call to colour.utilities.as_float_array and colour.utilities.as_int_array definitions. They could be a good entry point for the namespace jazz. I'm certainly happy for you to give it a go!

Paging @tjdcs for VIS!

Cheers,

Thomas

@tjdcs
Copy link
Contributor

tjdcs commented Feb 6, 2024

Thanks you know I love performance.

Someone told me about jaxtyping the other day supporting annotation for array shape. I'm not sure if it is true, but it would be very helpful.

https://jax.readthedocs.io/en/latest/jax.typing.html

One thing that concerns me long term for maintainability is putting a lot of configurable dependency switches in. Maybe we should consider just completely switching... Or depending on both.... Not sure that the latter is really the best but it might be required anyway.

@imh
Copy link
Author

imh commented Feb 8, 2024

FYI I'm not dropping this, but I'll return to it once data-apis/array-api-compat#83 is in, since it seems like the simplest, closest to "standard" way to do it.

@lucascolley
Copy link

lucascolley commented Feb 9, 2024

One thing that concerns me long term for maintainability is putting a lot of configurable dependency switches in.

It's worth mentioning that writing "array-agnostic" code like @imh 's example with xp = get_namespace(...) (get_namespace is now a legacy alias for array_namespace btw), does not introduce a dependency on other libraries like JAX. The idea is that all of the functions that are needed are contained in the xp namespace, which comes from input arrays, rather than being imported at this level. A namespace that complies with the standard will 'just work', without colour even needing to know which library it is.

For now, the only additional dependency would be array-api-compat, which bridges the gap between the current implementations and full compliance with the standard, but eventually that will not be needed and the namespaces will be accessible directly via x.__array_namespace__.

If you'd like more info on the standard, see this talk from SciPy 2023, and my blog post describing how we're using it in SciPy might be helpful to see the perspective of an 'array-consumer' library. (and feel free to ask me any questions!)

@KelSolaar
Copy link
Member

Thanks @lucascolley, fantastic insight! Will read/watch your links.

@KelSolaar
Copy link
Member

I watched the presentation and that was super helpful! I foresee a few issues and I'm not sure how to solve them with what offers array-api-compat :

Functions Accepting Scalar Input (or ArrayLike)

We have plenty of functions that accept scalar (and more generally ArrayLike) input and this obviously does not work:

import array_api_compat


def gamma_function(x, y):
    xp = array_api_compat.array_namespace(x, y)

    return xp.power(x, 1 / y)


gamma_function(0.18, 2.2)
gamma_function(0.18, [2.0, 2.1, 2.2, 2.3])

It means that we would need to now do that:

import numpy as np
import array_api_compat


def gamma_function(x, y):
    xp = array_api_compat.array_namespace(x, y)

    return xp.power(x, 1 / y)

gamma_function(np.array(0.18), np.array(2.2))
gamma_function(np.array(0.18), np.array([2.0, 2.1, 2.2, 2.3]))

Which for a user a significant loss in user experience, not only that but that would break a TON of downstream dependent code. This slide surprised me because it seems like the focus was put on the low-level libraries and not the libraries that are using them, e.g., Colour:

image

Datasets

We have a lot of datasets, matrices of all sorts that are imported:

CAT_VON_KRIES: NDArrayFloat = np.array(
    [
        [0.4002400, 0.7076000, -0.0808100],
        [-0.2263000, 1.1653200, 0.0457000],
        [0.0000000, 0.0000000, 0.9182200],
    ]
)
"""
*Von Kries* chromatic adaptation transform.

References
----------
:cite:`CIETC1-321994b`, :cite:`Fairchild2013ba`, :cite:`Lindbloom2009g`,
:cite:`Nayatani1995a`
"""

Those are using Numpy at the moment but how to handle those with array_api_compat? This was actually one of the big problems when we tried to adopt Cupy.

Is there a better place to discuss about those? They might merit more visibility as I'm sure we are not/will not be the only package with such questions!

@lucascolley
Copy link

lucascolley commented Feb 10, 2024

We have plenty of functions that accept scalar (and more generally ArrayLike) input and this obviously does not work

Here's how we do it in SciPy: we want to keep accepting array-likes (which often end up being turned into np arrays during computation) with the NumPy backend, but not for alternative backends. If you want to use a JAX array, you must pass a proper array. But other unrecognised inputs are used with the NumPy backend.

We do that by wrapping array_namespace as follows:

def array_namespace(*arrays):
    if not _GLOBAL_CONFIG["SCIPY_ARRAY_API"]:
        # here we could wrap the namespace if needed
        return np_compat

    arrays = [array for array in arrays if array is not None]

    arrays = compliance_scipy(arrays)

    return array_api_compat.array_namespace(*arrays)

def compliance_scipy(arrays):
    for i in range(len(arrays)):
        array = arrays[i]
        if isinstance(array, np.ma.MaskedArray):
            raise TypeError("Inputs of type `numpy.ma.MaskedArray` are not supported.")
        elif isinstance(array, np.matrix):
            raise TypeError("Inputs of type `numpy.matrix` are not supported.")
        if isinstance(array, (np.ndarray, np.generic)):
            dtype = array.dtype
            if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
                raise TypeError(f"An argument has dtype `{dtype!r}`; "
                                f"only boolean and numerical dtypes are supported.")
        elif not is_array_api_obj(array):
            try:
                array = np.asanyarray(array)
            except TypeError:
                raise TypeError("An argument is neither array API compatible nor "
                                "coercible by NumPy.")
            dtype = array.dtype
            if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)):
                message = (
                    f"An argument was coerced to an unsupported dtype `{dtype!r}`; "
                    f"only boolean and numerical dtypes are supported."
                )
                raise TypeError(message)
            arrays[i] = array
    return arrays

This means that array_api_compat.numpy is always returned unless the experimental environment variable is set, maintaining current behaviour for all array types for the time being. When we do set the env variable, we try to coerce array-likes with np.asanyarray before passing to array_api_compat.

This treatment of NumPy as the default backend allows you to have two code branches in functions, based on whether xp is (the compat version of) NumPy. In SciPy, since we have a lot of compiled code which only works with NumPy for now, this is needed quite a lot, but for pure Python + NumPy code, most of it should be easily convertible to array-agnostic code.

We have a lot of datasets, matrices of all sorts that are imported (using np.array)

While not guaranteed by the standard, every library we have worked with so far can coerce np arrays with xp.asarray. That is good enough for now - at some point in the future, when the goal is really to be as portable as possible, the standard includes some specification of device-interchange with DLPack.

If performance overhead is the concern, I suppose (just brainstorming) you could have a function like get_dataset(dataset_name, xp) which returns xp.asarray({dataset data}). As mentioned above, this would complicate things for static typing, so converting from NumPy may be the best option for now. thinking more, this would still involve a device copy. If you want to boost performance for a specific library, the best thing would probably be to use that library's creation functions conditionally on the namespace and device of the input.

@lucascolley
Copy link

lucascolley commented Feb 10, 2024

Is there a better place to discuss about those? They might merit more visibility as I'm sure we are not/will not be the only package with such questions!

There definitely needs to be some guidance on topics like this (perhaps a follow-up to my blog post), but it is still very early days in adoption. Dare I say you are a bit ahead of the curve here 😉 - packages which depend on the "low-level libraries" will have to wait for those libraries to become compatible before being able to become fully compatible themselves.

In SciPy, we are still figuring out the best way to go about things, and will likely upstream tools/helpers which are generally useful across the NumPy ecosystem to a new repo at some point. Once the foundations are in place and settled in the core libraries, it will be easier to give advice to downstream packages.

That said, feel free to open an issue on the array-api-compat or array-api repos if you'd prefer to discuss over there!

@tjdcs tjdcs pinned this issue Feb 12, 2024
@tjdcs
Copy link
Contributor

tjdcs commented Feb 12, 2024

I've gone ahead an pinned this issue. In general we are heavy users of array "stuff". It's probably 90% of our code in someway, and I think this is a really important and exciting path. Being able to support multiple back-ends would be very powerful.

With that said... I also don't want perfect to be the enemy of the good. Maybe it's worth continuing to pursue JAX integration and updateing more of our numpy specific code? Thoughts @KelSolaar @imh

@KelSolaar
Copy link
Member

Thanks @lucascolley!

@tjdcs: The Scipy approach looks sensible to me as it ensures that we are not breaking backward compatibility which is something I would like to avoid at all cost. There is too much code dependent on Colour that we need to be a bit careful. It seems like the stars are better aligned to start working on that compared to 4 years ago, so I'm keen to explore that more for sure.

@imh
Copy link
Author

imh commented Feb 13, 2024

+1 to the thanks, @lucascolley!

As far as arraylike datasets go, it seems like they should be promoted into whatever the user is using, since they're originating outside array-api-land. For example, Pointer's gamut remains a numpy array where it's defined, and if the user passes in jax arrays, then the gamut constant gets promoted from a numpy array to a jax array.

For functions accepting arraylike, it's too bad we can't reasonably tell which arguments are deliberately set by the user and which are just default parameters, otherwise we'd probably still want promote to whatever the user deliberately passed in. We could approximate that with an array_namespace that promotes as follows (pure python --> numpy --> other array api). For example:

  • array_namespace(jax_array, python_list) = jax api
  • array_namespace(dask_array, numpy_array) = dask api
  • array_namespace(python_list, numpy_array) = numpy api
  • array_namespace(jax_array, cupy_array) = error

I have a nebulous bad feeling about that hiding user errors though, so we could alternately just promote (pure python --> array api) and handle constants separately:

  • array_namespace(jax_array, python_list) = jax api
  • array_namespace(dask_array, numpy_array) = error
  • array_namespace(python_list, numpy_array) = numpy api
  • array_namespace(jax_array, cupy_array) = error

@lucascolley
Copy link

lucascolley commented Feb 13, 2024

In SciPy it is more simple. If you want one of your array inputs to be a JAX array, they all must be (see e.g. scipy/scipy#18286 (comment)).

These array inputs tend to not have default parameters. Where they do, it should be possible to change the default to a new object in a backwards-compatible way, so that you can distinguish between default input and genuine Python-list / None input.

@lucascolley
Copy link

It would be worth giving scipy/scipy#18286 a look if you are considering going down the array API route. Clearly there is a lot more complexity for a package like SciPy, but it might be helpful to spell out the overall aims/strategy.

@asmeurer
Copy link

asmeurer commented Mar 15, 2024

A few comments on some things discussed here:

Which for a user a significant loss in user experience, not only that but that would break a TON of downstream dependent code. This slide surprised me because it seems like the focus was put on the low-level libraries and not the libraries that are using them, e.g., Colour:

I wouldn't worry too much about this. The data I discussed on this slide was primarily used to get a reasonable list of APIs for inclusion in the initial version of the standard. Since then, more APIs and behaviors have been added to the standard based on user feedback. If something you need is missing, open an issue in the array-api repo.

While not guaranteed by the standard, every library we have worked with so far can coerce np arrays with xp.asarray. That is good enough for now - at some point in the future, when the goal is really to be as portable as possible, the standard includes some specification of device-interchange with DLPack.

asarray does support inputs that support the buffer protocol. There's also DLPack, which is probably preferable as the buffer protocol is CPU only (see https://data-apis.org/array-api/latest/design_topics/data_interchange.html).

I'm curious how scipy handles this. Does scipy not also have hard-coded data, or is that only in C and Fortran codes like fft?

We have plenty of functions that accept scalar (and more generally ArrayLike) input and this obviously does not work:

The way I see it is this: if a user is using a library that doesn't accept array-like inputs (for example, torch functions generally do not allow lists as inputs), then they should already have this expectation that everything should be an array first, and will carry that expectation to colour. If they are using a library like numpy that does allow it, then all the functions in array-api-compat will allow that (array-api-compat tries to maintain library behaviors that aren't required by the standard), so passing a list will work. That's with the minor caveat that you would need a wrapper like scipy's to make list-only arguments to array_namespace default to numpy (maybe we should add a default_library flag to array_namespace).

Of course, this does mean your internal function calls and tests will need to be a little more rigorous about calling asarray first. But that's really just a special case of a general fact, which is that if you want to support the array API in colour, you will need to only use array API-compatible APIs and behaviors everywhere. Functions only accepting array as inputs is one instance of that, but there are many others that you would need to update your code for as well (this slide from my presentation gives an idea of the sorts of changes typically required).

@lucascolley
Copy link

I'm curious how scipy handles this

I'm not sure whether this has come up in public API functions yet. In tests, we just create NumPy arrays and convert (copy if a different device) with xp.asarray.

@KelSolaar KelSolaar changed the title [FEATURE]: Allow jax.numpy as an alternative to numpy [FEATURE]: Implement support for Python Array API Standard. Apr 6, 2024
@KelSolaar
Copy link
Member

@imh : Have you put some more thoughts into all this by any chance?

@lucascolley
Copy link

I had a brief look into the Colour codebase today. My one thought was that you may want to wait for a resolution to data-apis/array-api#589, given that everything here is statically typed.

I also didn't spy any parametrized tests, which are the easiest way to test multiple xp backends without duplicating test code. Not sure whether this would require anything extra to work with unittest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants