Skip to content

Commit

Permalink
staging and compilation for custom_transpose
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
froystig and mattjj committed Mar 17, 2022
1 parent 5354a01 commit 45af307
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 4 deletions.
14 changes: 14 additions & 0 deletions jax/_src/custom_transpose.py
Expand Up @@ -18,6 +18,8 @@
from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
from jax._src import ad_util
Expand Down Expand Up @@ -187,6 +189,18 @@ def custom_transpose_transpose_rule(
return [None] * len(tree_leaves(res_arg)) + ct_lin_flat


def custom_transpose_lowering(*args, call_jaxpr, **params):
return core.jaxpr_as_fun(call_jaxpr)(*args)


custom_transpose_p = CustomTransposePrimitive('custom_transpose_call')
core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
mlir.register_lowering(
custom_transpose_p,
mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
xla.register_translation(
custom_transpose_p,
xla.lower_fun(
custom_transpose_lowering, new_style=True, multiple_results=True),
initial_style=True)
40 changes: 38 additions & 2 deletions jax/interpreters/partial_eval.py
Expand Up @@ -30,8 +30,9 @@
from jax import linear_util as lu
from jax._src import profiler
from jax._src.ad_util import Zero
from jax._src.api_util import flattened_fun_in_tree
from jax._src.tree_util import PyTreeDef, tree_unflatten, tree_leaves
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_leaves)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, cache, OrderedSet,
as_hashable_function)
Expand Down Expand Up @@ -1610,6 +1611,41 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
def post_process_custom_vjp_call(self, out_tracers, _):
assert False # unreachable

def process_custom_transpose(self, prim, call, tracers,
transpose, out_types,
lin_tree, res_tree, out_tree):
tracers_res, tracers_lin = split_list(tracers, [res_tree.num_leaves])

in_avals_p = [t.aval for t in tracers]
in_avals_t = [*[t.aval for t in tracers_res], *out_types]

with core.new_sublevel():
call_jaxpr, out_avals, call_consts = trace_to_subjaxpr_dynamic(
call, self.main, in_avals_p)
closed_call_jaxpr = core.ClosedJaxpr(
convert_constvars_jaxpr(call_jaxpr), ())

transpose_flat, in_tree2 = flatten_fun_nokwargs(
lu.wrap_init(transpose), treedef_tuple((res_tree, out_tree)))
transpose_jaxpr, in_avals2, transpose_consts = trace_to_subjaxpr_dynamic(
transpose_flat, self.main, in_avals_t)
closed_transpose_jaxpr = core.ClosedJaxpr(
convert_constvars_jaxpr(transpose_jaxpr), ())

out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
invars = map(self.getvar, tracers)
constvars = map(self.getvar, map(self.instantiate_const, call_consts))
outvars = map(self.makevar, out_tracers)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim,
dict(call_jaxpr=closed_call_jaxpr,
transpose_jaxpr=(closed_transpose_jaxpr,
transpose_consts),
num_consts=len(call_consts)),
source_info_util.current())
self.frame.eqns.append(eqn)
return out_tracers


custom_staging_rules: Dict[Primitive, Callable] = {}

def _memoize(thunk):
Expand Down
49 changes: 47 additions & 2 deletions tests/api_test.py
Expand Up @@ -6778,14 +6778,55 @@ def f_ref(x, y): return x + x / y
self.assertAllClose(api.jvp(f, [x, y], [tx, ty]),
api.jvp(f_ref, [x, y], [tx, ty]))

def test_make_jaxpr(self):
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return 2 * t / r

return x + fn(y, x)

x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)

jaxpr = api.make_jaxpr(f_)(x)
self.assertIn('custom_transpose_call', str(jaxpr))

jaxpr_t = api.make_jaxpr(f_t)(x)
self.assertNotIn('custom_transpose_call', str(jaxpr_t))

def test_jit(self):
raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj)
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return 2 * t / r

return x + fn(y, x)

x = jnp.ones(2) * 6.
y = jnp.ones(2) * 3.
self.assertAllClose(f(x, y), jax.jit(f)(x, y))

f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
g_ = jax.jit(f_)
g_t = transpose_unary(g_, x)
self.assertAllClose(f_(x), jax.jit(f_)(x))
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))

def test_jit_recursive(self):
raise unittest.SkipTest('unimplemented') # TODO(frostig,mattjj)
def f(x, y):
@custom_transpose(jnp.ones(2))
def fn(r, x): return x / r
@fn.def_transpose
def tp(r, t): return t / r
def tp(r, t): return 2 * fn(r, t)

return x + fn(y, x)

Expand All @@ -6795,8 +6836,12 @@ def tp(r, t): return t / r

f_ = lambda x: f(x, y)
f_t = transpose_unary(f_, x)
g_ = jax.jit(f_)
g_t = transpose_unary(g_, x)
self.assertAllClose(f_(x), jax.jit(f_)(x))
self.assertAllClose(f_t(x), jax.jit(f_t)(x))
self.assertAllClose(f_(x), g_(x))
self.assertAllClose(f_t(x), g_t(x))


class CustomVmapTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 45af307

Please sign in to comment.