Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experiment: add jax.typing module and use it in jax.lax #12018

Closed
wants to merge 2 commits into from

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Aug 19, 2022

Part of #12049

Changes:

  • created jax.typing module, following roadmap in JEP: Type Annotations #11859
  • moved core.Shape definition into jax.typing
  • moved jax.numpy.ndarray definition into jax.typing
  • applied these new types throughout jax.lax, following the style recommended in JEP: Type Annotations #11859
  • fixed annotations in other files that were broken by the new jax.lax types (what a tangled web we weave...)

Some observations:

  • When implementing annotations in jax.lax, I used ArrayLike for inputs and NDArray for outputs. Originally I had NDArray aliased to the ndarray metaclass, but because jax.lax has several functions that explicitly return Tracer types, it seemed necessary to make NDArray = Union[ndarray, Tracer].
  • This change had wide-ranging ripple effects: any function annotated with -> jnp.ndarray and returning the result of a lax operation caused a typecheck failure (because Tracer is inompatible), so it meant that in many places I had to change -> jnp.ndarray to -> NDarray, which caused even more ripples.
  • An alternative would be to use -> jnp.ndarray by convention (with the understanding that it implies compatiblity with tracers) and use #type: ignore statements where tracers are explicitly returned.

Happy to hear any thoughts/opinions on this.

@jakevdp jakevdp changed the title experiment: add typing module and use in jax._src.lax experiment: add jax.typing module and use it in jax.lax Aug 19, 2022
@jakevdp jakevdp force-pushed the typing branch 3 times, most recently from 362a6d9 to caa4c7a Compare August 19, 2022 19:41
@jakevdp jakevdp marked this pull request as draft August 19, 2022 20:19
@jakevdp jakevdp force-pushed the typing branch 5 times, most recently from 0bd2c89 to 6c3044b Compare August 22, 2022 23:05
@jakevdp jakevdp added the pull ready Ready for copybara import and testing label Aug 22, 2022
@jakevdp jakevdp force-pushed the typing branch 2 times, most recently from bb62f5e to d08a30d Compare August 23, 2022 21:34
@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Aug 23, 2022

I think folks are definitely going to get confused between google/jax.typing and google/jaxtyping. Plus I'm very wary of yet more sharding within the JAX ecosystem.

Obviously the former has more restricted aims, but wants to stay within the static type system.

Can we see if there's a reasonable way to bring both projects together? For example switched on the typing.TYPE_CHECKING flag.

(EDIT: jakevdp and I are discussing this offline, and think we have most-of-a-plan.)

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 23, 2022

Can we see if there's a reasonable way to bring both projects together? For example switched on the typing.TYPE_CHECKING flag.

I'm not entirely sure what you have in mind... Do you mean that jax should depend on jaxtyping for static type checks? If so I don't think this is a good idea (I discussed this briefly in the roadmap draft). In my view, what jaxtyping includes is far more complicated than the level of type annotations we want in the core library, at least for now.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Aug 23, 2022

For context, this draft PR is about answering the question "can we use simple annotations that are slightly better than Array = Any without breaking mypy and pytype". My understanding of jaxtyping is that it's far beyond the scope of that question.

@jakevdp jakevdp force-pushed the typing branch 2 times, most recently from 381cc77 to 0dfe493 Compare August 25, 2022 17:28
copybara-service bot pushed a commit that referenced this pull request Aug 25, 2022
Why? This is required according to https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases, and something about the change in #12018 resulted in mypy errors from this line.

PiperOrigin-RevId: 470025070
@copybara-service copybara-service bot mentioned this pull request Aug 25, 2022
copybara-service bot pushed a commit that referenced this pull request Aug 25, 2022
Why? This is required according to https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases, and something about the change in #12018 resulted in mypy errors from this line.

PiperOrigin-RevId: 470025070
copybara-service bot pushed a commit that referenced this pull request Aug 25, 2022
Why? This is required according to https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases, and something about the change in #12018 resulted in mypy errors from this line.

PiperOrigin-RevId: 470025070
copybara-service bot pushed a commit that referenced this pull request Aug 25, 2022
Why? This is required according to https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases, and something about the change in #12018 resulted in mypy errors from this line.

PiperOrigin-RevId: 470046276
@jakevdp jakevdp closed this Oct 14, 2022
@jakevdp jakevdp deleted the typing branch October 14, 2022 19:11
@jakevdp
Copy link
Collaborator Author

jakevdp commented Oct 14, 2022

(closing in favor of #12300)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants