-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trouble using JIT when the function is in a class? #1251
Comments
You might be able to work with this pattern: from functools import partial
class odes:
def __init__(self):
print("odes file initialized")
@partial(jit, static_argnums=(0,))
def simpleODE(self, t, q):
return np.array([[q[1]], [cos(q[0])]]) In words, we're marking the first argument (index 0) as a static argument. What do you think? |
Ah works, for me. Thanks for the help! Was it due to some interaction between the JAX wrapper and the self object? |
Glad to hear that helped! Yes, the issue is that |
Partial decorator solved my problem! I was also trying to use syntax like |
@mattjj any pointers as to how it would be done for jax.grad? |
@mattjj, the recipe works very nicely. ❓ However I am still wondering if it does not make sense to add the ❓ While if not, maybe the
|
Thanks for the question
We've previously discussed this kind of idea and decided not to go this route; see e.g. #10061 Also, I should note that this question and its answer are quite old, and I would no longer recommend |
Hi,
I noticed some significant slowdowns in my code from using jax.numpy instead of numpy and from the other issues it seems the solution is to use jit. However, when I try to use jit in a single script file for testing purposes it seems to work, but when I separate the function that I want to jit into another class I have problems.
gives the following error,
TypeError: Argument '<odes.odes object at 0x7fe440250810>' of type <class 'odes.odes'> is not a valid JAX type
The text was updated successfully, but these errors were encountered: