-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Print return of jit method in class #1567
Comments
Thanks for the question! The issue is that the See also What's supported in the readme.
Objects are okay, but (@dougalm your leak-detector could have caught this issue. I wonder if we should look into reviving it.) WDYT? |
Here are a couple styles that do work well with JAX: import jax.numpy as np
from jax import jit
from collections import namedtuple
World = namedtuple("World", ["p", "v"])
@jit
def step(world, dt):
a = -9.8
new_v = world.v + a * dt
new_p = world.p + new_v * dt
return World(new_p, new_v)
world = World(np.array([0, 0]), np.array([1, 1]))
for i in range(1000):
world = step(world, 0.01)
print(world.p) That's just a functional version of your code. The key is that We can organize the same thing into Python classes if we want: from jax.tree_util import register_pytree_node
from functools import partial
class World:
def __init__(self, p, v):
self.p = p
self.v = v
@jit
def step(self, dt):
a = -9.8
new_v = self.v + a * dt
new_p = self.p + new_v * dt
return World(new_p, new_v)
# By registering 'World' as a pytree, it turns into a transparent container and
# can be used as an argument to any JAX-transformed functions.
register_pytree_node(World,
lambda x: ((x.p, x.v), None),
lambda _, tup: World(tup[0], tup[1]))
world = World(np.array([0, 0]), np.array([1, 1]))
for i in range(1000):
world = world.step(0.01)
print(world.p) The key difference there is that Here's one last pattern that works, using your original class World:
def __init__(self, p, v):
self.p = p
self.v = v
def step(self, dt):
a = - 9.8
self.v += a * dt
self.p += self.v *dt
@jit
def run(init_p, init_v):
world = World(init_p, init_v)
for i in range(1000):
world.step(0.01)
return world.p, world.v
out = run(np.array([0, 0]), np.array([1, 1]))
print(out) (That last one takes much longer to compile, because we're unrolling 1000 steps into a single XLA computation and compiling that; in practice we'd use something like The reason your original class works in that last example is that we're only using it under a Of those styles, I personally have grown to like the first. I wrote all my code in grad school in an OOP-heavy style, and I regret it: it was hard to compose with other code, even other code that I wrote, and that really limited its reach. Functional code, by forcing explicit state management, solves the composition problem. It's also a great fit for numerical computing in general, since numerical computing is much closer to math than, say, writing a web server. Hope that's helpful :) |
Hi mattjj Thanks a lot for your elaborate answers! They're very enlightening. I've been thinking about why exactly I want to use OOP, and the pros and cons in relation to JAX. |
Just wanted to let you know that this answer allowed me to convert from an OOP to a functional mindset for the first time! Was v helpful and illustrative, and I can now |
I believe we can close this issue, will keep the useful_read label. |
Otherwise why not add support for attr classes? They're basically namedtuples under steroids. Looks to me like it would need a small contrib to the pytree file, mostly at lines:
Then you'd be able to do something like: @attr.s
class World:
p: np.ndarray = attr.ib()
v: np.ndarray = attr.ib()
@jax.jit
def step(self, dt):
a = - 9.8
v = a * dt
p = self.v *dt
return attr.evolve(self, p=p, v=v) Cause World instances would be recognized as acceptable arguments and outputs of jitted functions. EDIT: EDIT 2: |
In case anyone comes across this, there is now a section in the FAQ addressing jit compilation of methods: https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods |
Hi all,
I'm trying to use classes with jax and one of the problems I have is that I can't print a value that is manipulated in a JIT compiled class method. Example code:
This prints
Traced<ShapedArray(float32[2]):JaxprTrace(level=-1/1)>
I'm aware this is expected behavior when you print something inside the function, but this is not inside the function right?
More generally, I'm wondering if object oriented programming is well suited for jax? Should I avoid this kind of stuff? Is JIT capable of working optimally this way?
Thanks for your time!
The text was updated successfully, but these errors were encountered: