Skip to content

Commit

Permalink
allow for recursive uses of custom_transpose
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
froystig and mattjj committed Mar 26, 2022
1 parent f7df3ee commit a6a43e2
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 15 deletions.
30 changes: 28 additions & 2 deletions jax/_src/custom_transpose.py
Expand Up @@ -19,6 +19,7 @@
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
Expand Down Expand Up @@ -133,6 +134,16 @@ def check_transpose_rule_trees(rule, lin_tree, rule_out_tree):
f'Transpose rule output: {rule_out_tree}\n'
f'Linear primal inputs: {lin_tree}')

def make_transpose_from_thunk(thunk, lin_tree):
transpose_jaxpr, transpose_consts = thunk()
transpose_jaxpr = core.ClosedJaxpr(
pe.convert_constvars_jaxpr(transpose_jaxpr), ())
def transpose(res_arg, ct_out):
args_flat = tree_leaves((res_arg, ct_out))
ct_ins = core.jaxpr_as_fun(transpose_jaxpr)(*transpose_consts, *args_flat)
return tree_unflatten(lin_tree, ct_ins)
return transpose


### custom_transpose primitive and rules

Expand All @@ -157,8 +168,14 @@ def bind(self, call, *args, **params):
# TODO(frostig,mattjj): consider keeping `call` as a named parameter
# instead of following this "call primitive" convention.
def get_bind_params(self, params):
assert 'call_jaxpr' in params
assert 'transpose_jaxpr_thunk' in params
new_params = dict(params)
return [new_params.pop('call')], new_params
new_params['transpose'] = make_transpose_from_thunk(
new_params.pop('transpose_jaxpr_thunk'),
new_params['lin_tree'])
call = lu.wrap_init(core.jaxpr_as_fun(new_params.pop('call_jaxpr')))
return [call], new_params


# TODO(frostig,mattjj): reinstate checks
Expand All @@ -167,7 +184,16 @@ def custom_transpose_typecheck(*avals, **params):


def custom_transpose_transpose_rule(
cts, *args, call, transpose, out_types, res_tree, lin_tree, out_tree):
cts, *args, out_types, res_tree, lin_tree, out_tree, **params):

if 'transpose_jaxpr_thunk' in params:
assert 'call_jaxpr' in params
transpose = make_transpose_from_thunk(
params['transpose_jaxpr_thunk'], lin_tree)
else:
assert 'call' in params
transpose = params['transpose']

call_in_tree = treedef_tuple((res_tree, lin_tree))

# TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
Expand Down
12 changes: 9 additions & 3 deletions jax/_src/lax/control_flow.py
Expand Up @@ -805,7 +805,7 @@ def switch(index, branches, operand):


def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
operand=_no_operand_sentinel):
operand=_no_operand_sentinel, linear=None):
"""Conditionally apply ``true_fun`` or ``false_fun``.
``cond()`` has equivalent semantics to this Python implementation::
Expand Down Expand Up @@ -865,6 +865,12 @@ def cond(pred, true_fun, false_fun, *operands):
return false_fun(*operands)

ops, ops_tree = tree_flatten(operands)
if linear is None:
linear_ops = [False] * len(ops)
else:
linear_ops, ops_tree2 = tree_flatten(linear)
if ops_tree != ops_tree2:
raise TypeError('linear tree and operand tree mismatch')
ops_avals = tuple(_map(_abstractify, ops))

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
Expand All @@ -878,10 +884,10 @@ def cond(pred, true_fun, false_fun, *operands):

index = lax.convert_element_type(pred, np.int32)

linear = (False,) * (len(consts) + len(ops))
linear = [False] * len(consts) + linear_ops
out = cond_p.bind(
index, *consts, *ops,
branches=(false_jaxpr, true_jaxpr), linear=linear)
branches=(false_jaxpr, true_jaxpr), linear=tuple(linear))
return tree_unflatten(out_tree, out)

@api_boundary
Expand Down
6 changes: 4 additions & 2 deletions jax/interpreters/ad.py
Expand Up @@ -164,8 +164,10 @@ def recast_to_float0(primal, tangent):
else:
return tangent

# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack, consts, primals_in, cotangents_in):
# NOTE: The FIXMEs below are caused by primal/tangent mixups (type
# errors if you will)
def backward_pass(jaxpr: core.Jaxpr, reduce_axes, transform_stack,
consts, primals_in, cotangents_in):
if all(type(ct) is Zero for ct in cotangents_in):
return map(lambda v: Zero(v.aval), jaxpr.invars)

Expand Down
16 changes: 9 additions & 7 deletions jax/interpreters/partial_eval.py
Expand Up @@ -1626,20 +1626,22 @@ def process_custom_transpose(self, prim, call, tracers,

transpose_flat, in_tree2 = flatten_fun_nokwargs(
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
transpose_jaxpr, in_avals2, transpose_consts = trace_to_subjaxpr_dynamic(
transpose_flat, self.main, in_avals_t)
closed_transpose_jaxpr = core.ClosedJaxpr(
convert_constvars_jaxpr(transpose_jaxpr), ())

main_ = ref(self.main)
# the following thunk evaluates to a pair: transpose_jaxpr, transpose_consts
transpose_jaxpr_thunk = _memoize(
lambda: trace_to_subjaxpr_dynamic(
transpose_flat, main_(), in_avals_t)[::2])

out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, call_consts))
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
dict(call_jaxpr=closed_call_jaxpr,
transpose_jaxpr=(closed_transpose_jaxpr,
transpose_consts),
num_consts=len(call_consts)),
transpose_jaxpr_thunk=transpose_jaxpr_thunk,
out_types=out_types, res_tree=res_tree,
lin_tree=lin_tree, out_tree=out_tree),
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers
Expand Down
51 changes: 50 additions & 1 deletion tests/api_test.py
Expand Up @@ -6891,7 +6891,6 @@ def tp(r, t): return 2 * t / r
self.assertAllClose(f_t(x), g_t(x))

def test_jit_recursive(self):
raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj)
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
Expand All @@ -6913,6 +6912,56 @@ def tp(r, t): return 2 * fn(r, t)
self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))

def test_cond(self):
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return 2 * t / r

return x + fn(y, x)

def cond_wrap(f):
return lambda i, x: lax.cond(i > 0, f, lambda x: x, x,
linear=(True,))

i = 7.
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.

f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
g_ = partial(cond_wrap(f_), i)
g_t = transpose_unary(g_, x)

self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))

def test_cond_recursive(self):
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return 2 * fn(r, t)

return x + fn(y, x)

def cond_wrap(f):
return lambda i, x: lax.cond(i > 0, f, lambda x: x, x,
linear=(True,))

i = 7.
x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.

f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
g_ = partial(cond_wrap(f_), i)
g_t = transpose_unary(g_, x)

self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))


class CustomVmapTest(jtu.JaxTestCase):

Expand Down

0 comments on commit a6a43e2

Please sign in to comment.