-
Notifications
You must be signed in to change notification settings - Fork 3.3k
add core.closed_call_p #10711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add core.closed_call_p #10711
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -222,15 +222,17 @@ def process_call(self, primitive, f, tracers, params): | |
| unknown_arg_tracers = [t for t in tracers if not t.is_known()] | ||
| # Adjust parameters (e.g. donated_invars) for the staged-out call's args. | ||
| num_new_args = len(const_tracers) + len(env_tracers) | ||
| staged_params = update_params(params, map(op.not_, in_knowns), num_new_args) | ||
| staged_params = dict(staged_params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) | ||
| staged_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) | ||
| staged_params = update_params(staged_params, map(op.not_, in_knowns), | ||
| num_new_args) | ||
| # The outputs of the staged-out call are Tracers with the new eqn as recipe. | ||
| out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) | ||
| for a in out_avals] | ||
| name_stack = self._current_truncated_name_stack() | ||
| source = source_info_util.current().replace(name_stack=name_stack) | ||
| eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), | ||
| out_tracers, primitive, staged_params, jaxpr.effects, source) | ||
| out_tracers, primitive, staged_params, jaxpr.effects, | ||
| source) | ||
| for t in out_tracers: t.recipe = eqn | ||
| return merge_lists(out_knowns, out_tracers, out_consts) | ||
|
|
||
|
|
@@ -511,6 +513,12 @@ def partial_eval_wrapper_nounits( | |
| call_partial_eval_rules: Dict[Primitive, Callable] = {} | ||
| call_param_updaters: Dict[Primitive, Callable] = {} | ||
|
|
||
| def _closed_call_param_updater(params, _, __): | ||
| jaxpr = params.get('call_jaxpr') | ||
| if jaxpr is None: return params | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When is jaxpr None here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah good question, I forget... let me see if I can exercise this.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, it's because in This required behavior was covered by the tests in core_test.py.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, thanks! |
||
| assert type(jaxpr) is core.Jaxpr | ||
| return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ())) | ||
| call_param_updaters[core.closed_call_p] = _closed_call_param_updater | ||
|
|
||
| def abstract_eval_fun(fun, *avals, debug_info=None, **params): | ||
| _, avals_out, _ = trace_to_jaxpr_dynamic( | ||
|
|
@@ -666,8 +674,6 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], | |
| # TODO(necula): move these checks to core.check_jaxpr, and call in more places | ||
| if primitive.call_primitive or primitive.map_primitive: | ||
| assert "call_jaxpr" in params | ||
| # assert len(invars) == len(params["call_jaxpr"].invars) # TODO constvars? | ||
| assert len(out_tracers) == len(params["call_jaxpr"].outvars) | ||
| assert ("donated_invars" not in params or | ||
| len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) | ||
| if primitive.map_primitive: | ||
|
|
@@ -1254,6 +1260,20 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn | |
| dce_rules[remat_call_p] = dce_jaxpr_call_rule | ||
|
|
||
|
|
||
| def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn | ||
| ) -> Tuple[List[bool], JaxprEqn]: | ||
| # TODO(mattjj): de-duplicate with above rule? | ||
| jaxpr_ = eqn.params['call_jaxpr'] | ||
| jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts | ||
| new_jaxpr, used_inputs = dce_jaxpr(jaxpr, used_outputs) | ||
| new_params = dict(eqn.params, call_jaxpr=core.ClosedJaxpr(new_jaxpr, consts)) | ||
| new_eqn = new_jaxpr_eqn( | ||
| [v for v, used in zip(eqn.invars, used_inputs) if used], | ||
| [v for v, used in zip(eqn.outvars, used_outputs) if used], | ||
| eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) | ||
| return used_inputs, new_eqn | ||
| dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule | ||
|
|
||
| def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool] | ||
| ) -> ClosedJaxpr: | ||
| """Reorder `invars` by moving those indicated in `to_move` to the front.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm, this check exists because not all call primitives use closed jaxprs yet. When they do, we can delete this.