-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Please allow @jax.jit on staticmethods #7702
Comments
Thanks for the report! That's really interesting – I wasn't aware that For the time being, another workaround is to switch the order of class Foo:
@staticmethod
@jax.jit
def foo(x):
return x * 2 |
Thank you for your message. Actually I found some relevant test cases that are checking if classmethod (staticmethod) is not supported, which are introduced in bdd6545. /cc @jblespiau (an author of those tests): any particular reasons we had before? |
Interesting - I don't see any discussion of this in the associated PR (#4169) |
I cannot remember, but I think this was always the case, and that, while doing the C++ codepath for jax.jit, I added these tests to document this (so there was no discussion in the PR, because there was no change of behavior -- I may be wrong). Generally speaking, JAX is following a functional paradigm and decorating static (or non-static) methods is somehow breaking this, as the usual practice is to jax.jit plain functions, so it's discouraged (at least, I discourage it). Also, there is a one-liner solution if the user really wants that. Generally speaking, within Google, we avoid static methods, except for factories: https://google.github.io/styleguide/pyguide.html#217-function-and-method-decorators . It's always possible to have a plain, top-level function doing the same thing. So usually, we never have static methods, and thus, never want to jax.jit them. For these reasons, my personal conclusion (which may not align with what the JAX team is thinking) is that it's probably not worth supporting these somehow non-cannonical usages as first-class citizens (especially because it's possible with the syntax above). |
Thank you @jblespiau for the explanation.
That would probably be how Google usually writes a JAX code, but I respectfully disagree staticmethods should be prohibited, because some community codes that do not necessarily follow the Google convention strictly might still want jit-ed staticmethods. Personally I don't like having a top- or module-level plain function because I have to place them outside the class, usually very far from the methods that implement a very relevant logic. That said, as an alternative one could define a (nested) plain function inside |
I think the cleanest approach is probably to use If what you say about Python 3.10 is correct, though, the tests referred to above will (I think?) fail when run with Py3.10. If there are no fundamental reasons why staticmethod objects cannot be jitted, I would be fine with a PR that removes those tests and updates |
>>> @staticmethod
... def f(x):
... return x
>>> f(1)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-9-281ab0a37d7d> in <module>
----> 1 f(1)
TypeError: 'staticmethod' object is not callable Since the JIT transform operates by calling the function passed to it, I'm not sure how easy it would be to make this work. |
JIT would have to know to look at the >>> f.__func__(1)
1 This change would have to be made both in JAX and in jaxlib, where the C++ JIT path is defined, so it's not a particularly trivial change. All told, I think it would be cleaner to keep requiring that objects passed to |
Is there any technical downside to using |
I don't think there's any downside to using |
The following code does not work,
with an error
TypeError: Expected a callable value, got <staticmethod object at 0x7fe007e767f0>
. This is because astaticmethod
is NOT acallable
, until Python 3.10. Of course, staticmethod is callable -- something we can call.In #1251 one possible workaround is suggested:
which is a bit ugly, but basically does the same thing as long as
self
is not used.I don't see any reason staticmethod cannot be jit-ed as other non-method functions. I think this is just a matter of extending
_check_callable(fun)
to supportstaticmethod
s (which will happen automatically in Python 3.10); any pitfalls I would've missed?The text was updated successfully, but these errors were encountered: