Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import namedtuple
from contextlib import contextmanager
import functools
from functools import partialmethod, total_ordering
from functools import partial, partialmethod, total_ordering
import gc
import itertools as it
import operator
Expand Down Expand Up @@ -207,8 +207,6 @@ def replace(self, *args, **kwargs):
return self._replace(*args, **kwargs)

def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
if primitive.call_primitive:
assert len(outvars) == len(params["call_jaxpr"].outvars)
source_info = source_info or source_info_util.new_source_info()
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)

Expand Down Expand Up @@ -1822,6 +1820,17 @@ def call_impl(f: lu.WrappedFun, *args, **params):
named_call_p.def_impl(call_impl)


class ClosedCallPrimitive(CallPrimitive):
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('call_jaxpr')
subfun = lu.wrap_init(partial(eval_jaxpr, jaxpr.jaxpr, jaxpr.consts))
return [subfun], new_params

closed_call_p: ClosedCallPrimitive = ClosedCallPrimitive('closed_call')
closed_call_p.def_impl(call_impl)


outfeed_primitives: Set[Primitive] = set()
def jaxpr_uses_outfeed(jaxpr: Jaxpr) -> bool:
"""Finds if there are outfeed primitives anywhere inside a Jaxpr."""
Expand Down Expand Up @@ -2169,6 +2178,12 @@ class JaxprTypeError(TypeError): pass

custom_typechecks: Dict[Primitive, Callable] = {}

def _check_closed_call(*in_avals, call_jaxpr):
if list(in_avals) != list(call_jaxpr.in_avals):
raise JaxprTypeError("Closed call in_avals mismatch")
return call_jaxpr.out_avals, call_jaxpr.effects
custom_typechecks[closed_call_p] = _check_closed_call

def check_jaxpr(jaxpr: Jaxpr):
"""Checks well-formedness of a jaxpr.

Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"reduce_precision",
"schur",
"name",
"closed_call",
"unreachable",
"bint",
"getslice",
Expand Down
10 changes: 10 additions & 0 deletions jax/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,16 @@ def call_transpose(primitive, params, call_jaxpr, args, ct, _, reduce_axes):
out_flat = primitive.bind(fun, *all_args, **params)
return tree_unflatten(out_tree(), out_flat)
primitive_transposes[core.call_p] = partial(call_transpose, call_p)
primitive_transposes[core.named_call_p] = \
partial(call_transpose, core.named_call_p)


def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals, reduce_axes):
jaxpr_, consts = jaxpr.jaxpr, jaxpr.consts
jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_)
return call_transpose(core.closed_call_p, params, jaxpr_, (*consts, *args),
ct, cts_in_avals, reduce_axes)
primitive_transposes[core.closed_call_p] = _closed_call_transpose


def remat_transpose(params, call_jaxpr, primals_in, cotangents_in,
Expand Down
8 changes: 5 additions & 3 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,13 +986,13 @@ def f_lowered(ctx, *args, **params):

def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in,
avals_out, tokens_in, *args):
if isinstance(call_jaxpr, core.Jaxpr):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, this check exists because not all call primitives use closed jaxprs yet. When they do, we can delete this.

call_jaxpr = core.ClosedJaxpr(call_jaxpr, ())
xla.check_backend_matches(backend, ctx.platform)
output_types = map(aval_to_ir_types, avals_out)
flat_output_types = util.flatten(output_types)
effects = tokens_in.effects()
symbol_name = lower_jaxpr_to_fun(ctx, fn_name,
core.ClosedJaxpr(call_jaxpr, ()),
effects).name.value
symbol_name = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects).name.value
args = [*tokens_in.tokens(), *args]
call = func_dialect.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(symbol_name),
Expand Down Expand Up @@ -1024,6 +1024,8 @@ def _named_call_lowering(ctx, *args, name, backend=None,

register_lowering(core.named_call_p, _named_call_lowering)
register_lowering(core.call_p, partial(_named_call_lowering, name="core_call"))
register_lowering(core.closed_call_p,
partial(_named_call_lowering, name="core_closed_call"))


def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
Expand Down
30 changes: 25 additions & 5 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,17 @@ def process_call(self, primitive, f, tracers, params):
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
# Adjust parameters (e.g. donated_invars) for the staged-out call's args.
num_new_args = len(const_tracers) + len(env_tracers)
staged_params = update_params(params, map(op.not_, in_knowns), num_new_args)
staged_params = dict(staged_params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
staged_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr))
staged_params = update_params(staged_params, map(op.not_, in_knowns),
num_new_args)
# The outputs of the staged-out call are Tracers with the new eqn as recipe.
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
name_stack = self._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
out_tracers, primitive, staged_params, jaxpr.effects, source)
out_tracers, primitive, staged_params, jaxpr.effects,
source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)

Expand Down Expand Up @@ -511,6 +513,12 @@ def partial_eval_wrapper_nounits(
call_partial_eval_rules: Dict[Primitive, Callable] = {}
call_param_updaters: Dict[Primitive, Callable] = {}

def _closed_call_param_updater(params, _, __):
jaxpr = params.get('call_jaxpr')
if jaxpr is None: return params
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is jaxpr None here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good question, I forget... let me see if I can exercise this.

Copy link
Collaborator Author

@mattjj mattjj May 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's because in JaxprTrace.process_call we actually call the same call_param_updater for both the bind-form and jaxpr-form parameter versions. Usually it's just used to update params like donated_invars, and it doesn't matter whether we're working with the bind-form or the jaxpr-form (e.g. for xla_call).

This required behavior was covered by the tests in core_test.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks!

assert type(jaxpr) is core.Jaxpr
return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
call_param_updaters[core.closed_call_p] = _closed_call_param_updater

def abstract_eval_fun(fun, *avals, debug_info=None, **params):
_, avals_out, _ = trace_to_jaxpr_dynamic(
Expand Down Expand Up @@ -666,8 +674,6 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
# assert len(invars) == len(params["call_jaxpr"].invars) # TODO constvars?
assert len(out_tracers) == len(params["call_jaxpr"].outvars)
assert ("donated_invars" not in params or
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
if primitive.map_primitive:
Expand Down Expand Up @@ -1254,6 +1260,20 @@ def dce_jaxpr_call_rule(used_outputs: List[bool], eqn: JaxprEqn
dce_rules[remat_call_p] = dce_jaxpr_call_rule


def dce_jaxpr_closed_call_rule(used_outputs: List[bool], eqn: JaxprEqn
) -> Tuple[List[bool], JaxprEqn]:
# TODO(mattjj): de-duplicate with above rule?
jaxpr_ = eqn.params['call_jaxpr']
jaxpr, consts = jaxpr_.jaxpr, jaxpr_.consts
new_jaxpr, used_inputs = dce_jaxpr(jaxpr, used_outputs)
new_params = dict(eqn.params, call_jaxpr=core.ClosedJaxpr(new_jaxpr, consts))
new_eqn = new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
return used_inputs, new_eqn
dce_rules[core.closed_call_p] = dce_jaxpr_closed_call_rule

def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]
) -> ClosedJaxpr:
"""Reorder `invars` by moving those indicated in `to_move` to the front."""
Expand Down
4 changes: 0 additions & 4 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,3 @@ def f(*args, **kw):
"Add an MLIR (MHLO) lowering via jax.interpreters.mlir "
"instead.")
return f


ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,
core.named_call_p)
10 changes: 10 additions & 0 deletions tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def core_call(f, *args):
out = core.call_p.bind(f, *args)
return tree_unflatten(out_tree(), out)

@util.curry
def core_closed_call(f, *args):
args, in_tree = tree_flatten(args)
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
out = core.closed_call_p.bind(f, *args)
return tree_unflatten(out_tree(), out)

def simple_fun(x, y):
return jnp.sin(x * y)

Expand Down Expand Up @@ -147,6 +154,9 @@ def jvp_unlinearized(f, primals, tangents):
test_specs.append(CallSpec(core_call(ts.fun), ts.args))
test_specs.append(CallSpec(core_call(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_call(core_call(ts.fun)), ts.args))
test_specs.append(CallSpec(core_closed_call(ts.fun), ts.args))
test_specs.append(CallSpec(core_closed_call(jit(ts.fun)), ts.args))
test_specs.append(CallSpec(core_closed_call(core_closed_call(ts.fun)), ts.args))
test_specs.append(CallSpec(partial(jvp_unlinearized, ts.fun),
(ts.args, ts.args)))

Expand Down