Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.jit recompiles nested jitted functions #284

Closed
jonasrauber opened this issue Jan 28, 2019 · 5 comments
Closed

jax.jit recompiles nested jitted functions #284

jonasrauber opened this issue Jan 28, 2019 · 5 comments
Assignees
Labels
question Questions for the JAX team

Comments

@jonasrauber
Copy link
Contributor

This isn't really a bug and more of a question, I guess, but who knows.

Why is f being compiled twice in this example:

x = onp.random.rand(1024).astype(onp.float32)
x = jax.device_put(x)

def f(x):
    print('f x', x)
    return np.square(x)

def g(f, x):
    print('g f', f)
    print('g x', x)
    return f(x)

def h(f, x):
    print('h f', f)
    print('h x', x)
    return f(x)

f = jax.jit(f)
g = jax.jit(g, static_argnums=(0,))
h = jax.jit(h, static_argnums=(0,))

f(x)  # will trigger compilation of f
f(x)  # reuse cache
print('')
g(f, x)  # will trigger compilation of g and another compilation of f
g(f, x)  # reuse cache
print('')
h(f, x)  # will trigger compilation of h, but uses cached f
h(f, x)  # reuse cache
f x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>

g f <function jit.<locals>.f_jitted at 0x7faa5bb22ae8>
g x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>
f x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/2)>

h f <function jit.<locals>.f_jitted at 0x7faa5bb22ae8>
h x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>

I guess it has something to do with the level, but I don't really get why a jitted function is recompiled once it gets called from within another jitted function.

@hawkinsp hawkinsp added the enhancement New feature or request label Feb 1, 2019
@mattjj
Copy link
Collaborator

mattjj commented Feb 6, 2019

but I don't really get why a jitted function is recompiled once it gets called from within another jitted function.

Our trace-caching logic is pretty simple: it's just a @memoize decorator on the function that takes a wrapped Python callable fun and a set of abstract arguments and returns an executable XLA computation. The wrapping of fun just records what transformations have been applied to the underlying Python callable (and any auxiliary information that they need to smuggle out after the function has been called, like the tuple/list/dict tree structure of the output), and so that @memoize decorator is taking that transformation stack into account too.

What's complicated is the transformations we need to set up to guard against other traces' tracers hiding in function closures.

We can add a print(fun) to the top of the memoized xla_callable function to see why we're getting cache misses running your script. If we run just this part:

f(x)
g(f, x)

we see this (ignoring the stuff that prints out due to device_put):

Wrapped function:
0   : flatten_fun   ((*,),)
1   : process_env_traces   (xla_call, -1)
2   : pytree_fun_to_jaxtupletree_fun   ((*,),)
3   : argnums_partial_   ((0,), (None,))
Core: f

('f x', Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>)
Wrapped function:
0   : flatten_fun   ((*,),)
1   : process_env_traces   (xla_call, -1)
2   : pytree_fun_to_jaxtupletree_fun   ((*,),)
3   : argnums_partial_   ((1,), (<jax.util.WrapHashably object at 0x7f9d54442990>, None))
Core: g

('g f', <function jit(f) at 0x7f9d54457848>)
('g x', Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>)
Wrapped function:
0   : flatten_fun   ((JTupleTreeDef(child_specs=()),),)
1   : process_env_traces   (xla_call, -2)
2   : partial_eval_wrapper   ((ShapedArray(float32[1024]),),)
3   : trace_to_subjaxpr   (MasterTrace(-1,JaxprTrace),)
4   : process_env_traces   (xla_call, -1)
5   : pytree_fun_to_jaxtupletree_fun   ((*,),)
6   : argnums_partial_   ((0,), (None,))
Core: f

So the second time we're seeing f, the transform context is pretty different. It's tricky to unpack the details, but the high-level issue is that the second time we trace f we don't know if it closed over values that are traced by g. If it did, we need to generate different code (because, in effect, the computation carried out by f has more inputs the second time), even though from f's point of view those are just constants. This is basically like lambda lifting.

In this case, there's special structure here: f is a top-level function in the original Python source text, and in particular doesn't close over any values that could be traced, so we're safe from this closure issue. Maybe we could detect this special structure (by checking the Python runtime function object and noticing it has an empty closure?) and get a cache hit here.

But in general, when a function has a non-empty closure, we can't tell whether that's a benign closure (with no hidden traces) or whether that closure contains other tracers (maybe very indirectly, buried in arbitrary Python objects, including closed-over Python function objects) until we actually run the function. And at the point where we call and memoize xla_callable, we haven't actually run the function yet, so we don't know if we're safe from nested tracers, and we need to be defensive.

I'm inclined to err on the side of simplicity and not try to detect this special closure-free structure until we have a use case that needs it. (However, it's possible that @jonasrauber has already articulated possible use cases in other issues, and I just haven't grokked them yet.)

@dougalm did I get that right? Should we consider special handling of empty-closure functions, which might mean a special bind_call that is promised there are no traces in the closure? (There's a related todo in ad.py in call_transpose.)

@jonasrauber
Copy link
Contributor Author

Great explanation, thanks @mattjj. I am not sure how it relates to other issues I opened (in particular #282). I think for now it was just a question out of curiosity.

@mattjj mattjj self-assigned this Feb 6, 2019
@mattjj
Copy link
Collaborator

mattjj commented Feb 6, 2019

I'll re-mark this as a question and close it, but if we find a use case where we would really like cache hits here, we can reopen (or open a new issue).

Thanks for asking this!

@mattjj mattjj closed this as completed Feb 6, 2019
@mattjj mattjj added question Questions for the JAX team and removed enhancement New feature or request labels Feb 6, 2019
@jonasrauber
Copy link
Contributor Author

Following up on this, now that I worked a bit more with it: calling jitted functions from within jitted functions does not seem to be a good idea… i.e. not jitting the inner function explicitly (it will still be jitted when the whole outer function is jitted) seems to improve performance… is this a general rule one can keep in mind (i.e. for maximum performance put jit only on the most outer function) or are there cases where jitting functions inside other jitted functions is advantageous?

@josephrocca
Copy link
Contributor

@mattjj Curious if what @jonasrauber said is expected behavior? Should it affect performance in the general case if some "inner" functions have already been jited?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

4 participants