Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we replace Array = Any with Array = ndarray now? #10186

Closed
YouJiacheng opened this issue Apr 7, 2022 · 18 comments
Closed

Can we replace Array = Any with Array = ndarray now? #10186

YouJiacheng opened this issue Apr 7, 2022 · 18 comments
Assignees
Labels
enhancement New feature or request

Comments

@YouJiacheng
Copy link
Contributor

It seems that problems in #943 have been resolved?

@YouJiacheng YouJiacheng added the enhancement New feature or request label Apr 7, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 8, 2022

I don't think it's as simple as that – for example, inputs to JAX functions will accept any array-like object, including Python scalar types, numpy scalar types, numpy arrays, jax arrays/tracers, plus for many functions, arbitrary objects that define a __jax_array__ method. So we'd need to account for those too, which makes it more complicated than simply setting Array = jnp.ndarray.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 8, 2022

I personally prefer Array = Any because its false positive rate is zero 😁

@YouJiacheng
Copy link
Contributor Author

Oh, I forget __jax_array__.
However, IIUC, we can still annotate the return type as ndarray, and its false positive rate is zero as well.
And return type annotation is very helpful for IDE auto-completing and semantic highlighting

@YouJiacheng
Copy link
Contributor Author

@jakevdp WDYT about annotate the return type as ndarray, or even Union[ndarray, Any] (IIUC Union[ndarray, PRNGKeyArray] is enough), which can be very useful for development with IDE/vscode.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 14, 2022

Sure, we could try it, and just be ready to roll it back quickly if it breaks things.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 14, 2022

I have to say though, I have a knee-jerk dislike for using Union[Any, ndarray] when the result is an array. What it tells me is that the static typing system is not actually expressive enough to statically type our API, and if that's necessary I'd prefer to continue to just use Any or leave things undeclared.

@YouJiacheng
Copy link
Contributor Author

@jakevdp
I have an idea: use overload to provide a fully annotated version for the most usual case, and a Any annotated version for corner case.

@NeilGirdhar
Copy link
Contributor

What's Jax's long term plan with this? Perhaps the Array API will provide annotations that can be used in Jax?

@jakevdp jakevdp self-assigned this Jul 6, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Jul 6, 2022

I wrote a design doc a while ago regarding how the project should approach type annotations (surprise, it's not that simple because "type annotations" are a multiply-overloaded concept in Python!)... it got lost in the shuffle. I'll dust it off and try to generate some discussion.

@YouJiacheng
Copy link
Contributor Author

❤️Does the design doc also discuss things mentioned in #10322 ?

@YouJiacheng
Copy link
Contributor Author

BTW, now I know __jax_array__ can be solved by Protocol, i.e. Union[ndarray, JAXArray] where JAXArray is a protocol implemented __jax_array__(self) -> ndarray.

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 7, 2022

It's still unclear whether __jax_array__ is something we want to support broadly; see the discussion in #10065.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 15, 2022

The design doc is coming together here: #11859

In particular, it does answer the main question of this Issue

Can we replace Array = Any with Array = ndarray now?

The answer is No, because we need inputs to functions to be far more flexible than ndarray. Inputs should be ArrayLike, where this accepts jax.numpy.ndarray, numpy.ndarray, numpy scalars, Python scalars, objects with __array__ attributes, etc.

For return values, we can annotate with jnp.ndarray, or perhaps a new jax.typing.NDArray.

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Aug 15, 2022

The answer is No, because we need inputs to functions to be far more flexible than ndarray. Inputs should be ArrayLike, where this accepts jax.numpy.ndarray, numpy.ndarray, numpy scalars, Python scalars, objects with array attributes, etc.

Would it be possible to create a type alias for jax.numpy.typing.ArrayLike like the one in Numpy?

@YouJiacheng
Copy link
Contributor Author

actually in jax, the problem is simpler than numpy, since jax.numpy is not designed to be compatible with sequence, e.g. list. It is not too bad to use an union of scalar, ndarray and supportsjaxarray.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 15, 2022

Would it be possible to create a type alias for jax.numpy.typing.ArrayLike like the one in Numpy?

Yes, the draft doc I linked to mentions doing that in the roadmap.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 15, 2022

I'm going to close this issue, because I think the work going forward is better outlined and tracked elsewhere. If you have thoughts about the type annotation roadmap, please feel free to comment on #11859

@jakevdp jakevdp closed this as completed Aug 15, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Oct 7, 2022

Hi @YouJiacheng - if you're still interested in this, we've started the process of doing these sorts of annotations in #12049. See the attached PRs for a few examples. Sorry it took such a long time to get there, but we'd love your help now if you'd like to contribute!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants