New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
experiment: add jax.typing
module and use it in jax.lax
#12018
Conversation
jax.typing
module and use it in jax.lax
362a6d9
to
caa4c7a
Compare
0bd2c89
to
6c3044b
Compare
bb62f5e
to
d08a30d
Compare
I think folks are definitely going to get confused between Obviously the former has more restricted aims, but wants to stay within the static type system. Can we see if there's a reasonable way to bring both projects together? For example switched on the (EDIT: jakevdp and I are discussing this offline, and think we have most-of-a-plan.) |
I'm not entirely sure what you have in mind... Do you mean that |
For context, this draft PR is about answering the question "can we use simple annotations that are slightly better than |
381cc77
to
0dfe493
Compare
Why? This is required according to https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases, and something about the change in #12018 resulted in mypy errors from this line. PiperOrigin-RevId: 470025070
Why? This is required according to https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases, and something about the change in #12018 resulted in mypy errors from this line. PiperOrigin-RevId: 470025070
Why? This is required according to https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases, and something about the change in #12018 resulted in mypy errors from this line. PiperOrigin-RevId: 470025070
Why? This is required according to https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases, and something about the change in #12018 resulted in mypy errors from this line. PiperOrigin-RevId: 470046276
59c09ef
to
7993a14
Compare
(closing in favor of #12300) |
Part of #12049
Changes:
jax.typing
module, following roadmap in JEP: Type Annotations #11859core.Shape
definition intojax.typing
jax.numpy.ndarray
definition intojax.typing
jax.lax
, following the style recommended in JEP: Type Annotations #11859jax.lax
types (what a tangled web we weave...)Some observations:
jax.lax
, I usedArrayLike
for inputs andNDArray
for outputs. Originally I hadNDArray
aliased to thendarray
metaclass, but becausejax.lax
has several functions that explicitly returnTracer
types, it seemed necessary to makeNDArray = Union[ndarray, Tracer]
.-> jnp.ndarray
and returning the result of a lax operation caused a typecheck failure (becauseTracer
is inompatible), so it meant that in many places I had to change-> jnp.ndarray
to-> NDarray
, which caused even more ripples.-> jnp.ndarray
by convention (with the understanding that it implies compatiblity with tracers) and use#type: ignore
statements where tracers are explicitly returned.Happy to hear any thoughts/opinions on this.