-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.jit recompiles nested jitted functions #284
Comments
Our trace-caching logic is pretty simple: it's just a What's complicated is the transformations we need to set up to guard against other traces' tracers hiding in function closures. We can add a
we see this (ignoring the stuff that prints out due to
So the second time we're seeing In this case, there's special structure here: But in general, when a function has a non-empty closure, we can't tell whether that's a benign closure (with no hidden traces) or whether that closure contains other tracers (maybe very indirectly, buried in arbitrary Python objects, including closed-over Python function objects) until we actually run the function. And at the point where we call and memoize I'm inclined to err on the side of simplicity and not try to detect this special closure-free structure until we have a use case that needs it. (However, it's possible that @jonasrauber has already articulated possible use cases in other issues, and I just haven't grokked them yet.) @dougalm did I get that right? Should we consider special handling of empty-closure functions, which might mean a special bind_call that is promised there are no traces in the closure? (There's a related todo in ad.py in |
I'll re-mark this as a question and close it, but if we find a use case where we would really like cache hits here, we can reopen (or open a new issue). Thanks for asking this! |
Following up on this, now that I worked a bit more with it: calling jitted functions from within jitted functions does not seem to be a good idea… i.e. not jitting the inner function explicitly (it will still be jitted when the whole outer function is jitted) seems to improve performance… is this a general rule one can keep in mind (i.e. for maximum performance put |
@mattjj Curious if what @jonasrauber said is expected behavior? Should it affect performance in the general case if some "inner" functions have already been |
This isn't really a bug and more of a question, I guess, but who knows.
Why is f being compiled twice in this example:
I guess it has something to do with the level, but I don't really get why a jitted function is recompiled once it gets called from within another jitted function.
The text was updated successfully, but these errors were encountered: