Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for computing the cumulative sum to the standard #597

Closed
steff456 opened this issue Feb 14, 2023 · 20 comments · Fixed by #653
Closed

Add support for computing the cumulative sum to the standard #597

steff456 opened this issue Feb 14, 2023 · 20 comments · Fixed by #653
Labels
API extension Adds new functions or objects to the API.
Milestone

Comments

@steff456
Copy link
Member

This RFC requests to include a new API in the array API specification for the purpose of computing the cumulative sum.

Overview

Based on array comparison data, the API is available in all the libraries in the PyData ecosystem.

Prior art

Proposal:

def cumsum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None) -> array
  • dtype kwarg is for consistency with sum et al

cc @oleksandr-pavlyk

@steff456 steff456 added the API extension Adds new functions or objects to the API. label Feb 14, 2023
@asmeurer
Copy link
Member

There's also np.add.accumulate which may be relevant when looking at the comparison data.

@rgommers
Copy link
Member

Based on previous experience with complaints about bad naming (see, e.g., scipy/scipy#12924 and https://www.reddit.com/r/programminghorror/comments/j6sd61/i_was_just_looking_at_the_documentation_for/) I would very much prefer not to enshrine cumsum and cumprod (as proposed in gh-598) names in this API standard.

This is pretty niche functionality and arguably not "core" enough to implementing an array library for it to be in this standard at all, so I'd vote for leaving it out completely. In a compat layer it could perhaps be named cumulative_* like SciPy did with a few methods, if it's desired for these functions to be there.

@soraros
Copy link
Contributor

soraros commented Feb 16, 2023

Niche it may be, I still think cumsum is pretty useful, as many indexing tricks depend on it. One of such trick is turning "pauses" into "stairs":

p = jnp.zeros(10, int).at[jnp.array([1, 4, 8])].set(1)
# [0 1 0 0 1 0 0 0 1 0]
s = p.cumsum()
# [0 1 1 1 2 2 2 2 3 3]

This is especially true when one works with statically-shaped system like JAX where these tricks are more or less required. Its usages in the implementation of jax.numpy.nonzero and jax.numpy.repeat are pretty typical. These are also good examples: ex1, ex2, ex3.

@rgommers
Copy link
Member

rgommers commented Feb 16, 2023

This is especially true when one works with statically-shaped system like JAX where these tricks are more or less required.

@soraros it seems like all that usage of cumsum in JAX is just a cumbersome way of working around not having boolean indexing support for in-place ops? And JAX should instead add some other primitive that's more suitable, rather than letting all users manually construct expanded integer index arrays to then use with .at[idx_cumsum].xxx?

# For an array `xs` and a boolean index `conds` of the same shape
# example JAX expression for `jax_filter`:
>>> xs = jnp.arange(5)
>>> conds = jnp.array([True, False, True, False, True])
>>> cumsum = jnp.cumsum(conds)
>>> cumsum
Array([1, 1, 2, 2, 3], dtype=int32)
>>> jnp.zeros_like(xs).at[cumsum - 1].add(jnp.where(conds, xs, 0))
Array([0, 2, 4, 0, 0], dtype=int32)

# NumPy:
>>> xs = np.arange(5)
>>> conds = np.array([True, False, True, False, True])
>>> xs[conds]
array([0, 2, 4])
>>> np.zeros_like(xs) + np.where(conds, xs, 0)  # note: not identical to JAXs `.at`, 0-padding not at the end
array([0, 0, 2, 0, 4])

That kind of JAX code looks very bad, and as a motivator for writing portable code based on a standard I don't think it's a positive. JAX could add support for x[ix_bool] += 1 tomorrow, and imho they should rather than make users write x.at[cumsum(ix_bool)].add(1). It's 100% equivalent. And for the "expanded boolean indexing" where the desired result is [0, 2, 4, 0, 0] as an output, it needs a new builtin function like jnp.bool_index(xs, conds). It has little to do with cumulative sums for data/statistical purposes, it's a type of indexing.

There's a known blocker (brought up by the JAX team before) for adding more in-place operator support beyond what they have now, which is that this standard should be able to better guarantee that += doesn't modify views in NumPy et al. But if we can make that guarantee at some point, then we have what we want here.

@soraros
Copy link
Contributor

soraros commented Feb 16, 2023

@soraros it seems like all that usage of cumsum in JAX is just a cumbersome way of working around not having boolean indexing support for in-place ops?

@rgommers Yes and no. It is a cumbersome workaround, but I also think the problem is more fundamental than that. JAX is essentially a front-end for XLA, and the primitives provided by XLA (for now) require static shape. So the line that actually go wrong is

>>> xs[ix_bool]
array([0, 2, 4])

Note this code does work in JAX, though not jittable, for we don't know its output shape. Let's pretend x[ix_bool] += 1 is syntax sugar for x = x + where(ix_bool, 1, 0) (which works in JAX) for a moment. The same problem appears when we want x[ix_bool] += [1, 3, 5]. Again, we somehow need to know the shape of the rhs, which is equivalent to know the shape of xs[ix_bool] as in the last example.

So what we really work around is the static shape requirement (recall the need of a size parameter for nonzero), which is not exclusively JAX.

Now, for something a bit off-topic.

I think the JAX style functional syntax a = a.at[...].set(...) for in-place operation looks (and arguably works) better than numpy, and I'd really like to have it for array api. Some pros:

  • Looks familiar, and simulates the feel of in-place operation just fine.
  • Made it clear nothing is modified. This restricted access pattern would work with any accelerator-backed system. I think it would aid static analysis in system like Numba as well.
  • More concise, can be chained, and sometimes express our intention better.
a = zeros(m)       # initialing a
a[I] += arange(n)  # semantically, still initialing a

# VS

# being concise here is not the important point
# this line becomes a "semantical block" for initialisation
a = zeros(m).at[I].add(arange(n))  # initialing a
  • Can specify indexing mode, (more) easily.
# I think these are fairly cumbersome to represent in `numpy`, as we don't have kwargs for __getitem__
b = a.at[I].add(val, unique_indices=True)     # important info for accelerators
c = b.at[J].get(mode='fill', fill_value=nan)  # sure, we have `take`, but this is uniform and cool

@rgommers
Copy link
Member

rgommers commented Feb 17, 2023

@soraros thanks for all this detail, it's very interesting actually. I think there's something to be said indeed for .at - especially if JAX can make a new cleaner API for it that doesn't rely on things like explicit use of cumsum. This discussion is effectively a follow-up to gh-84 (EDIT: and gh-24). What do you think about opening a new issue to continue this discussion, especially your "for something a bit off-topic" part?

@soraros
Copy link
Contributor

soraros commented Feb 17, 2023

@rgommers Glad you find the exchange interesting!
gh-84 is indeed interesting, I will read that throughly later.
I'm not sure how to proceed regarding opening a new issue to continue the discussion though (not sure about the exact topic and/or scope you have in mind). If it's not too much trouble, could you open a new issue so I (or maybe you) can move the relevant part there? Thanks in advance!

@asmeurer
Copy link
Member

How do you implement cumulative sum using only array API functions (and without using a Python loop)?

@oleksandr-pavlyk
Copy link
Contributor

oleksandr-pavlyk commented Feb 17, 2023

One inefficient possibility:

In [4]: def cumsum(x):
   ...:     assert x.ndim == 1
   ...:     n = x.shape[0]
   ...:     return tril(ones((n,n,))) @ x
   ...:

In [5]: cumsum(np.array([1,2,3,4,5]))
Out[5]: array([ 1.,  3.,  6., 10., 15.])

@kgryte kgryte changed the title Add cumsum to the standard Add support for computing the cumulative sum to the standard Feb 19, 2023
@rgommers
Copy link
Member

rgommers commented Mar 9, 2023

I'm not sure how to proceed regarding opening a new issue to continue the discussion though (not sure about the exact topic and/or scope you have in mind). If it's not too much trouble, could you open a new issue so I (or maybe you) can move the relevant part there? Thanks in advance!

Done now in gh-609 - sorry for the delay! I spent a lot of time refreshing my memory and trying to put together something more coherent. But it's tricky; the need for new API in JAX to avoid cumsum seems clear, but if it was easy to exactly define how it'd look, the JAX devs would have done that by now I guess:)

@shoyer
Copy link
Contributor

shoyer commented Apr 3, 2023

I think cumulative sum is a rather fundamental array operation and we should include it in the standard. None of the typical reasons for omitting a function from the API standard apply here:

  • It has well defined behavior (aside from NumPy's funny flattening behavior if axis is omitted)
  • It is not particularly hard to implement in a distributed fashion or on accelerators, as evidenced by how it can be found in Dask and JAX.
  • It is not easy to implement in terms of other fundamental operations.

To give a few other examples of how I've used it:

  1. Calculating integrals
  2. Calculating moving window averages
  3. Calculating cumulative probabilities

As a reference point on popularity, I see about twice as many uses of np.cumsum in Google's codebase as for np.roll or np.var.

I'll also second Ralf's point on calling it cumulative_sum rather than cumsum.

@WarrenWeckesser
Copy link

It has well defined behavior...

Before locking in the specification, it would be worthwhile taking a look at numpy/numpy#6044, and numpy/numpy#14542 that is referenced from numpy/numpy#6044. My comment in numpy/numpy#14542 provides some evidence for the usefulness of allowing the result to include a prepended 0. (More generally, for other cumulative operations such as cumprod, the identity of the operation would be prepended.)

@shoyer
Copy link
Contributor

shoyer commented Apr 3, 2023

Before locking in the specification, it would be worthwhile taking a look at numpy/numpy#6044, and numpy/numpy#14542 that is referenced from numpy/numpy#6044. My comment in numpy/numpy#14542 provides some evidence for the usefulness of allowing the result to include a prepended 0. (More generally, for other cumulative operations such as cumprod, the identity of the operation would be prepended.)

I agree, starting with zero (and excluding the last value) is quite useful, and I would definitely support adding an optional argument.

Is it clear that this is a better default behavior? My inclination is that it would not be worth the trouble to break existing code.

@WarrenWeckesser
Copy link

Is it clear that this is a better default behavior? My inclination is that it would not be worth the trouble to break existing code.

I agree, starting with zero is not a better default behavior, if only because of the long history of the existing behavior in numpy. The links to those previous discussions are more for increasing awareness of the interest and usefulness of this option. That is, don't lock in an API that can't be enhanced with an option later. I don't think that is a problem here, but I don't know how hard it will be to add options to a function in the array API specification later. Also, maybe the reminder will inspire someone to push forward with the enhancement in numpy 🤞 or in other libraries.

@rgommers
Copy link
Member

rgommers commented Apr 6, 2023

Good points on the initial value feature. That doesn't look all that hard to push forward in numpy, there doesn't seem to be a blocker other than no one having done the work.

@arogozhnikov
Copy link

arogozhnikov commented Jul 8, 2023

Just opinion:

  • cumsum in my experience is very common, and I'd even vote for cumulative logsumexp (but not cumprod)
  • Prepending zero - had a number of cases when this would be a desired behavior (I'd better have start_with_zero=True as kwarg)

@rgommers
Copy link
Member

Looks like there's enough thumbs-ups for cumulative sum at least, and we should add it now. @kgryte just volunteered to open a PR for it.

@seberg
Copy link
Contributor

seberg commented Jul 13, 2023

On the PR there is currently startpoint= and endpoint=, so I am wondering if anyone here has opinions on naming? (Above was start_with_zero= also).
I don't want to include it here, but in principle it makes sense to have initial=value (distinct or included in a start_with_initial=True/include_initial). I am wondering if there is a clearer name for startpoint=True?

@seberg
Copy link
Contributor

seberg commented Jul 13, 2023

On the NumPy issue numpy/numpy#6044 include_identity and include_final were brought up. I think I lean towards those names being nicer, include_initial would pair better with an initial= argument for reductions, but not sure it is as clear.

@kgryte
Copy link
Contributor

kgryte commented Jul 27, 2023

I've updated the PR to use include_initial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants