Skip to content

Commit

Permalink
Merge pull request #13584 from LenaMartens:debug-check
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 494754750
  • Loading branch information
jax authors committed Dec 12, 2022
2 parents 13c34f9 + 3db909e commit 4a9e9d5
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 19 deletions.
45 changes: 32 additions & 13 deletions docs/debugging/checkify_guide.md
Expand Up @@ -8,16 +8,16 @@ import jax
import jax.numpy as jnp

def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
y = x[i]
z = jnp.sin(y)
return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
```

You can also use checkify to automatically add common checks:
Expand Down Expand Up @@ -58,7 +58,7 @@ But the checkify transformation functionalizes (or discharges) these effects. A
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
ValueError:
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
Expand Down Expand Up @@ -108,19 +108,19 @@ The error is a regular value computed by the function, and the error is raised o

```python
def f(x):
checkify.check(x > 0., "must be positive!") # convenient but effectful API
checkify.check(x > 0., "{} must be positive!", x) # convenient but effectful API
return jnp.log(x)

f_checked = checkify(f)

err, x = jax.jit(f_checked)(0.)
err, x = jax.jit(f_checked)(-1.)
err.throw()
# ValueError: must be positive! (check failed at <...>:2 (f))
# ValueError: -1. must be positive! (check failed at <...>:2 (f))
```

We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but also track error messages and expose throw and get methods; see {mod}`jax.experimental.checkify`).
We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but also track error messages and expose throw and get methods; see {mod}`jax.experimental.checkify`). `checkify.check` also allows you to add run-time values to your error message by providing them as format arguments to the error message.

You could now instrument your code with run-time checks, but `checkify` can also automatically add checks for common errors!
You could now manually instrument your code with run-time checks, but `checkify` can also automatically add checks for common errors!
Consider these error cases:

```python
Expand Down Expand Up @@ -158,10 +158,29 @@ jitted. Here's a few more examples of `checkify` with other JAX
transformations. Note that checkified functions are functionally pure, and
should trivially compose with all JAX transformations!

### `jit`

You can safely add `jax.jit` to a checkified function, or `checkify` a jitted
function, both will work.

```python
def f(x, i):
return x[i]

checkify_of_jit = checkify.checkify(jax.jit(f))
jit_of_checkify = jax.jit(checkify.checkify(f))
err, _ = checkify_of_jit(jnp.ones((5,)), 100)
err.get()
# out-of-bounds indexing at <..>:2 (f)
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
# out-of-bounds indexing at <..>:2 (f)
```

### `vmap`/`pmap`

Mapping a checkified function will give you a mapped error, which can contain
different errors for every element of the mapped dimension.
You can `vmap` and `pmap` checkified functions (or `checkify` mapped functions).
Mapping a checkified function will give you a mapped error, which can contain
different errors for every element of the mapped dimension.

```python
def f(x, i):
Expand Down Expand Up @@ -206,7 +225,7 @@ f = pjit(
f,
in_axis_resources=PartitionSpec('x', None),
out_axis_resources=(None, PartitionSpec('x', None)))

with maps.Mesh(mesh.devices, mesh.axis_names):
err, data = f(input_data)
err.throw()
Expand Down Expand Up @@ -264,4 +283,4 @@ jax.grad(assert_gradient_negative)(-1.)
* Requires threading error values out of functions and manually throwing the
error. If the error is not explicitly thrown, you might miss out on errors!
* Throwing an error value will materialize that error value on the host, meaning
it's a blocking operation which defeats JAX's async run-ahead.
it's a blocking operation which defeats JAX's async run-ahead.
20 changes: 14 additions & 6 deletions jax/_src/checkify.py
Expand Up @@ -598,38 +598,46 @@ def checkify_fun_to_jaxpr(



def check(pred: Bool, msg: str, *args, **kwargs) -> None:
def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
"""Check a predicate, add an error with msg if predicate is False.
This is an effectful operation, and can't be staged (jitted/scanned/...).
Before staging a function with checks, :func:`~checkify` it!
Args:
pred: if False, an error is added.
msg: error message if error is added.
msg: error message if error is added. Can be a format string.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
Note that these arguments can be traced values allowing you to add
run-time values to the error message.
Note that tracking these run-time arrays will increase your memory usage,
even if no error happens.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.check(x!=0, "cannot be zero!")
... checkify.check(x>0, "{x} needs to be positive!", x=x)
... return 1/x
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(0)
>>> err, out = jax.jit(checked_f)(-3.)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: cannot be zero!
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
"""
if not is_scalar_pred(pred):
raise TypeError(f'check takes a scalar pred as argument, got {pred}')
new_error = FailedCheckError(summary(), msg, *args, **kwargs)
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
error = assert_func(init_error, jnp.logical_not(pred), new_error)
return check_error(error)


def is_scalar_pred(pred) -> bool:
return (isinstance(pred, bool) or
isinstance(pred, jnp.ndarray) and pred.shape == () and
Expand Down

0 comments on commit 4a9e9d5

Please sign in to comment.