Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the Type annotation for jit an pmap as there are additional attributes on the returned callable. #9535

Closed
wants to merge 1 commit into from

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Feb 11, 2022

Remove the Type annotation for jit an pmap as there are additional attributes on the returned callable.

Using the experimental jax.jit(lambda x: x+1).lower(...) is raising an error with pytype.

@hawkinsp
Copy link
Collaborator

@NeilGirdhar fyi

@copybara-service copybara-service bot force-pushed the test_427986940 branch 2 times, most recently from 4e02937 to 9103a5a Compare February 11, 2022 15:17
@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Feb 11, 2022

@hawkinsp What error are you getting? Seems like a bug in pytype. Could you file one there?

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 11, 2022

@NeilGirdhar My understanding is that the type annotation is incorrect because the returned object adds some methods; e.g. type checkers will incorrectly fail on the following:

from jax import jit
@jit
def f(x):
  return x
lowered = f.lower(1.0)  # <--- original function has no `lower` method, so type checking fails

The full solution would proably be to define some custom protocol, but in the meantime Any is not incorrect.

@NeilGirdhar
Copy link
Contributor

The full solution would proably be to define some custom protocol

Right, that would be a better solution.

@copybara-service copybara-service bot force-pushed the test_427986940 branch 2 times, most recently from 03d32b7 to b4e47f9 Compare February 14, 2022 09:40
@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Feb 14, 2022

@jakevdp By the way, I looked into it, but MyPy has bugs relating to that protocol: python/mypy#12169. So I guess, I won't be able to get that protocol using ParamSpec until MyPy has fixes. Could we use a protocol that doesn't use ParamSpec?

Something like this:

T = TypeVar("T", covariant=True)

# TODO: Use ParamSpec to annotate function parameters.
class JittedFunction(Protocol, Generic[T]):
  def __call__(self, *args: Any, **kwargs: Any) -> T:
    ...

  def lower(self, *args: Any, **kwargs: Any) -> Lowered:
    ...

And then where you have Callable, you could have Callable[..., T] and you could return JittedFunction[T]. What do you think?

@copybara-service copybara-service bot force-pushed the test_427986940 branch 2 times, most recently from df9576c to a13d45e Compare February 14, 2022 14:29
@hawkinsp hawkinsp self-assigned this Feb 14, 2022
@jblespiau
Copy link
Contributor

Note also that ParamSpec is in Python 3.10. Internally, we are still at 3.7

…tributes on the returned callable.

Using the experimental jax.jit(lambda x: x+1).lower(...) is raising an error with pytype.

PiperOrigin-RevId: 427986940
@NeilGirdhar
Copy link
Contributor

@jblespiau You can import ParamSpec from typing_extensions.

@copybara-service copybara-service bot closed this Feb 15, 2022
@copybara-service copybara-service bot deleted the test_427986940 branch February 15, 2022 10:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants