-
Notifications
You must be signed in to change notification settings - Fork 3.3k
add core.closed_call_p #10711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add core.closed_call_p #10711
Conversation
sharadmv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the detailed PR description.
|
|
||
| def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, | ||
| avals_out, tokens_in, *args): | ||
| if isinstance(call_jaxpr, core.Jaxpr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm, this check exists because not all call primitives use closed jaxprs yet. When they do, we can delete this.
|
|
||
| def _closed_call_param_updater(params, _, __): | ||
| jaxpr = params.get('call_jaxpr') | ||
| if jaxpr is None: return params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is jaxpr None here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good question, I forget... let me see if I can exercise this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it's because in JaxprTrace.process_call we actually call the same call_param_updater for both the bind-form and jaxpr-form parameter versions. Usually it's just used to update params like donated_invars, and it doesn't matter whether we're working with the bind-form or the jaxpr-form (e.g. for xla_call).
This required behavior was covered by the tests in core_test.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks!
This PR adds a variant of core.call_p called core.closed_call_p. The only difference is that in its 'jaxpr form' its call_jaxpr parameter is a core.ClosedJaxpr rather than a core.Jaxpr.
Some background:
core.call_pis the most vanilla call primitive possible: unlike, say,xla_call_p(the primitive underlyingjax.jit), its impl rule isn't to compile an XLA computation, but instead it's just to interpret its jaxpr (staying in Python, usingcore.eval_jaxpr). Correspondingly, it doesn't need to raise the abstraction level of its arguments. It's basically a model for other "final-style" call primitives, each of which is interesting in precisely how it deviates fromcore.call_p(e.g.xla_call_p's impl rule stages out for compilation;remat_call_phas a special partial evaluation rule;custom_jvp_call_phas a special JVP rule; etc). Historically it was the first call primitive we introduced, just to test the system;core.call_pis not really used anywhere.core.ClosedJaxpris a data type which would be better named asPartiallyAppliedJaxpr. When we form jaxprs, they usually get paired with "constants" (e.g.trace_to_jaxpr_nounitsandtrace_to_jaxpr_dynamicoutput a list of constants), which are values that are not arguments and that we don't want to turn into literals (e.g. because we want to de-duplicate them, or even just avoid inlining them in pretty-prints). In some cases, these "constants" can becore.Tracers, like when we form the jaxprs forjax.lax.scanand the body function closes over someTracer; when that's possible, becauseTracers have to be handled withcore.Primitive.bind, we typically just convert them to arguments (viape.convert_constvars_jaxpr). But in other cases the constants that come out can't beTracers (e.g. in the JVP rule of an initial-style primitive, when we runad.jvp_jaxpr, we can get new constants out which can't beTracers and must be raw array values). That's whencore.ClosedJaxprcomes in handy: it lets us pair ajaxprwith some array constants so that the caller, e.g. a JVP rule for an initial-style higher-order primitive, doesn't need to deal with handling new constant values and their input binders. In other words, primitives which are parameterized byClosedJaxprs can have simpler rules, especially jaxpr-to-jaxpr rules, since those rules don't need to worry about handling new constants/binders introduced by the rule.On that last point, when working on #10576 we ran into a situation where
the current signature for "custom-policy partial eval rules" didn't allow a custom partial evaluation rule to introduce new constants (because such rules just get to output a pair of
Optional[JaxprEqn]s and have no output for "new constants for the caller to handle appropriately");but to perform an optimization, namely hoisting loop-invariant residual computations out of a
scanbody, we might need such a rule to introduce multiple equations as well as new constants.To proceed, there were at least two options:
make the signature for custom-policy partial evaluation rules even more complex (to support outputting multiple equations, new variable names being introduced, new constants, etc)
just use a call primitive to handle the "multiple equations with new variables" problem, and as long as it was a call primitive with a
ClosedJaxprit would handle the constants problem too.I chose the second approach, which led to this PR.
For simplicity, we could delete
core.call_pin favor of thiscore.closed_call_p; after all, the former is not used at all. Going further, we might want to make all higher-order primitives (i.e. even the final-style ones, not just the initial style ones as at present) takeClosedJaxprs rather thanJaxprs; futher still, at that point we could de-duplicateJaxprandClosedJaxprso that we only have one such type. Those simplifications sound reasonable, but they're out of scope for this PR. Here I just want to land a change for enabling the newrematimplementation withscaninside!Finally, some notes on the changes here. Final-style primitives (like the new
closed_call_p) have two forms, with different parameters: the 'bind form' used during tracing which takes a Python callable as a parameter representing the function to be called (really alinear_util.WrappedFun), and the 'jaxpr form' which appears in a jaxpr which itself takes aJaxpr(or after this PR alternatively aClosedJaxpr). Since we're introducing a primitive which is likecore.call_pexcept that it takes aClosedJaxprparameter, we need toupdate places where the bind-form primitive is converted to the jaxpr-form primitive (i.e.
JaxprTrace.process_callandDynamicJaxprTrace.process_callin partial_eval.py, both of which can be handled by using the existing "call param updater" hook) to actually produce aClosedJaxprparameter;update places where the jaxpr-form is converted to the bind-form (namely
ClosedCallPrimitive.get_bind_paramsin core.py)update rules which consume the jaxpr-form to handle the
ClosedJaxprparameter (namely the MLIR lowering rule in mlir.py, the transpose rule in ad.py, the typecheck rule in core.py, the DCE rule in partial_eval.py, and (once it exists for any calls) the forwarding rule in partial_eval.py); note that we do not need to update rules which consume the bind form (e.g.JVPTrace.process_callorBatchTrace.process_call) since the bind forms ofcall_pandclosed_call_pare identical;update core_test.py to cover the new call primitive.
Only the second-to-last bullet seems burdensome. That would be mitigated by moving to make all call primitives take
ClosedJaxprparameters, which I think was already a good idea. But again that's out of scope!