From 15d2ccaeba368b3fa22794cf923c31fe1e8abafe Mon Sep 17 00:00:00 2001 From: Lena Martens Date: Wed, 30 Mar 2022 17:19:15 +0100 Subject: [PATCH] Checkify: add axis and axis size to OOB error message. --- jax/experimental/checkify/checkify_impl.py | 40 ++++++++++------ tests/checkify_test.py | 54 +++++++++++++++------- 2 files changed, 64 insertions(+), 30 deletions(-) diff --git a/jax/experimental/checkify/checkify_impl.py b/jax/experimental/checkify/checkify_impl.py index ee9695a73ead..33618a530268 100644 --- a/jax/experimental/checkify/checkify_impl.py +++ b/jax/experimental/checkify/checkify_impl.py @@ -58,9 +58,18 @@ def setnewattr(obj, name, val): Int = Union[int, core.Tracer] Payload = Union[np.ndarray, jnp.ndarray, core.Tracer] -# For now, the payload needs to be a fixed-size array (int32 scalar). +# For now, the payload needs to be a fixed-size array: 3 int32s, used for the +# OOB message. # TODO(lenamartens): Relax this fixed-size constraint. -init_payload = np.ones((), np.int32) +init_payload = np.ones((3,), np.int32) + + +def _format_msg(msg, payloads): + payload_mapping = {} + for i, pl in enumerate(payloads): + payload_mapping[f'payload{i}'] = pl + return msg.format(**payload_mapping) + @dataclass(frozen=True) class Error: @@ -76,11 +85,11 @@ def get(self) -> Optional[str]: assert np.shape(self.err) == np.shape(self.code) if np.size(self.err) == 1: if self.err: - return self.msgs[int(self.code)].format(payload=self.payload) + return _format_msg(self.msgs[int(self.code)], self.payload) else: return '\n'.join( f'at mapped index {", ".join(map(str, idx))}: ' # type: ignore - f'{self.msgs[int(self.code[idx])].format(payload=self.payload[idx])}' # type: ignore + f'{_format_msg(self.msgs[int(self.code[idx])], self.payload[idx])}' # type: ignore for idx, e in np.ndenumerate(self.err) if e) or None return None @@ -250,9 +259,8 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): return [CheckifyTracer(self, x) for x in out] def _reduce_any_error(errs, codes, payloads): - errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0) - errs_, payload_ = lax.sort_key_val(errs, payloads, dimension=0) - return errs_[-1], codes_[-1], payload_[-1] + reduced_idx = jnp.argsort(errs)[-1] + return errs[reduced_idx], codes[reduced_idx], payloads[reduced_idx] ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error) error_checks: Dict[core.Primitive, ErrorCheckRule] = {} @@ -490,12 +498,18 @@ def gather_error_check(error, enabled_errors, operand, start_indices, *, upper_bound = jnp.expand_dims(upper_bound, axis=tuple(range(num_batch_dims))) in_bounds = (start_indices >= 0) & (start_indices <= upper_bound) - msg = f'out-of-bounds indexing at {summary()}: ' - msg += 'index {payload} is out of bounds for ' - msg += f'array of shape {operand.shape}.' - start_indices, in_bounds = jnp.ravel(start_indices), jnp.ravel(in_bounds) - # Report first index which is out-of-bounds (in row-major order). - payload = start_indices[jnp.argsort(in_bounds, axis=0)[0]] + # Get first OOB index, axis and axis size so it can be added to the error msg. + flat_idx = jnp.argmin(in_bounds) + multi_idx = jnp.unravel_index(flat_idx, start_indices.shape) + oob_axis = jnp.array(dnums.start_index_map)[multi_idx[-1]] + oob_axis_size = jnp.array(operand.shape)[oob_axis] + oob_index = jnp.ravel(start_indices)[flat_idx] + payload = jnp.array([oob_index, oob_axis, oob_axis_size], dtype=jnp.int32) + + msg = (f'out-of-bounds indexing at {summary()} for array of ' + f'shape {operand.shape}: ' + 'index {payload0} is out of bounds for axis {payload1} ' + 'with size {payload2}.') return out, assert_func(error, jnp.all(in_bounds), msg, payload) error_checks[lax.gather_p] = gather_error_check diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 06fd88a3f0f6..a0e5b7f1c441 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -138,27 +138,36 @@ def f(x, i): self.assertStartsWith(err.get(), "nan generated by primitive cos") def test_numpy_indexing_oobs(self): - x = jnp.ones((2, 3, 7)) - def assert_raises_oob(fn, idx, expected_str): + def raises_oob(fn, idx, *expected_strs): err, _ = checkify.checkify(fn, errors=checkify.index_checks)(x, idx) error_txt = err.get() self.assertIsNotNone(error_txt) self.assertStartsWith(error_txt, "out-of-bounds indexing") - self.assertIn(expected_str, error_txt) - - simple_indexing = lambda x, i: x[i] - assert_raises_oob(simple_indexing, 5, "index 5") - assert_raises_oob(simple_indexing, -5, "index -3") - assert_raises_oob(simple_indexing, (0, 100), "index 100") - assert_raises_oob(simple_indexing, (0, 5, 100), "index 5") - assert_raises_oob(simple_indexing, ((1, 20), (1, 4)), "index 20") - assert_raises_oob(simple_indexing, ((1, 20), (3, 4)), "index 3") - - multiple_axis_indexing = lambda x, i: x[i[0], :, i[1]] - assert_raises_oob(multiple_axis_indexing, (0, 9), "index 9") - assert_raises_oob(multiple_axis_indexing, (-5, 9), "index -3") - assert_raises_oob(multiple_axis_indexing, (5, -9), "index 5") - assert_raises_oob(multiple_axis_indexing, ((0, 9), 0), "index 9") + for s in expected_strs: + self.assertIn(s, error_txt) + + x = jnp.ones((2, 3, 7)) + axis0_msg = "axis 0 with size 2" + axis1_msg = "axis 1 with size 3" + axis2_msg = "axis 2 with size 7" + + single_idx = lambda x, i: x[i] + raises_oob(single_idx, 5, "index 5", axis0_msg) + raises_oob(single_idx, -5, "index -3", axis0_msg) + raises_oob(single_idx, (0, 100), "index 100", axis1_msg) + raises_oob(single_idx, (0, 5, 100), "index 5", axis1_msg) + raises_oob(single_idx, (0, 0, 100), "index 100", axis2_msg) + raises_oob(single_idx, ((1, 20), (1, 4)), "index 20", axis0_msg) + raises_oob(single_idx, ((1, 20), (3, 4)), "index 3", axis1_msg) + raises_oob(single_idx, (((1, 1), (1, 20)), 3), "index 3", axis1_msg) + raises_oob(single_idx, (((1, 1), (1, 20)), 0), "index 20", axis0_msg) + + multi_idx = lambda x, i: x[i[0], :, i[1]] + raises_oob(multi_idx, (0, 9), "index 9", axis2_msg) + # TODO(lenamartens): numpy reports index -5 here, need to normalize? + raises_oob(multi_idx, (-5, 9), "index -3", axis0_msg) + raises_oob(multi_idx, (5, -9), "index 5", axis0_msg) + raises_oob(multi_idx, ((0, 9), 0), "index 9", axis0_msg) @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_jit={}".format(jit), "jit": jit} @@ -598,6 +607,17 @@ def test_nd_payloads(self): self.assertIn("index 5", errs.get()) self.assertIn("index 100", errs.get()) + def test_mapped_error_one_payload(self): + def f(x, i): + x = x[i] + return x/0 + + cf = checkify.checkify(f, errors=checkify.automatic_checks) + errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100])) + self.assertIsNotNone(errs.get()) + self.assertIn("divided by zero", errs.get()) + self.assertIn("index 100", errs.get()) + class AssertPrimitiveTests(jtu.JaxTestCase):