Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array API backend #317

Merged
merged 9 commits into from
Nov 23, 2023
Merged

Array API backend #317

merged 9 commits into from
Nov 23, 2023

Conversation

tomwhite
Copy link
Member

Fixes #315

This uses https://github.com/data-apis/array-api-compat which makes NumPy conform to the array API. There are a couple of places where we use functions not in the array API (which fall back to NumPy): take_along_axis for arg reductions, and nan functions.

@@ -19,6 +19,7 @@
float64,
)
from cubed.array_api.linear_algebra_functions import matmul
from cubed.backend_array_api import namespace as nxp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simply import namespace as np? This would minimize the diff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would, but I wanted to signal that this may not be regular NumPy. By using nxp here, it makes it easy to search for np to find places that are still using regular NumPy.

@@ -0,0 +1,23 @@
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm excited to see this as simple place to drop in a Jax array implementation :) Great work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be interesting to see if this works with the recent work in JAX on the Array API.

@tomwhite
Copy link
Member Author

Thanks for taking a look at this @alxmrs!

This has conflicts that I need to resolve, but it also conflicts with the work I've been doing for TensorStore in #187, which is why I opened #322. But maybe it's OK to merge this, and figure out the TensorStore part later?

@TomNicholas
Copy link
Collaborator

This uses https://github.com/data-apis/array-api-compat which makes NumPy conform to the array API.

You might be able to avoid the extra dependency on array-api-compat now that numpy v1.26.0 introduced numpy.array_api.

@TomNicholas
Copy link
Collaborator

But maybe it's OK to merge this, and figure out the TensorStore part later?

I vote merge this and think first about generalizing the array type in cases that don't also require changing the storage layer (e.g. cupy/sparse/pint).

@tomwhite
Copy link
Member Author

tomwhite commented Nov 21, 2023

This uses https://github.com/data-apis/array-api-compat which makes NumPy conform to the array API.

You might be able to avoid the extra dependency on array-api-compat now that numpy v1.26.0 introduced numpy.array_api.

NumPy 1.22 introduced numpy.array_api, 1.26 updates it to the latest spec. The reason for using the compat layer is that it falls back to some functions that we use that are not in the standard yet (like take_along_axis).

If we're worried about the extra dependency then array-api-compat can be vendored, but I'm not sure that's necessary.

@tomwhite tomwhite merged commit e87671d into main Nov 23, 2023
7 checks passed
@tomwhite tomwhite deleted the array-api-backend branch November 23, 2023 12:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use array API for internal array operations
3 participants