-
Notifications
You must be signed in to change notification settings - Fork 3.3k
ahead-of-time lowering and compilation for jit #7997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this have more caching? I would expect calling jit(f).lower(*args).compile() to cause cache hits when I jit(f)(*args) later.
|
Yep, that's part of the plan. I might scope this PR to a cache-free version of things first, to get the refactoring and basic functionality in place. |
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK this is only supposed to be a first step and to that end LGTM. I left some comments that you might want to address, but it's fine by me to do it later.
| out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend, | ||
| name=flat_fun.__name__, donated_invars=donated_invars, inline=inline) | ||
| return tree_unflatten(out_tree(), out) | ||
| closed_fun, in_tree, args_flat, donated_invars = _prepare_jit( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider returning a dict you later splat as kwargs? that's what I do with xmap and pjit to avoid messing up the order of returns/binders
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also called it _infer_params, which I think is a better description for this function than _prepare_jit, but I don't have a super strong preference
This is an initial take on #7733 corresponding to
jax.jitExample usage:
Note the compiler optimization between lowering and compilation, from XLA.
We have yet to determine exactly how to expose (functions of) the lowered computation and compiled executable, so for now the example above reaches for the underlying computation and executable directly, via a hidden method (leading underscore).