Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trouble using JIT when the function is in a class? #1251

Closed
vedantammihir opened this issue Aug 26, 2019 · 7 comments
Closed

Trouble using JIT when the function is in a class? #1251

vedantammihir opened this issue Aug 26, 2019 · 7 comments
Assignees
Labels
question Questions for the JAX team

Comments

@vedantammihir
Copy link

vedantammihir commented Aug 26, 2019

Hi,

I noticed some significant slowdowns in my code from using jax.numpy instead of numpy and from the other issues it seems the solution is to use jit. However, when I try to use jit in a single script file for testing purposes it seems to work, but when I separate the function that I want to jit into another class I have problems.

import jax.numpy as np
import numpy as onp
from jax import jit, jacfwd, grad

from jax.numpy import sin, cos, exp

class odes:
    def __init__(self):
        print("odes file initialized")
    @jit
    def simpleODE(self, t,q):
        return np.array([[q[1]], [cos(q[0])]])

from odes import *
from jax import jit, jacfwd, grad

ODE = odes()

Jac = jacfwd(ODE.simpleODE, argnums = (1,))

q = np.ones(2)

A = Jac(0,q)
print(A)

gives the following error,
TypeError: Argument '<odes.odes object at 0x7fe440250810>' of type <class 'odes.odes'> is not a valid JAX type

@mattjj
Copy link
Collaborator

mattjj commented Aug 26, 2019

You might be able to work with this pattern:

from functools import partial

class odes:
  def __init__(self):
    print("odes file initialized")

  @partial(jit, static_argnums=(0,))
  def simpleODE(self, t, q):
    return np.array([[q[1]], [cos(q[0])]])

In words, we're marking the first argument (index 0) as a static argument.

What do you think?

@mattjj mattjj added the question Questions for the JAX team label Aug 26, 2019
@mattjj mattjj self-assigned this Aug 26, 2019
@vedantammihir
Copy link
Author

Ah works, for me.

Thanks for the help! Was it due to some interaction between the JAX wrapper and the self object?

@mattjj
Copy link
Collaborator

mattjj commented Aug 27, 2019

Glad to hear that helped!

Yes, the issue is that jit only knows how to compile numerical computations on arrays (i.e. what XLA can do), not arbitrary Python computations. In particular that means it only knows how to work with array data types, not arbitrary classes, and in this case the self argument is an instance of ode. By using static_argnums we're telling jit to compile only the computation that gets applied to the other arguments, and just to re-trace and re-compile every time the first argument changes its Python object id. That re-tracing basically means jit lets Python handle everything to do with the self argument.

@mattjj mattjj closed this as completed Aug 27, 2019
@YukunXia
Copy link

Partial decorator solved my problem! I was also trying to use syntax like jacfwd(test_class.test_func, argnums=[1]), where test_func is defined as def test_func(self, x), to avoid considering the self argument, but that returned me tuple out of range. How should I fix this error?

@araza6
Copy link

araza6 commented Jun 17, 2020

@mattjj any pointers as to how it would be done for jax.grad?

@Anselmoo
Copy link

Anselmoo commented Jan 1, 2024

from functools import partial

@mattjj, the recipe works very nicely.

❓ However I am still wondering if it does not make sense to add the partial approach straight into the @jit decorated because many JAX functions like jacfwd or vmap will profit from the simplicity.

❓ While if not, maybe the TypeError can provide a hint to think about using @partial?

TypeError: Cannot interpret value of type <class '__main__.TheMethod'> as an abstract array; it does not have a dtype attribute

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 1, 2024

Thanks for the question

I am still wondering if it does not make sense to add the partial approach straight into the @jit decorator

We've previously discussed this kind of idea and decided not to go this route; see e.g. #10061

Also, I should note that this question and its answer are quite old, and I would no longer recommend static_argnums as a solution for JIT-compilation of a class method. For a fuller discussion of issues with this, see FAQ: how to use jit with methods.

cjlee7128 referenced this issue in xlab-ub/py-mlmodelscope Mar 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

6 participants