-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove the Type annotation for jit an pmap as there are additional attributes on the returned callable. #9535
Conversation
14cf746
to
96ee1b9
Compare
@NeilGirdhar fyi |
4e02937
to
9103a5a
Compare
@hawkinsp What error are you getting? Seems like a bug in pytype. Could you file one there? |
@NeilGirdhar My understanding is that the type annotation is incorrect because the returned object adds some methods; e.g. type checkers will incorrectly fail on the following: from jax import jit
@jit
def f(x):
return x
lowered = f.lower(1.0) # <--- original function has no `lower` method, so type checking fails The full solution would proably be to define some custom protocol, but in the meantime |
Right, that would be a better solution. |
03d32b7
to
b4e47f9
Compare
@jakevdp By the way, I looked into it, but MyPy has bugs relating to that protocol: python/mypy#12169. So I guess, I won't be able to get that protocol using Something like this: T = TypeVar("T", covariant=True)
# TODO: Use ParamSpec to annotate function parameters.
class JittedFunction(Protocol, Generic[T]):
def __call__(self, *args: Any, **kwargs: Any) -> T:
...
def lower(self, *args: Any, **kwargs: Any) -> Lowered:
... And then where you have |
df9576c
to
a13d45e
Compare
Note also that ParamSpec is in Python 3.10. Internally, we are still at 3.7 |
…tributes on the returned callable. Using the experimental jax.jit(lambda x: x+1).lower(...) is raising an error with pytype. PiperOrigin-RevId: 427986940
a13d45e
to
42de6c7
Compare
@jblespiau You can import |
Remove the Type annotation for jit an pmap as there are additional attributes on the returned callable.
Using the experimental jax.jit(lambda x: x+1).lower(...) is raising an error with pytype.