Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print return of jit method in class #1567

Closed
rdaems opened this issue Oct 25, 2019 · 7 comments
Closed

Print return of jit method in class #1567

rdaems opened this issue Oct 25, 2019 · 7 comments
Assignees
Labels
question Questions for the JAX team useful read PR or issue that contains useful design discussion

Comments

@rdaems
Copy link

rdaems commented Oct 25, 2019

Hi all,

I'm trying to use classes with jax and one of the problems I have is that I can't print a value that is manipulated in a JIT compiled class method. Example code:

import jax.numpy as np
from jax import jit
from functools import partial


class World:
    def __init__(self, p, v):
        self.p = p
        self.v = v

    @partial(jit, static_argnums=(0,))
    def step(self, dt):
        a = - 9.8
        self.v += a * dt
        self.p += self.v *dt


world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
    world.step(0.01)
print(world.p)

This prints Traced<ShapedArray(float32[2]):JaxprTrace(level=-1/1)>
I'm aware this is expected behavior when you print something inside the function, but this is not inside the function right?

More generally, I'm wondering if object oriented programming is well suited for jax? Should I avoid this kind of stuff? Is JIT capable of working optimally this way?

Thanks for your time!

@mattjj
Copy link
Collaborator

mattjj commented Oct 25, 2019

Thanks for the question!

The issue is that the step method violates functional purity: it has a side-effect (of updating self.v and self.p). You can only use jit on pure functions. (Unfortunately, we don't have good ways of checking whether a function is pure and warning you; it's just an un-checked user promise.)

See also What's supported in the readme.

More generally, I'm wondering if object oriented programming is well suited for jax? Should I avoid this kind of stuff?

Objects are okay, but jit functions can't have side-effects, and that means the pattern in your example (which is pretty canonical Python OOP) won't work with jit.

(@dougalm your leak-detector could have caught this issue. I wonder if we should look into reviving it.)

WDYT?

@mattjj
Copy link
Collaborator

mattjj commented Oct 25, 2019

Here are a couple styles that do work well with JAX:

import jax.numpy as np
from jax import jit
from collections import namedtuple

World = namedtuple("World", ["p", "v"])

@jit
def step(world, dt):
  a = -9.8
  new_v = world.v + a * dt
  new_p = world.p + new_v * dt
  return World(new_p, new_v)

world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
  world = step(world, 0.01)
print(world.p)

That's just a functional version of your code. The key is that step returns a new World, rather than modifying the existing one.

We can organize the same thing into Python classes if we want:

from jax.tree_util import register_pytree_node
from functools import partial

class World:
  def __init__(self, p, v):
    self.p = p
    self.v = v

  @jit
  def step(self, dt):
    a = -9.8
    new_v = self.v + a * dt
    new_p = self.p + new_v * dt
    return World(new_p, new_v)

# By registering 'World' as a pytree, it turns into a transparent container and
# can be used as an argument to any JAX-transformed functions.
register_pytree_node(World,
                     lambda x: ((x.p, x.v), None),
                     lambda _, tup: World(tup[0], tup[1]))


world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
  world = world.step(0.01)
print(world.p)

The key difference there is that step returns a new World instance.

Here's one last pattern that works, using your original World class, though it's a bit more subtle:

class World:
  def __init__(self, p, v):
    self.p = p
    self.v = v

  def step(self, dt):
    a = - 9.8
    self.v += a * dt
    self.p += self.v *dt

@jit
def run(init_p, init_v):
  world = World(init_p, init_v)
  for i in range(1000):
    world.step(0.01)
  return world.p, world.v

out = run(np.array([0, 0]), np.array([1, 1]))
print(out)

(That last one takes much longer to compile, because we're unrolling 1000 steps into a single XLA computation and compiling that; in practice we'd use something like lax.fori_loop or lax.scan to avoid those long compile times.)

The reason your original class works in that last example is that we're only using it under a jit, so the jit function itself doesn't have any side-effects.

Of those styles, I personally have grown to like the first. I wrote all my code in grad school in an OOP-heavy style, and I regret it: it was hard to compose with other code, even other code that I wrote, and that really limited its reach. Functional code, by forcing explicit state management, solves the composition problem. It's also a great fit for numerical computing in general, since numerical computing is much closer to math than, say, writing a web server.

Hope that's helpful :)

@mattjj mattjj self-assigned this Oct 25, 2019
@mattjj mattjj added the question Questions for the JAX team label Oct 25, 2019
@rdaems
Copy link
Author

rdaems commented Nov 4, 2019

Hi mattjj

Thanks a lot for your elaborate answers! They're very enlightening.
It's interesting to hear your experiences from grad school, I'll try to learn from them.

I've been thinking about why exactly I want to use OOP, and the pros and cons in relation to JAX.
The main reason to use OOP in my case is because I'm building a physics engine, and OOP provides a neat way of defining objects in the scene with properties and states.
Eventually you get a global state vector with the states of all the objects in it, so I think I'll still try to use OOP to build up the scene and the objects, and define what's what in the global state vector.
But I could keep the vector itself out of the OOP.
That way, there's nothing in self that changes throughout the simulation. And jit would work, right?

@phinate
Copy link

phinate commented Mar 3, 2020

Here are a couple styles that do work well with JAX:

import jax.numpy as np
from jax import jit
from collections import namedtuple

World = namedtuple("World", ["p", "v"])

@jit
def step(world, dt):
  a = -9.8
  new_v = world.v + a * dt
  new_p = world.p + new_v * dt
  return World(new_p, new_v)

world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
  world = step(world, 0.01)
print(world.p)

That's just a functional version of your code. The key is that step returns a new World, rather than modifying the existing one.

We can organize the same thing into Python classes if we want:

from jax.tree_util import register_pytree_node
from functools import partial

class World:
  def __init__(self, p, v):
    self.p = p
    self.v = v

  @jit
  def step(self, dt):
    a = -9.8
    new_v = self.v + a * dt
    new_p = self.p + new_v * dt
    return World(new_p, new_v)

# By registering 'World' as a pytree, it turns into a transparent container and
# can be used as an argument to any JAX-transformed functions.
register_pytree_node(World,
                     lambda x: ((x.p, x.v), None),
                     lambda _, tup: World(tup[0], tup[1]))


world = World(np.array([0, 0]), np.array([1, 1]))

for i in range(1000):
  world = world.step(0.01)
print(world.p)

The key difference there is that step returns a new World instance.

Here's one last pattern that works, using your original World class, though it's a bit more subtle:

class World:
  def __init__(self, p, v):
    self.p = p
    self.v = v

  def step(self, dt):
    a = - 9.8
    self.v += a * dt
    self.p += self.v *dt

@jit
def run(init_p, init_v):
  world = World(init_p, init_v)
  for i in range(1000):
    world.step(0.01)
  return world.p, world.v

out = run(np.array([0, 0]), np.array([1, 1]))
print(out)

(That last one takes much longer to compile, because we're unrolling 1000 steps into a single XLA computation and compiling that; in practice we'd use something like lax.fori_loop or lax.scan to avoid those long compile times.)

The reason your original class works in that last example is that we're only using it under a jit, so the jit function itself doesn't have any side-effects.

Of those styles, I personally have grown to like the first. I wrote all my code in grad school in an OOP-heavy style, and I regret it: it was hard to compose with other code, even other code that I wrote, and that really limited its reach. Functional code, by forcing explicit state management, solves the composition problem. It's also a great fit for numerical computing in general, since numerical computing is much closer to math than, say, writing a web server.

Hope that's helpful :)

Just wanted to let you know that this answer allowed me to convert from an OOP to a functional mindset for the first time! Was v helpful and illustrative, and I can now @jax.jit everything! ;)

@mattjj mattjj added the useful read PR or issue that contains useful design discussion label Apr 29, 2020
@gnecula
Copy link
Collaborator

gnecula commented May 12, 2020

I believe we can close this issue, will keep the useful_read label.

@gnecula gnecula closed this as completed May 12, 2020
@AdrienCorenflos
Copy link
Contributor

AdrienCorenflos commented Jul 15, 2020

Otherwise why not add support for attr classes? They're basically namedtuples under steroids.

Looks to me like it would need a small contrib to the pytree file, mostly at lines:

Then you'd be able to do something like:

@attr.s
class World:
    p: np.ndarray = attr.ib()
    v: np.ndarray = attr.ib()

    @jax.jit
    def step(self, dt):
        a = - 9.8
        v = a * dt
        p = self.v *dt
        return attr.evolve(self, p=p, v=v)

Cause World instances would be recognized as acceptable arguments and outputs of jitted functions.

EDIT:
Might even help to do AOT compilation à la numba.

EDIT 2:
Additionally you could verify that the methods are pure by enforcing that the class be declared as being frozen.

@jakevdp
Copy link
Collaborator

jakevdp commented May 6, 2022

In case anyone comes across this, there is now a section in the FAQ addressing jit compilation of methods: https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team useful read PR or issue that contains useful design discussion
Projects
None yet
Development

No branches or pull requests

6 participants