Skip to content

Commit

Permalink
re-implement custom_transpose without upfront staging.
Browse files Browse the repository at this point in the history
Whereas the previous `custom_transpose` implementation would stage its
callable arguments upfront, this one preserves them as callables. For
the time being, this requires callers to additionally supply the target
function's output types at call time.

Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
froystig and mattjj committed Mar 5, 2022
1 parent 2a3f936 commit 947b7b8
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 80 deletions.
1 change: 1 addition & 0 deletions jax/_src/api.py
Expand Up @@ -2601,6 +2601,7 @@ def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
"[float or complex], and integer -> integer functions, "
f"but got {in_dtypes} -> {out_dtypes}.")

@api_boundary
def transposed_fun(consts, out_cotangent):
out_cotangents, out_tree2 = tree_flatten(out_cotangent)
if out_tree() != out_tree2:
Expand Down
176 changes: 115 additions & 61 deletions jax/_src/custom_transpose.py
Expand Up @@ -13,22 +13,19 @@
# limitations under the License.

import functools
from typing import Callable, Optional
from typing import Any, Callable, Optional, Tuple

from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_leaves, tree_unflatten,
treedef_tuple)
from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
from jax._src import ad_util
from jax._src import api_util
from jax._src import custom_api_util
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.api_util import flatten_fun_nokwargs


source_info_util.register_exclusion(__file__)
Expand All @@ -39,15 +36,37 @@
zip, unsafe_zip = util.safe_zip, zip


### bespoke linear_util and api_util deviations

class StoreEqual(lu.Store):
"""Stores an unchanging value. Checks empty reads and unequal overwrites."""
def store(self, val):
if self._val is not lu._EMPTY_STORE_VALUE and val != self._val:
raise lu.StoreException(
f"Store assignment mismatch, from {self._val} to {val}")
self._val = val

@util.curry
def transformation_with_aux(
gen, fun: lu.WrappedFun, *gen_static_args) -> Tuple[lu.WrappedFun, Any]:
out_store = StoreEqual()
out_thunk = lambda: out_store.val
return fun.wrap(gen, gen_static_args, out_store), out_thunk

flatten_fun_nokwargs = transformation_with_aux(
api_util.flatten_fun_nokwargs.args[0]) # type: ignore[has-type]


### api

@custom_api_util.register_custom_decorator_type
class custom_transpose:
fun: Callable
transpose: Optional[Callable]
transpose: Optional[Callable] = None

def __init__(self, fun: Callable):
functools.update_wrapper(self, fun)
self.fun = fun # type: ignore[assignment]
self.transpose = None

__getattr__ = custom_api_util.forward_attr

Expand All @@ -56,83 +75,118 @@ def def_transpose(self, transpose: Callable):
return transpose

@traceback_util.api_boundary
def __call__(self, residual_arg, linear_arg):
res_arg, lin_arg = residual_arg, linear_arg
def __call__(self, out_types, res_arg, lin_arg):
_, res_tree = tree_flatten(res_arg)
_, lin_tree = tree_flatten(lin_arg)
args_flat, in_tree = tree_flatten((res_arg, lin_arg))

flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose")
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
assert not len(consts)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
out_flat = custom_transpose_p.bind(*consts, *args_flat,
call=closed_call,
rule=self.transpose,
# TODO(frostig,mattjj): check that out_trees match
# TODO(frostig,mattjj): could, and should, we avoid flattening
# self.fun at this point?

flat_fun, out_tree2 = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
out_types_flat, out_tree = tree_flatten(out_types)
out_flat = custom_transpose_p.bind(flat_fun, *args_flat,
transpose=self.transpose,
out_types=out_types_flat,
lin_tree=lin_tree,
res_tree=res_tree,
out_tree=out_tree())
return tree_unflatten(out_tree(), out_flat)
out_tree=out_tree)
return tree_unflatten(out_tree, out_flat)


### utils

def rule_name(rule):
return getattr(rule, '__name__', '<unnamed transpose rule>')
def tree_fill(x, treedef):
return tree_unflatten(treedef, [x] * treedef.num_leaves)

def check_transpose_rule_trees(rule, lin_tree, rule_out_tree):
if lin_tree != rule_out_tree and len(lin_tree.children()) == 1:
lin_tree2, = lin_tree.children()
else:
lin_tree2 = lin_tree
if lin_tree2 != rule_out_tree:
raise ValueError(
'structure of custom transpose rule\'s output does not match '
'structure of primal function\'s linear inputs under '
f'custom transpose rule ({rule_name(rule)}).\n'
f'Transpose rule output: {rule_out_tree}\n'
f'Linear primal inputs: {lin_tree}')
def tree_fill_like(x, tree):
return tree_fill(x, tree_structure(tree))

def tree_broadcast(full_treedef, tree, is_leaf=None):
full_tree = tree_fill(0, full_treedef)
return tree_map(tree_fill_like, tree, full_tree, is_leaf=is_leaf)

### custom_transpose_p rules
def is_treedef_prefix(entire, prefix):
entire = tree_fill(0, entire)
prefix = tree_fill(0, prefix)
try:
tree_map(lambda x, y: x, prefix, entire)
except ValueError:
return False
return True

def rule_name(rule):
return getattr(rule, '__name__', '<unnamed transpose rule>')

def custom_transpose_impl(*args, call, rule, res_tree, lin_tree, out_tree):
del rule, res_tree, lin_tree, out_tree
return core.jaxpr_as_fun(call)(*args)
def check_transpose_rule_trees(rule, lin_tree, rule_out_tree):
if not is_treedef_prefix(lin_tree, rule_out_tree):
if hasattr(rule, '_transpose_type_error'):
raise rule._transpose_type_error(lin_tree, rule_out_tree)
else:
raise TypeError(
'structure of custom transpose rule\'s output does not prefix-match '
'structure of primal function\'s linear inputs under '
f'custom transpose rule ({rule_name(rule)}).\n'
f'Transpose rule output: {rule_out_tree}\n'
f'Linear primal inputs: {lin_tree}')


### custom_transpose primitive and rules

class CustomTransposePrimitive(core.Primitive):
call_primitive = False
map_primitive = False
multiple_results = True

def bind(self, call, *args, **params):
# TODO(frostig,mattjj): This doesn't handle closures yet, which is
# a bit involved. Closures are complicated by us binding `call`
# twice in the JVP rule for custom transpose. The `env_trace_todo`
# output by `process_env_traces` due to one of those two bindings
# should be passable to the other, and need to be passed onward
# since the second bind is deferred by partial eval (since it
# typically receives unknowns)
top_trace = core.find_top_trace(args)
tracers = map(top_trace.full_raise, args)
outs = top_trace.process_custom_transpose(self, call, tracers, **params)
return outs

# TODO(frostig,mattjj): consider keeping `call` as a named parameter
# instead of following this "call primitive" convention.
def get_bind_params(self, params):
new_params = dict(params)
return [new_params.pop('call')], new_params


# TODO(frostig,mattjj): reinstate checks
def custom_transpose_typecheck(*avals, **params):
pass


def custom_transpose_transpose_rule(
cts, *args, call, rule, res_tree, lin_tree, out_tree):
cts, *args, call, transpose, out_types, res_tree, lin_tree, out_tree):
call_in_tree = treedef_tuple((res_tree, lin_tree))

# TODO(frostig,mattjj): `lin_arg` indicates the inputs with respect
# to which we are transposing (via `ad.is_undefined_primal`).
# Consider passing this information to the custom transpose rule?

res_arg, lin_arg = tree_unflatten(call_in_tree, args)
assert all(ad.is_undefined_primal(x) for x in tree_leaves(lin_arg))
del lin_arg
assert all(not ad.is_undefined_primal(x) for x in tree_leaves(res_arg))

cts = [ad_util.zeros_like_aval(ct_aval) if type(ct) is ad_util.Zero else ct
for ct, ct_aval in zip(cts, call.out_avals)]
cts = [ad_util.zeros_like_aval(ct.aval) if type(ct) is ad_util.Zero else ct
for ct in cts]
ct_out = tree_unflatten(out_tree, cts)
ct_lin = rule(res_arg, ct_out)
ct_lin_flat, ct_lin_tree = tree_flatten(ct_lin)
check_transpose_rule_trees(rule, lin_tree, ct_lin_tree)
ct_lin = transpose(res_arg, ct_out)
check_transpose_rule_trees(transpose, lin_tree, tree_structure(ct_lin))
ct_lin_flat, _ = tree_flatten(
tree_broadcast(lin_tree, ct_lin, is_leaf=lambda x: x is None),
is_leaf=lambda x: x is None)
return [None] * len(tree_leaves(res_arg)) + ct_lin_flat


def custom_transpose_abstract_eval(*in_avals, call, **_):
return call.out_avals


custom_transpose_p = core.Primitive('custom_transpose_call')
custom_transpose_p.multiple_results = True
custom_transpose_p.def_impl(custom_transpose_impl)
custom_transpose_p.def_abstract_eval(custom_transpose_abstract_eval)
custom_transpose_p = CustomTransposePrimitive('custom_transpose_call')
core.custom_typechecks[custom_transpose_p] = custom_transpose_typecheck
ad.primitive_transposes[custom_transpose_p] = custom_transpose_transpose_rule
xla.register_translation(custom_transpose_p,
xla.lower_fun(custom_transpose_impl, new_style=True,
multiple_results=True),
initial_style=True)
mlir.register_lowering(custom_transpose_p, mlir.lower_fun(
custom_transpose_impl, multiple_results=True))
16 changes: 13 additions & 3 deletions jax/core.py
Expand Up @@ -420,6 +420,11 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
"to handle custom_jvp primitives")
raise NotImplementedError(msg)

def process_custom_transpose(self, prim, call, tracers, **params):
msg = (f"{type(self)} must override process_custom_transpose "
"to handle custom_transpose_call primitives")
raise NotImplementedError(msg)

def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
msg = (f"{type(self)} must override process_custom_vjp_call "
"to handle custom_vjp primitives")
Expand Down Expand Up @@ -607,13 +612,18 @@ def process_call(self, primitive, f, tracers, params):
return primitive.impl(f, *tracers, **params)
process_map = process_call

def process_custom_transpose(self, primitive, call, tracers, **_):
del primitive
with new_sublevel():
return call.call_wrapped(*tracers)

def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
del primitive, jvp # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)

def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
del primitive, fwd, bwd, out_trees # Unused.
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, **_):
del primitive, fwd, bwd # Unused.
with new_sublevel():
return fun.call_wrapped(*tracers)

Expand Down Expand Up @@ -1709,7 +1719,7 @@ def call_bind(primitive: CallPrimitive, fun, *args, **params):
return map(full_lower, apply_todos(env_trace_todo(), outs))

@lu.transformation_with_aux
def process_env_traces_call(primitive: CallPrimitive, level: int,
def process_env_traces_call(primitive: CallPrimitive, level: Optional[int],
params_tuple: tuple, *args):
outs = yield args, {}
params = dict(params_tuple)
Expand Down
32 changes: 32 additions & 0 deletions jax/interpreters/ad.py
Expand Up @@ -385,6 +385,38 @@ def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
def post_process_custom_vjp_call(self, out_tracers, _):
raise CustomVJPException()

def process_custom_transpose(self, prim, call, tracers, **params):
ps_in, ts_in = unzip2((t.primal, t.tangent) for t in tracers)
res_ps_in, lin_ps_in = split_list(ps_in, [params['res_tree'].num_leaves])
res_ts_in, lin_ts_in = split_list(ts_in, [params['res_tree'].num_leaves])

# TODO(frostig): Handle differentiation with respect to residual
# operands. Calling `call` twice on all operands invalid, since it
# isn't linear in the residuals. However, we know that if we
# write:
#
# jvp_call_res = lambda x: partial(jvp, lambda r: call(r, x))
#
# then:
#
# jvp(call, (r, x), (dr, dx)) == jvp_call_res(x)(r, dr) + call(r, dx)
#
# In words: a possible strategy is to take the jvp of `call` with
# respect to residuals, and with linear arguments fixed, then add
# that to a custom-transpose call to `call` (i.e. what we already
# do below in the all-linear argument case).

if any(type(t) is not Zero for t in res_ts_in):
raise NotImplementedError(
'JVP of custom transpose with respect to non-symbolic-zero residuals')

ps_out = prim.bind(call, *ps_in, **params)

lin_ts_in = map(instantiate_zeros, lin_ts_in)
ts_out = prim.bind(call, *res_ps_in, *lin_ts_in, **params)

return map(partial(JVPTracer, self), ps_out, ts_out)

def join(self, xt, yt):
xz, yz = type(xt) is Zero, type(yt) is Zero
if xz == yz:
Expand Down
18 changes: 18 additions & 0 deletions jax/interpreters/partial_eval.py
Expand Up @@ -422,6 +422,24 @@ def post_process_custom_jvp_call(self, out_tracers, _):
# respect to values over which a custom_jvp function closes is detected.
raise NotImplementedError # TODO(mattjj)

def process_custom_transpose(self, prim, call, tracers, **params):
res_ts, lin_ts = split_list(tracers, [params['res_tree'].num_leaves])
assert all(t.is_known() for t in res_ts)
lin_all_known = all(t.is_known() for t in lin_ts)
if lin_all_known:
res_cvals = [t.pval[1] for t in res_ts]
lin_cvals = [t.pval[1] for t in lin_ts]
return prim.bind(call, *res_cvals, *lin_cvals, **params)
else:
out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None)
for aval in params['out_types']]
in_tracers = map(self.instantiate_const, tracers)
new_params = dict(params, call=call)
eqn = new_eqn_recipe(in_tracers, out_tracers, prim, new_params,
source_info_util.current())
for t in out_tracers: t.recipe = eqn
return out_tracers

def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
tracers = map(self.instantiate_const_abstracted, tracers)
in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units
Expand Down

0 comments on commit 947b7b8

Please sign in to comment.