Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Please allow @jax.jit on staticmethods #7702

Closed
wookayin opened this issue Aug 24, 2021 · 10 comments
Closed

Please allow @jax.jit on staticmethods #7702

wookayin opened this issue Aug 24, 2021 · 10 comments
Assignees
Labels
enhancement New feature or request

Comments

@wookayin
Copy link
Contributor

wookayin commented Aug 24, 2021

The following code does not work,

class Foo:
  @jax.jit
  @staticmethod
  def foo(x):
    return x * 2

with an error TypeError: Expected a callable value, got <staticmethod object at 0x7fe007e767f0>. This is because a staticmethod is NOT a callable, until Python 3.10. Of course, staticmethod is callable -- something we can call.

In #1251 one possible workaround is suggested:

class Foo:
  @functools.partial(jax.jit, static_argnums=(0,))
  def foo(self, x):
    # this function should not make use of `self` to be pure-functional
    return x * 2

which is a bit ugly, but basically does the same thing as long as self is not used.

I don't see any reason staticmethod cannot be jit-ed as other non-method functions. I think this is just a matter of extending _check_callable(fun) to support staticmethods (which will happen automatically in Python 3.10); any pitfalls I would've missed?

@wookayin wookayin added the enhancement New feature or request label Aug 24, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Aug 24, 2021

Thanks for the report! That's really interesting – I wasn't aware that staticmethod objects are not callable. I can't think of any reason why we shouldn't be able to extend _check_callable to recognize static methods: would you like to put together a pull request?

For the time being, another workaround is to switch the order of jit and staticmethod. This works as expected:

class Foo:
  @staticmethod
  @jax.jit
  def foo(x):
    return x * 2

@jakevdp jakevdp self-assigned this Aug 24, 2021
@wookayin
Copy link
Contributor Author

wookayin commented Aug 24, 2021

Thank you for your message.

Actually I found some relevant test cases that are checking if classmethod (staticmethod) is not supported, which are introduced in bdd6545. /cc @jblespiau (an author of those tests): any particular reasons we had before?

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 24, 2021

Interesting - I don't see any discussion of this in the associated PR (#4169)

@jblespiau
Copy link
Contributor

jblespiau commented Aug 24, 2021

I cannot remember, but I think this was always the case, and that, while doing the C++ codepath for jax.jit, I added these tests to document this (so there was no discussion in the PR, because there was no change of behavior -- I may be wrong).

Generally speaking, JAX is following a functional paradigm and decorating static (or non-static) methods is somehow breaking this, as the usual practice is to jax.jit plain functions, so it's discouraged (at least, I discourage it). Also, there is a one-liner solution if the user really wants that.

Generally speaking, within Google, we avoid static methods, except for factories: https://google.github.io/styleguide/pyguide.html#217-function-and-method-decorators . It's always possible to have a plain, top-level function doing the same thing. So usually, we never have static methods, and thus, never want to jax.jit them.

For these reasons, my personal conclusion (which may not align with what the JAX team is thinking) is that it's probably not worth supporting these somehow non-cannonical usages as first-class citizens (especially because it's possible with the syntax above).

@wookayin
Copy link
Contributor Author

wookayin commented Aug 26, 2021

Thank you @jblespiau for the explanation.

It's always possible to have a plain, top-level function doing the same thing. So usually, we never have static methods, and thus, never want to jax.jit them.

That would probably be how Google usually writes a JAX code, but I respectfully disagree staticmethods should be prohibited, because some community codes that do not necessarily follow the Google convention strictly might still want jit-ed staticmethods. Personally I don't like having a top- or module-level plain function because I have to place them outside the class, usually very far from the methods that implement a very relevant logic.

That said, as an alternative one could define a (nested) plain function inside __init__ and assign them as an attribute, but I feel using staticmethods would give a bit more flexibility. I can prepare a PR to support this feature.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 26, 2021

I think the cleanest approach is probably to use @staticmethod @jit rather than @jit @staticmethod.

If what you say about Python 3.10 is correct, though, the tests referred to above will (I think?) fail when run with Py3.10. If there are no fundamental reasons why staticmethod objects cannot be jitted, I would be fine with a PR that removes those tests and updates _is_callable to recognize static methods.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 26, 2021

staticmethod objects are strange, though: for example:

>>> @staticmethod
... def f(x):
...   return x
>>> f(1)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-9-281ab0a37d7d> in <module>
----> 1 f(1)

TypeError: 'staticmethod' object is not callable

Since the JIT transform operates by calling the function passed to it, I'm not sure how easy it would be to make this work.

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 26, 2021

JIT would have to know to look at the __func__ attribute of a staticmethod object:

>>> f.__func__(1)
1

This change would have to be made both in JAX and in jaxlib, where the C++ JIT path is defined, so it's not a particularly trivial change.

All told, I think it would be cleaner to keep requiring that objects passed to jit are callable, and in the case of staticmethod use the workarounds mentioned here.

@cisprague
Copy link

Is there any technical downside to using @staticmethod @jit? It seems like one reason to have a class of jited static methods is to keep those methods in the namespace of the class. But, are there other recommended ways?

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 5, 2024

I don't think there's any downside to using @staticmethod @jit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants