Skip to content

Commit

Permalink
Better pprint rule for check_p primitive.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 539703344
  • Loading branch information
LenaMartens authored and jax authors committed Jun 12, 2023
1 parent ed073aa commit 55da62f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
15 changes: 15 additions & 0 deletions jax/_src/checkify.py
Expand Up @@ -459,6 +459,21 @@ def _reduce_any_error(error: Error):
check_p = core.Primitive('check')
check_p.multiple_results = True # zero results


def _pp_check(eqn, context, settings) -> core.pp.Doc:
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
name_stack_annotation = (f'[{eqn.source_info.name_stack}]'
if settings.name_stack else None)
trimmed_params = sorted((k, v) for (k, v) in eqn.params.items()
if k != "err_tree")
rhs = [core.pp.text(eqn.primitive.name, annotation=name_stack_annotation),
core.pp_kv_pairs(trimmed_params, context, settings),
core.pp.text(" ") + core.pp_vars(eqn.invars, context)]
return core.pp.concat([core.pp.text("", annotation), *rhs])

core.pp_eqn_rules[check_p] = _pp_check

# TODO(lenamartens): inherit from Exception instead of ValueError.
class JaxRuntimeError(ValueError):
pass
Expand Down
4 changes: 4 additions & 0 deletions tests/checkify_test.py
Expand Up @@ -1295,6 +1295,10 @@ def f(x):

_ = jax.jit(f, static_argnums=(0,))(True)

def test_check_pp_rule(self):
jaxpr = jax.make_jaxpr(lambda: checkify.check(False, "hi"))()
jaxpr.pretty_print(source_info=True, name_stack=True) # Does not crash.


class LowerableChecksTest(jtu.JaxTestCase):
def setUp(self):
Expand Down

0 comments on commit 55da62f

Please sign in to comment.