Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: static vs. dynamic shapes and JAX's .at for simulating in-place ops #609

Open
rgommers opened this issue Mar 9, 2023 · 7 comments
Open
Labels
Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Dynamic Shapes Data-dependent shapes. topic: Indexing Array indexing.

Comments

@rgommers
Copy link
Member

rgommers commented Mar 9, 2023

This is a continuation of a discussion that started a few weeks ago in gh-597 (Cc @soraros). It is closely related to gh-84 (boolean indexing) and gh-24 (mutability and copy/views).

I'll copy the content of @soraros's comment here in full:

Start of comment

I also think the problem is more fundamental than that. JAX is essentially a front-end for XLA, and the primitives provided by XLA (for now) require static shape. So the line that actually go wrong is

>>> xs[ix_bool]
array([0, 2, 4])

Note this code does work in JAX, though not jittable, for we don't know its output shape. Let's pretend x[ix_bool] += 1 is syntax sugar for x = x + where(ix_bool, 1, 0) (which works in JAX) for a moment. The same problem appears when we want x[ix_bool] += [1, 3, 5]. Again, we somehow need to know the shape of the rhs, which is equivalent to know the shape of xs[ix_bool] as in the last example.

So what we really work around is the static shape requirement (recall the need of a size parameter for nonzero), which is not exclusively JAX.

Now, for something a bit off-topic.:

I think the JAX style functional syntax a = a.at[...].set(...) for in-place operation looks (and arguably works) better than numpy, and I'd really like to have it for array api. Some pros:

  • Looks familiar, and simulates the feel of in-place operation just fine.
  • Made it clear nothing is modified. This restricted access pattern would work with any accelerator-backed system. I think it would aid static analysis in system like Numba as well.
  • More concise, can be chained, and sometimes express our intention better.
a = zeros(m)       # initialing a
a[I] += arange(n)  # semantically, still initialing a

# VS

# being concise here is not the important point
# this line becomes a "semantical block" for initialisation
a = zeros(m).at[I].add(arange(n))  # initialing a
  • Can specify indexing mode, (more) easily.
# I think these are fairly cumbersome to represent in `numpy`, as we don't have kwargs for __getitem__
b = a.at[I].add(val, unique_indices=True)     # important info for accelerators
c = b.at[J].get(mode='fill', fill_value=nan)  # sure, we have `take`, but this is uniform and cool

Some of my thoughts

  • The last two lines of code, annotating getitem/setitem-like operations with info for accelerators, is an argument that hasn't been made before. If that's something we'd want to support, then this is a way to do it. A context manager would be another way, or like Numba does it (e.g., a boundscheck keyword to @njit).
  • As discussed in Copy-view behaviour and mutating arrays #24, the syntax for x = x.at[... and numpy et al.'s in-place support is completely equivalent when you have a JIT, and numpy's version is more efficient if you don't - as long as you can guarantee that you are not modifying a view. The syntax is also arguably nicer - more concise and more familiar. So, from that perspective, .at isn't ideal.
  • It seems like we do need better static shape support though. The dynamic shape support is marked as optional in the standard, so what's the alternative?

The last point is important. Writing generic code is difficult now when you need, e.g., update values with a mask. Doing that only the JAX way seems like a nonstarter, because it's way too inefficient for NumPy et al. The question though is if there's something that would work for JAX, TF and Dask? Dask also struggles to some extent with dynamic shapes, although most of it now works (xref dask/dask#2000 and dask/dask#7393). @jakirkham any thoughts on whether you need anything more (possibly JAX-like) for dynamic shape support in Dask?

@rgommers
Copy link
Member Author

rgommers commented Mar 9, 2023

For completeness, let me copy the comparison between syntax's from https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html:

image

@asmeurer
Copy link
Member

asmeurer commented Mar 9, 2023

I'm curious if the Jax developers have thought about what would be needed for Python's syntax to allow the more readable x += x[i] but to still have the same functionality. Maybe https://peps.python.org/pep-0637/ (keyword arguments in getitem)? Or would you also need more than that (like a += walrus operator or something, I don't know)?

@jakevdp
Copy link

jakevdp commented Jun 21, 2023

For what it's worth, x += x[i] does work in JAX, and is compatible with JIT. JAX arrays don't override __iadd__, so Python falls back to essentially x = x + x[i].

Since JAX arrays do not have mutable view semantics, this is not at all problematic.

@shoyer
Copy link
Contributor

shoyer commented Jun 21, 2023

I'm curious if the Jax developers have thought about what would be needed for Python's syntax to allow the more readable x += x[i] but to still have the same functionality. Maybe https://peps.python.org/pep-0637/ (keyword arguments in getitem)? Or would you also need more than that (like a += walrus operator or something, I don't know)?

Seems like the tricky case would be x[i] += y. This fails silently with NumPy if x[i] does not create a view. In contrast, x.at[i].add(y) always works.

@kgryte kgryte changed the title Static vs. dynamic shapes and JAX's .at for simulating in-place ops RFC: static vs. dynamic shapes and JAX's .at for simulating in-place ops Apr 4, 2024
@kgryte kgryte added RFC Request for comments. Feature requests and proposed changes. topic: Indexing Array indexing. Needs Discussion Needs further discussion. labels Apr 4, 2024
@lucascolley
Copy link
Contributor

Seems like the tricky case would be x[i] += y. This fails silently with NumPy if x[i] does not create a view. In contrast, x.at[i].add(y) always works.

This seems to be the only unsolved problem for testing JAX arrays inside SciPy over at scipy/scipy#20085. Lots of these cases occur already in the small portion of the code base which has been ported to array API compatibility.

@ogrisel
Copy link

ogrisel commented Aug 9, 2024

Out of curiosity, I tried to run the existing array API test suite of scikit-learn with jax and many of the tests failed because of the inplace assignment limitation of jax making this mostly useless in its current state:

@lucascolley
Copy link
Contributor

For anyone following this issue but not my SciPy JAX PR, scipy/scipy#20085 (comment) is (I think) as far as we got with this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Needs Discussion Needs further discussion. RFC Request for comments. Feature requests and proposed changes. topic: Dynamic Shapes Data-dependent shapes. topic: Indexing Array indexing.
Projects
None yet
Development

No branches or pull requests

7 participants