diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 6abb1beac2bc..cbed4eac0a70 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -19,6 +19,7 @@ from jax import linear_util as lu from jax.interpreters import ad from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax.tree_util import (tree_flatten, tree_leaves, tree_map, tree_structure, treedef_tuple, tree_unflatten) @@ -133,6 +134,16 @@ def check_transpose_rule_trees(rule, lin_tree, rule_out_tree): f'Transpose rule output: {rule_out_tree}\n' f'Linear primal inputs: {lin_tree}') +def make_transpose_from_thunk(thunk, lin_tree): + transpose_jaxpr, transpose_consts = thunk() + transpose_jaxpr = core.ClosedJaxpr( + pe.convert_constvars_jaxpr(transpose_jaxpr), ()) + def transpose(res_arg, ct_out): + args_flat = tree_leaves((res_arg, ct_out)) + ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat) + return tree_unflatten(lin_tree, ct_ins) + return transpose + ### custom_transpose primitive and rules @@ -157,8 +168,14 @@ def bind(self, call, *args, **params): # TODO(frostig,mattjj): consider keeping `call` as a named parameter # instead of following this "call primitive" convention. def get_bind_params(self, params): + assert 'call_jaxpr' in params + assert 'transpose_jaxpr_thunk' in params new_params = dict(params) - return [new_params.pop('call')], new_params + new_params['transpose'] = make_transpose_from_thunk( + new_params.pop('transpose_jaxpr_thunk'), + new_params['lin_tree']) + call = lu.wrap_init(core.jaxpr_as_fun(new_params.pop('call_jaxpr'))) + return [call], new_params # TODO(frostig,mattjj): reinstate checks @@ -167,7 +184,16 @@ def custom_transpose_typecheck(*avals, **params): def custom_transpose_transpose_rule( - cts, *args, call, transpose, out_types, res_tree, lin_tree, out_tree): + cts, *args, out_types, res_tree, lin_tree, out_tree, **params): + + if 'transpose_jaxpr_thunk' in params: + assert 'call_jaxpr' in params + transpose = make_transpose_from_thunk( + params['transpose_jaxpr_thunk'], lin_tree) + else: + assert 'call' in params + transpose = params['transpose'] + call_in_tree = treedef_tuple((res_tree, lin_tree)) # TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index a4f2d099cc8e..018a5d4a5e07 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -805,7 +805,7 @@ def switch(index, branches, operand): def _cond(pred, true_fun: Callable, false_fun: Callable, *operands, - operand=_no_operand_sentinel): + operand=_no_operand_sentinel, linear=None): """Conditionally apply ``true_fun`` or ``false_fun``. ``cond()`` has equivalent semantics to this Python implementation:: @@ -865,6 +865,12 @@ def cond(pred, true_fun, false_fun, *operands): return false_fun(*operands) ops, ops_tree = tree_flatten(operands) + if linear is None: + linear_ops = [False] * len(ops) + else: + linear_ops, ops_tree2 = tree_flatten(linear) + if ops_tree != ops_tree2: + raise TypeError('linear tree and operand tree mismatch') ops_avals = tuple(_map(_abstractify, ops)) jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts( @@ -878,10 +884,10 @@ def cond(pred, true_fun, false_fun, *operands): index = lax.convert_element_type(pred, np.int32) - linear = (False,) * (len(consts) + len(ops)) + linear = [False] * len(consts) + linear_ops out = cond_p.bind( index, *consts, *ops, - branches=(false_jaxpr, true_jaxpr), linear=linear) + branches=(false_jaxpr, true_jaxpr), linear=tuple(linear)) return tree_unflatten(out_tree, out) @api_boundary diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 6af9fd2abbe9..3caa733931df 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -164,8 +164,10 @@ def recast_to_float0(primal, tangent): else: return tangent -# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will) -def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in): +# NOTE: The FIXMEs below are caused by primal/tangent mixups (type +# errors if you will) +def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, + consts, primals_in, cotangents_in): if all(type(ct) is Zero for ct in cotangents_in): return map(lambda v: Zero(v.aval), jaxpr.invars) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 3112c4d91062..a7d71846d4ca 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1626,10 +1626,12 @@ def process_custom_transpose(self, prim, call, tracers, transpose_flat, in_tree2 = flatten_fun_nokwargs( lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree))) - transpose_jaxpr, in_avals2, transpose_consts = trace_to_subjaxpr_dynamic( - transpose_flat, self.main, in_avals_t) - closed_transpose_jaxpr = core.ClosedJaxpr( - convert_constvars_jaxpr(transpose_jaxpr), ()) + + main_ = ref(self.main) + # the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts + transpose_jaxpr_thunk = _memoize( + lambda: trace_to_subjaxpr_dynamic( + transpose_flat, main_(), in_avals_t)[::2]) out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] invars = map(self.getvar, tracers) @@ -1637,9 +1639,9 @@ def process_custom_transpose(self, prim, call, tracers, outvars = map(self.makevar, out_tracers) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim, dict(call_jaxpr=closed_call_jaxpr, - transpose_jaxpr=(closed_transpose_jaxpr, - transpose_consts), - num_consts=len(call_consts)), + transpose_jaxpr_thunk=transpose_jaxpr_thunk, + out_types=out_types, res_tree=res_tree, + lin_tree=lin_tree, out_tree=out_tree), source_info_util.current()) self.frame.eqns.append(eqn) return out_tracers diff --git a/tests/api_test.py b/tests/api_test.py index 9999e3be9545..7c8006d91025 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6891,7 +6891,6 @@ def tp(r, t): return 2 * t / r self.assertAllClose(f_t(x), g_t(x)) def test_jit_recursive(self): - raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj) def f(x, y): @custom_transpose(jnp.ones(2)) def fn(r, x): return x / r @@ -6913,6 +6912,56 @@ def tp(r, t): return 2 * fn(r, t) self.assertAllClose(f_(x), g_(x)) self.assertAllClose(f_t(x), g_t(x)) + def test_cond(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * t / r + + return x + fn(y, x) + + def cond_wrap(f): + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x, + linear=(True,)) + + i = 7. + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = partial(cond_wrap(f_), i) + g_t = transpose_unary(g_, x) + + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + + def test_cond_recursive(self): + def f(x, y): + @custom_transpose(jnp.ones(2)) + def fn(r, x): return x / r + @fn.def_transpose + def tp(r, t): return 2 * fn(r, t) + + return x + fn(y, x) + + def cond_wrap(f): + return lambda i, x: lax.cond(i > 0, f, lambda x: x, x, + linear=(True,)) + + i = 7. + x = jnp.ones(2) * 6. + y = jnp.ones(2) * 3. + + f_ = lambda x: f(x, y) + f_t = transpose_unary(f_, x) + g_ = partial(cond_wrap(f_), i) + g_t = transpose_unary(g_, x) + + self.assertAllClose(f_(x), g_(x)) + self.assertAllClose(f_t(x), g_t(x)) + class CustomVmapTest(jtu.JaxTestCase):