Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] Use ParamSpec in JIT annotation #14688

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

NeilGirdhar
Copy link
Contributor

@NeilGirdhar NeilGirdhar commented Feb 26, 2023

This pull request would be a huge improvement for Jax users who use type checkers like MyPy or Pyright, which now support ParamSpec.

Consider:

from jax import Array, jit

@jit
def f(x: Array, y: Array) -> Array:
    pass

reveal_type(f.__call__)

Previous to this PR, users who decorate a function with jit lose all annotations of the method:

Pyright: Type of "f.__call__" is "(*args: Unknown, **kwargs: Unknown) -> Unknown"
MyPy: Revealed type is "def (*args: Any, **kwargs: Any) -> Any"

After this PR, we get:

Pyright: Type of "f.__call__" is "(x: Array, y: Array) -> Array"
MyPy: Revealed type is "def (x: Array, y: Array) -> Array"

This unblocks a lot of type checking.

Part of #12049 cc: @jakevdp

@NeilGirdhar NeilGirdhar changed the title Use ParamSpec in jit annotation; bump MyPy to 1.0 [typing] Use ParamSpec in jit annotation; bump MyPy to 1.0 Feb 26, 2023
@NeilGirdhar
Copy link
Contributor Author

Not sure what's going on with the type checking, but perhaps it's not running MyPy 1.0.1?

Also, I'm still working on fixing this for applying jit to methods, which currently works, but does not type check correctly.

setup.py Outdated
@@ -67,6 +67,7 @@ def generate_proto(source):
'numpy>=1.20',
'opt_einsum',
'scipy>=1.5',
'typing_extensions>=4.5.0',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for adding this as a compulsory dependency?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could make it a dev-dependency. Would that be better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it a dev-dependency in a subsequence change (within this PR) so that you can assess whether you like that better.

Normally, typing-extensions is seen as a light dependency since it only contains typing annotations, so it won't realistically break any runtime code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to add as a dependency, but I'm also curious why it's necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakevdp We need ParamSpec, which was added in Python 3.10 (Jax's minimum version per NEP 29 is 3.8. We could make it a dev-dependency (as I've now done) although I think it complicates the code.

@NeilGirdhar NeilGirdhar force-pushed the jit_annotation branch 2 times, most recently from c877b16 to ee60835 Compare February 26, 2023 11:50
@NeilGirdhar NeilGirdhar marked this pull request as draft February 26, 2023 12:31
@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Feb 26, 2023

After testing this a lot with my code, I think MyPy is still not ready to check code with complex ParamSpec usage (although Pyright is).

If we check something like this in, it may trip up MyPy quite a bit. On the other hand, it has a the amazing effect of exposing plenty Jax type annotations that were previously hidden. Plenty of functions in jax.numpy are stripped of perfectly good annotations by @jit.

Also, there are some limitations in Python typing wrt method decorators. I've posted about this on python/typing. If people are using jit on methods, then they will get errors, and there doesn't seem to be a universal fix.

What I will do is add typed decorators to my tjax library. Please let me know if you'd like to get this in to Jax sooner or if we should wait on it.

.pre-commit-config.yaml Outdated Show resolved Hide resolved
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 26, 2023

Thanks for looking at this! ParamSpec seems like it could be interesting; that said it still looks to be relatively unstable and probably the remarks here still apply.

@JesseFarebro
Copy link

Hasn't this already been discussed here: #10311?

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Feb 26, 2023

@JesseFarebro My mistake, I didn't find that when I searched! MyPy has progressed since then, but maybe still not enough.

@jakevdp Yes, fair enough. ParamSpec seems to work for more cases than before, but it seems to be causing some errors still. I'll check back in a few months to see how support is coming.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Aug 16, 2023

Great news: Thanks to the recently merged python/mypy#15837, it appears that the main MyPy error with this pull request may have been solved: python/mypy#12169. Also, many of the MyPy errors that may have affected its usage may have been solved: python/mypy#11846, python/mypy#12986, python/mypy#14802.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 16, 2023

We've updated mypy to 1.4.1 in the meantime, and #17147 bumps it to 1.5.0 – can you sync your PR to the current main branch?

@NeilGirdhar NeilGirdhar force-pushed the jit_annotation branch 3 times, most recently from aaec678 to 0ffe6c8 Compare August 16, 2023 21:50
@NeilGirdhar
Copy link
Contributor Author

@jakevdp Done. FYI the MyPy pull I linked is merged, but it's not in any released MyPy yet.

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Oct 10, 2023

@jakevdp MyPy 1.6 is out today, and it may now support using ParamSpec as in this pull request. (See the section on ParamSpec improvements.)

Is there a process or plan for upgrading Jax to use MyPy 1.6? Looks like you guys have been upgrading regularly.

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 10, 2023

As soon as the new mypy version is mirrored at https://github.com/pre-commit/mirrors-mypy, we can bump the version in the pre-commit configuration here:

rev: 'v1.5.1'

It looks like the mirror update will happen automatically in about 12.5 hours: https://github.com/pre-commit/mirrors-mypy/blob/08cbc46b6e135adec84911b20e98e5bc52032152/.github/workflows/main.yml#L6

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 11, 2023

#18066 updates mypy to v1.6.0

@NeilGirdhar NeilGirdhar force-pushed the jit_annotation branch 2 times, most recently from ee38fd6 to 534cae8 Compare October 11, 2023 21:10
@NeilGirdhar NeilGirdhar marked this pull request as ready for review October 11, 2023 21:11
@NeilGirdhar
Copy link
Contributor Author

Okay, I've rebased this to take advantage of the new MyPy version

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Oct 11, 2023

Just FYI if you want to get rid of the typing_extensions dependency, one option would be to follow SPEC 0, and drop Python 3.9 (since you already support Python 3.12). Python 3.10 has ParamSpec.

@NeilGirdhar
Copy link
Contributor Author

Running MyPy 1.6.1, I'm still getting a lot of bad errors:


jax/experimental/sparse/linalg.py:102: error: Argument 2 to "__call__" of "Wrapped" has incompatible type "Array"; expected "P.args"  [arg-type]
jax/experimental/sparse/linalg.py:102: error: Argument 3 to "__call__" of "Wrapped" has incompatible type "int"; expected "P.args"  [arg-type]
jax/experimental/sparse/linalg.py:102: error: Argument 4 to "__call__" of "Wrapped" has incompatible type "Array | float | None"; expected "P.args"  [arg-type]

It seems like more MyPy bugs. PyRight seems okay with it. I guess we'll have to wait longer? And someone should consider submitting this bug to MyPy?

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 3, 2023

Seems like maybe a bad interaction with functools.partial?

@overload
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split out these changes in a separate PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, okay, I will do that when we get close to having this checked in. This PR will expose a lot of typing errors as code that's jitted becomes typed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Factored out as requested: #18395

@NeilGirdhar
Copy link
Contributor Author

Seems like maybe a bad interaction with functools.partial?

That makes sense. I'll look at it as soon as I have more time.

@NeilGirdhar
Copy link
Contributor Author

Filed a MyPy bug python/mypy#16404

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Nov 4, 2023

@jakevdp I think we should wait one more release of MyPy, which should fix the above errors. After that, although this change will not help MyPy users (the inferred type of a jit-decorated function will just be Callable), this change won't hurt, and it will help Pyright users. What do you think?

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 6, 2023

@jakevdp I think we should wait one more release of MyPy

Sounds good

@NeilGirdhar
Copy link
Contributor Author

@jakevdp I've rebased this now. Is there a way to get the tests to run? python/mypy#1484 is still unsolved, so there may only be modest gains in MyPy.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I'm going to approve and then pull it in to run internal pytype tests

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Mar 13, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 14, 2024

Pytype is failing with this message:

File ".../jax/_src/api.py", line 100, in <module>: argument "covariant" to TypeVar not supported yet [not-supported-yet]

Seems this is a known issue (google/pytype#1471). Is there any way to do this kind of improvement without using covariant TypeVars?

@NeilGirdhar
Copy link
Contributor Author

NeilGirdhar commented Mar 14, 2024

@jakevdp

I'm going to approve and then pull it in to run internal pytype tests'

Thanks for taking the time!

Is there any way to do this kind of improvement without using covariant TypeVars?

The return type does need to be covariant (MyPy gives an error without it). One option is to do this in two steps. Just keep the parameter specification annotation and remove the return type annotation.

It's funny because in Python 3.12, we won't need these markers at all as the type checker can infer them, although I think we'd have to use the new PEP 695 syntax.

Should I remove the annotations on the return type?

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 14, 2024

I'm honestly not sure what the best fix is here.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 14, 2024

I think maybe the discussion in https://jax.readthedocs.io/en/latest/jep/12049-type-annotations.html#avoid-unstable-typing-mechanisms is still applicable.

@NeilGirdhar
Copy link
Contributor Author

@jakevdp No worries. I thought we were just waiting for MyPy. I didn't realize that we were waiting for Pytype as well!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants