Skip to content

Commit

Permalink
Checkify: address initial feedback.
Browse files Browse the repository at this point in the history
  - add a way to run checkify with no errors enabled
  - clarify "can't be staged" error message
  - export init_error: a way for users to set a default Error value
  • Loading branch information
LenaMartens committed Feb 18, 2022
1 parent 1486be7 commit 2eeb683
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
1 change: 1 addition & 0 deletions jax/experimental/checkify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
div_checks as div_checks,
float_checks as float_checks,
index_checks as index_checks,
init_error as init_error,
nan_checks as nan_checks,
user_checks as user_checks,
)
13 changes: 9 additions & 4 deletions jax/experimental/checkify/checkify_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,12 @@ def assert_impl(pred, code, *, msgs):

@assert_p.def_abstract_eval
def assert_abstract_eval(pred, code, *, msgs):
raise Exception("can't be staged!")

# TODO(lenamartens) add in-depth explanation to link to in module docs.
raise ValueError('Cannot abstractly evaluate a checkify.check which was not'
' functionalized. This probably means you tried to stage'
' (jit/scan/pmap/...) a `check` without functionalizing it'
' through `checkify.checkify`.'
)

## checkify rules

Expand Down Expand Up @@ -742,8 +746,9 @@ def checkify(fun: Callable[..., Out],
"""
if not errors:
raise ValueError('Checkify needs to be called with at least one enabled'
' ErrorCategory, was called with an empty errors set.')
def checked_fun_trivial(*args, **kwargs):
return init_error, fun(*args, **kwargs)
return checked_fun_trivial

@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
Expand Down
17 changes: 13 additions & 4 deletions tests/checkify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,17 @@ def f(cond_val, body_val):
self.assertStartsWith(err.get(), "nan generated by primitive sin")

def test_empty_enabled_errors(self):
with self.assertRaisesRegex(ValueError, "called with an empty errors set"):
checkify.checkify(lambda x: x, errors={})
def multi_errors(x):
x = x/0 # DIV
x = jnp.sin(x) # NAN
x = x[500] # OOB
# TODO(lenamartens): this error should also be disabled.
# checkify.check(x < 0, "must be negative!") # ASSERT
return x

x = jnp.ones((2,))
err, _ = checkify.checkify(multi_errors, errors={})(x)
self.assertIsNone(err.get())

@parameterized.named_parameters(
("assert", checkify.user_checks, "must be negative!"),
Expand Down Expand Up @@ -504,7 +513,7 @@ def sin_bwd(x2, g):
err, y = checkify.checkify(jax.grad(sin),
errors=checkify.float_checks)(jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
self.assertStartsWith(err.get(), "nan generated by primitive sin")


class AssertPrimitiveTests(jtu.JaxTestCase):
Expand All @@ -521,7 +530,7 @@ def test_assert_primitive_(self):
def f():
checkify.check(False, "hi")

with self.assertRaisesRegex(Exception, "can't be staged"):
with self.assertRaisesRegex(ValueError, "Cannot abstractly evaluate"):
f()

def test_assert_discharging(self):
Expand Down

0 comments on commit 2eeb683

Please sign in to comment.