Skip to content

Commit

Permalink
checkify: cache jaxpr formation so we don't always retrace
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Feb 1, 2023
1 parent fcb9dfb commit 684846b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
30 changes: 21 additions & 9 deletions jax/_src/checkify.py
Expand Up @@ -31,7 +31,7 @@
from jax._src.sharding import OpShardingSharding
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3)
unzip3, weakref_lru_cache)
from jax.api_util import flatten_fun
from jax.experimental import maps
from jax.experimental import pjit
Expand Down Expand Up @@ -383,6 +383,20 @@ def out_axes_thunk():
def get_shaped_aval(val):
return core.raise_to_shaped(core.get_aval(val))

def initial_style_jaxpr(
fun: Callable, in_tree: PyTreeDef, in_avals: Sequence[core.AbstractValue]
) -> Tuple[core.Jaxpr, List[Any], PyTreeDef]:
return _initial_style_jaxpr(fun, in_tree, tuple(in_avals))

@weakref_lru_cache
def _initial_style_jaxpr(fun, in_tree, in_avals):
# like control_flow._initial_style_jaxpr, but use flatten_fun not _nokwargs
fun_, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun_, in_tree, False, 'checkify')
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun_, in_avals, debug)
return jaxpr, consts, out_tree()


def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
error: Error, *args) -> Tuple[Error, List[core.Value]]:
err_vals, err_tree = jtu.tree_flatten(error)
Expand Down Expand Up @@ -1065,16 +1079,14 @@ def checkify(f: Callable[..., Out],
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
# stage:
fun = lu.wrap_init(f)
flat_args, in_tree = jtu.tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(fun, in_tree)
flat_avals = map(get_shaped_aval, flat_args)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree()
flat_args, in_tree = tree_flatten((args, kwargs))
in_avals = map(get_shaped_aval, flat_args)
jaxpr_, consts, out_tree = initial_style_jaxpr(f, in_tree, in_avals)
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
# checkify:
flat_args = jtu.tree_leaves((args, kwargs))
error, out_flat = checkify_jaxpr(core.ClosedJaxpr(jaxpr, consts), errors,
init_error, *flat_args)
error, out_flat = checkify_jaxpr(jaxpr, errors, init_error,
*consts, *flat_args)
return error, jtu.tree_unflatten(out_tree, out_flat)
return checked_fun

Expand Down
7 changes: 7 additions & 0 deletions tests/checkify_test.py
Expand Up @@ -773,6 +773,13 @@ def g(x):
err, _ = checked_f(jnp.ones((2, 4)))
self.assertIsNone(err.get())

def test_retracing(self):
f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2))
_ = f(3.)
with jtu.count_primitive_compiles() as count:
_ = f(3.)
self.assertEqual(count[0], 0)


@jtu.with_config(jax_check_tracer_leaks=True)
class AssertPrimitiveTests(jtu.JaxTestCase):
Expand Down

0 comments on commit 684846b

Please sign in to comment.