Skip to content

Commit

Permalink
refine constant-hoisting heuristic for closure_convert
Browse files Browse the repository at this point in the history
Instead of hoisting all float-type arrays during closure conversion,
only hoist JVPTracers (or tracers carrying such tracers
indirectly). Doing so better approximates the subset of
closure-captured values that participate in AD.

Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
froystig and mattjj committed Jul 27, 2021
1 parent d1e1d65 commit 258ae44
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
16 changes: 9 additions & 7 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Callable, Generic, Optional, Sequence, Tuple, TypeVar, Any

from jax import core
from jax._src import dtypes
from jax import linear_util as lu
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
tree_multimap, treedef_is_leaf, treedef_tuple,
Expand Down Expand Up @@ -831,18 +830,21 @@ def rev(objective_fn, res, g):
else:
return _closure_convert_for_avals(fun, in_tree, in_avals)

def _is_perturbed(x: Any) -> bool:
if isinstance(x, ad.JVPTracer):
return True
elif isinstance(x, core.Tracer):
return any(_is_perturbed(attr) for name, attr in x._contents())
else:
return False

@cache()
def _closure_convert_for_avals(fun, in_tree, in_avals):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
out_tree = out_tree()

# We only want to closure convert for constants with respect to which we're
# differentiating. As a proxy for that, we hoist consts with float dtype.
# TODO(frostig,mattjj): revise this approach
from jax.numpy import inexact
is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), inexact)
(closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
(closure_consts, hoisted_consts), merge = partition_list(_is_perturbed, consts)
num_consts = len(hoisted_consts)

def converted_fun(*args_hconsts):
Expand Down
4 changes: 2 additions & 2 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,15 +568,15 @@ def __getattr__(self, name):

def __repr__(self):
base = pp('Traced<{}>with<{}>'.format(self.aval, self._trace))
contents = self._contents()
contents = [(name, pp(repr(attr))) for name, attr in self._contents()]
if contents:
base += pp(' with ') >> vcat(pp('{} = '.format(name)) >> pp_payload
for name, pp_payload in contents)
return str(base)

def _contents(self):
try:
return [(name, pp(repr(getattr(self, name)))) for name in self.__slots__]
return [(name, getattr(self, name)) for name in self.__slots__]
except AttributeError:
return ()

Expand Down

0 comments on commit 258ae44

Please sign in to comment.