Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat to the specification #690

Merged
merged 21 commits into from Feb 22, 2024
Merged

Add repeat to the specification #690

merged 21 commits into from Feb 22, 2024

Conversation

kgryte
Copy link
Contributor

@kgryte kgryte commented Sep 21, 2023

This PR

  • resolves RFC: add support for repeating each element of an array #654 by adding support for repeat to the array API specification.
  • allows repeats to be either an int or an array. NumPy and other inspired libraries and TensorFlow support one-dimensional arrays. NumPy also supports lists and tuples. CuPy docs suggest support for only lists and tuples. PyTorch supports a one-dimensional array; however, there has been discussion (linked to in the linked RFC) preferring sequences over arrays due to synchronization issues. However, it's not clear that providing a sequence of integers is particularly common or useful. In this PR, I've chosen to explicitly type repeats to support int and array. Should sequences be considered acceptable, this can be revisited in a future revision of the Array API standard.
  • adds a data-dependent admonition and allows array libraries to omit array support for the repeats argument.

@kgryte kgryte added API extension Adds new functions or objects to the API. topic: Manipulation Array manipulation and transformation. labels Sep 21, 2023
@kgryte kgryte added this to the v2023 milestone Sep 21, 2023
@kgryte
Copy link
Contributor Author

kgryte commented Oct 19, 2023

I've updated the proposed specification to include a note advising conforming array libraries to include a warning regarding device synchronization if repeats is an array.

@kgryte
Copy link
Contributor Author

kgryte commented Jan 25, 2024

@leofang Would you mind giving this PR a review? I believe this PR addresses the concerns you raised in #654 (comment), but I want to confirm before merging.

@kgryte
Copy link
Contributor Author

kgryte commented Feb 8, 2024

@leofang Pinging in case you missed the above.

Copy link
Contributor

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing this, Athan. Sorry I missed the ping. Took a stab at it, no concerns.

src/array_api_stubs/_draft/manipulation_functions.py Outdated Show resolved Hide resolved
Copy link
Contributor

@leofang leofang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more question: Should we also add a note on "data-dependent output shapes" like what we do for unique*/nonzero?

@kgryte
Copy link
Contributor Author

kgryte commented Feb 13, 2024

@leofang Re: admonition. I am not certain. It's only when repeats is an array that this is an issue correct?

And for that case, we include a note regarding device synchronization. If we add a "data-dependent admonition" here, this would make this API optional, which I am not sure is desirable.

As a point of reference, in tile (#692), we did not include an admonition, even though repetitions can be a tuple of ints. If the Sequence[int] is problematic for repeats, I would imagine that we'd also need to add the admonition to tile.

@leofang
Copy link
Contributor

leofang commented Feb 13, 2024

Re: admonition. I am not certain. It's only when repeats is an array that this is an issue correct?

Doesn't this API always have the output shape determined by the input data (repeats), regardless of how repeats is provided?

@kgryte
Copy link
Contributor Author

kgryte commented Feb 20, 2024

@leofang I added the data-dependent shape admonition. Given that JAX requires a total_repeat_length in order for repeat to be compilable (ref: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.repeat.html), I think you are right that it makes sense to make this API optional due to data-dependent output shapes.

Now that this has been added, I believe that this PR should be ready for another review.

cc @rgommers

@rgommers
Copy link
Member

Doesn't this API always have the output shape determined by the input data (repeats), regardless of how repeats is provided?

Kinda sorta, but I think the "data-dependent shape" admonition is more aimed at the input values of the input array. E.g., the most common usage here will be with a literal int: repeat(x, 2). Which isn't data-dependent.

I played with this a bit with JAX:

>>> import jax
>>> import jax.numpy as jnp

>>> x = jnp.arange(3)
>>> jnp.repeat(x, 2)
Array([0, 0, 1, 1, 2, 2], dtype=int32)
>>> jnp.repeat(x, 2, total_repeat_length=x.size*2)  # the documented way to allow JIT-ing
Array([0, 0, 1, 1, 2, 2], dtype=int32)

>>> def func(x):
...     return jnp.repeat(x, 2)
... 
>>> func(x)
Array([0, 0, 1, 1, 2, 2], dtype=int32)
>>> # It's not actually needed to use `total_repeat_length` if `repeats` is a literal int:
>>> jax.jit(func)(x)
Array([0, 0, 1, 1, 2, 2], dtype=int32)

>>> # It is needed if we make `repeats` data-dependent:
>>> def func(x):
...     return jnp.repeat(x, x[2])
... 
>>> func(x)
Array([0, 0, 1, 1, 2, 2], dtype=int32)
>>> jax.jit(func)(x)
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>

A few conclusions:

  • the main use case of this function, repeats=a_literal_int, is not data-dependent,
  • the second most common use case of this function, repeats=an_int_derived_from_x, is data-dependent,
  • if we do want to include repeats as an array, that is almost always data-dependent
  • Now that I've played with this some more: in case repeats is not an integer, it has to have the length of the number of elements of the first input array. The argument is inherently an array. Both JAX and PyTorch don't allow list/tuple, and while CuPy does allow (in fact, require) a sequence, its argument doesn't seem strong because the main use case requires repeats=x.tolist() (workaround from Support array objects as the repeats argument to cupy.repeat cupy/cupy#3849) which gains nothing.
    • A hardcoded repeat(x, [2, 3, 4, 5]) is highly unusual. That code know that x is of length 4 - may happen in tests, but is unlikely to be useful in production code.
    • This kind of usage is expected more (from SciPy code):
# from scipy.signal tests
repeats[1::2] = x[1::2]
x = np.repeat(x, repeats)

# from scipy.integrate functionality
diff.data /= np.repeat(h, np.diff(diff.indptr))

My suggested resolution:

  • do not allow sequence input for repeats, only integer or array
  • do include the data-dependent note, with an explanation of when this occurs
  • state that this API must always be implemented for the non-data-dependent part (repeats is a literal int), and only the data-dependent part is optional

@kgryte
Copy link
Contributor Author

kgryte commented Feb 20, 2024

@rgommers We discussed making repeats a sequence previously (result from that discussion: #654 (comment)). In that thread, Mario considered it undesirable for repeats to be an array and would prefer PyTorch accept a list/sequence due to device synchronization.

@rgommers
Copy link
Member

rgommers commented Feb 20, 2024

We discussed making repeats a sequence previously (result from that discussion: #654 (comment)). In that thread, Mario considered it undesirable for repeats to be an array and would prefer PyTorch accept a list/sequence due to device synchronization.

I'd say I agree with Leo's comments in that thread. There's just not much of a point of it being a sequence. It is not like you can do repeat(x, [2, 3]) and have that mean some kind of generalization of repeat(x, 2). The 2 here is not an axis/shape, it's shorthand for repeat(x, [2] * x.size). EDIT: the language around repeats broadcasting also gives away that it is fundamentally an array.

NumPy also documents it as an ndarray; the only reason a sequence works for NumPy is because it calls asarray on all its array-like inputs.

@leofang
Copy link
Contributor

leofang commented Feb 20, 2024

Q: Does it make sense to say "when the input is an array, this API has data-dependent output shape" and followed by the note?

@rgommers
Copy link
Member

Q: Does it make sense to say "when the input is an array, this API has data-dependent output shape" and followed by the note?

I think so. I'd generalize it slightly - maybe "unless the repeats keyword contains a literal int, this API has data-dependent output shape"?

@kgryte
Copy link
Contributor Author

kgryte commented Feb 21, 2024

Okay. I've dropped support for sequences and updated the admonition. The admonition now only allows optional support for providing an array as the second argument; all conforming libraries must support providing an integer.

As providing sequences is still controversial and can be added in a subsequent revision of the standard, I think it is fine to omit for the time being.

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kgryte. The repeats treatment LGTM. Two other minor comments.

Copy link
Member

@rgommers rgommers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now. I think this should be good to go; will aim to merge this at the end of today unless there are new comments.

@rgommers rgommers merged commit 9d200ea into data-apis:main Feb 22, 2024
3 checks passed
@rgommers
Copy link
Member

Thanks @kgryte & all reviewers!

@kgryte kgryte deleted the repeat branch February 22, 2024 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API. topic: Manipulation Array manipulation and transformation.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RFC: add support for repeating each element of an array
3 participants