Skip to content

Commit

Permalink
Checkify: add axis and axis size to OOB error message.
Browse files Browse the repository at this point in the history
  • Loading branch information
LenaMartens committed Mar 31, 2022
1 parent c5f9c42 commit 15d2cca
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 30 deletions.
40 changes: 27 additions & 13 deletions jax/experimental/checkify/checkify_impl.py
Expand Up @@ -58,9 +58,18 @@ def setnewattr(obj, name, val):
Int = Union[int, core.Tracer]
Payload = Union[np.ndarray, jnp.ndarray, core.Tracer]

# For now, the payload needs to be a fixed-size array (int32 scalar).
# For now, the payload needs to be a fixed-size array: 3 int32s, used for the
# OOB message.
# TODO(lenamartens): Relax this fixed-size constraint.
init_payload = np.ones((), np.int32)
init_payload = np.ones((3,), np.int32)


def _format_msg(msg, payloads):
payload_mapping = {}
for i, pl in enumerate(payloads):
payload_mapping[f'payload{i}'] = pl
return msg.format(**payload_mapping)


@dataclass(frozen=True)
class Error:
Expand All @@ -76,11 +85,11 @@ def get(self) -> Optional[str]:
assert np.shape(self.err) == np.shape(self.code)
if np.size(self.err) == 1:
if self.err:
return self.msgs[int(self.code)].format(payload=self.payload)
return _format_msg(self.msgs[int(self.code)], self.payload)
else:
return '\n'.join(
f'at mapped index {", ".join(map(str, idx))}: ' # type: ignore
f'{self.msgs[int(self.code[idx])].format(payload=self.payload[idx])}' # type: ignore
f'{_format_msg(self.msgs[int(self.code[idx])], self.payload[idx])}' # type: ignore
for idx, e in np.ndenumerate(self.err) if e) or None
return None

Expand Down Expand Up @@ -250,9 +259,8 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
return [CheckifyTracer(self, x) for x in out]

def _reduce_any_error(errs, codes, payloads):
errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0)
errs_, payload_ = lax.sort_key_val(errs, payloads, dimension=0)
return errs_[-1], codes_[-1], payload_[-1]
reduced_idx = jnp.argsort(errs)[-1]
return errs[reduced_idx], codes[reduced_idx], payloads[reduced_idx]

ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
Expand Down Expand Up @@ -490,12 +498,18 @@ def gather_error_check(error, enabled_errors, operand, start_indices, *,
upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims)))
in_bounds = (start_indices >= 0) & (start_indices <= upper_bound)

msg = f'out-of-bounds indexing at {summary()}: '
msg += 'index {payload} is out of bounds for '
msg += f'array of shape {operand.shape}.'
start_indices, in_bounds = jnp.ravel(start_indices), jnp.ravel(in_bounds)
# Report first index which is out-of-bounds (in row-major order).
payload = start_indices[jnp.argsort(in_bounds, axis=0)[0]]
# Get first OOB index, axis and axis size so it can be added to the error msg.
flat_idx = jnp.argmin(in_bounds)
multi_idx = jnp.unravel_index(flat_idx, start_indices.shape)
oob_axis = jnp.array(dnums.start_index_map)[multi_idx[-1]]
oob_axis_size = jnp.array(operand.shape)[oob_axis]
oob_index = jnp.ravel(start_indices)[flat_idx]
payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32)

msg = (f'out-of-bounds indexing at {summary()} for array of '
f'shape {operand.shape}: '
'index {payload0} is out of bounds for axis {payload1} '
'with size {payload2}.')

return out, assert_func(error, jnp.all(in_bounds), msg, payload)
error_checks[lax.gather_p] = gather_error_check
Expand Down
54 changes: 37 additions & 17 deletions tests/checkify_test.py
Expand Up @@ -138,27 +138,36 @@ def f(x, i):
self.assertStartsWith(err.get(), "nan generated by primitive cos")

def test_numpy_indexing_oobs(self):
x = jnp.ones((2, 3, 7))
def assert_raises_oob(fn, idx, expected_str):
def raises_oob(fn, idx, *expected_strs):
err, _ = checkify.checkify(fn, errors=checkify.index_checks)(x, idx)
error_txt = err.get()
self.assertIsNotNone(error_txt)
self.assertStartsWith(error_txt, "out-of-bounds indexing")
self.assertIn(expected_str, error_txt)

simple_indexing = lambda x, i: x[i]
assert_raises_oob(simple_indexing, 5, "index 5")
assert_raises_oob(simple_indexing, -5, "index -3")
assert_raises_oob(simple_indexing, (0, 100), "index 100")
assert_raises_oob(simple_indexing, (0, 5, 100), "index 5")
assert_raises_oob(simple_indexing, ((1, 20), (1, 4)), "index 20")
assert_raises_oob(simple_indexing, ((1, 20), (3, 4)), "index 3")

multiple_axis_indexing = lambda x, i: x[i[0], :, i[1]]
assert_raises_oob(multiple_axis_indexing, (0, 9), "index 9")
assert_raises_oob(multiple_axis_indexing, (-5, 9), "index -3")
assert_raises_oob(multiple_axis_indexing, (5, -9), "index 5")
assert_raises_oob(multiple_axis_indexing, ((0, 9), 0), "index 9")
for s in expected_strs:
self.assertIn(s, error_txt)

x = jnp.ones((2, 3, 7))
axis0_msg = "axis 0 with size 2"
axis1_msg = "axis 1 with size 3"
axis2_msg = "axis 2 with size 7"

single_idx = lambda x, i: x[i]
raises_oob(single_idx, 5, "index 5", axis0_msg)
raises_oob(single_idx, -5, "index -3", axis0_msg)
raises_oob(single_idx, (0, 100), "index 100", axis1_msg)
raises_oob(single_idx, (0, 5, 100), "index 5", axis1_msg)
raises_oob(single_idx, (0, 0, 100), "index 100", axis2_msg)
raises_oob(single_idx, ((1, 20), (1, 4)), "index 20", axis0_msg)
raises_oob(single_idx, ((1, 20), (3, 4)), "index 3", axis1_msg)
raises_oob(single_idx, (((1, 1), (1, 20)), 3), "index 3", axis1_msg)
raises_oob(single_idx, (((1, 1), (1, 20)), 0), "index 20", axis0_msg)

multi_idx = lambda x, i: x[i[0], :, i[1]]
raises_oob(multi_idx, (0, 9), "index 9", axis2_msg)
# TODO(lenamartens): numpy reports index -5 here, need to normalize?
raises_oob(multi_idx, (-5, 9), "index -3", axis0_msg)
raises_oob(multi_idx, (5, -9), "index 5", axis0_msg)
raises_oob(multi_idx, ((0, 9), 0), "index 9", axis0_msg)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
Expand Down Expand Up @@ -598,6 +607,17 @@ def test_nd_payloads(self):
self.assertIn("index 5", errs.get())
self.assertIn("index 100", errs.get())

def test_mapped_error_one_payload(self):
def f(x, i):
x = x[i]
return x/0

cf = checkify.checkify(f, errors=checkify.automatic_checks)
errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100]))
self.assertIsNotNone(errs.get())
self.assertIn("divided by zero", errs.get())
self.assertIn("index 100", errs.get())


class AssertPrimitiveTests(jtu.JaxTestCase):

Expand Down

0 comments on commit 15d2cca

Please sign in to comment.