From d2f80ef117c2aef4ec287f5164cc989a13b00171 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 9 Jun 2022 15:21:49 -0700 Subject: [PATCH] [x64] deprecate unsafe type casting in scatter-update operations --- CHANGELOG.md | 5 ++ jax/_src/ops/scatter.py | 8 +++ jax/_src/scipy/sparse/linalg.py | 7 +-- tests/checkify_test.py | 2 +- tests/lax_numpy_indexing_test.py | 89 +++++++++++++++++++++++++++++--- 5 files changed, 100 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cc17898fbbe4..ffff13ca8e14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. traces as an alternative to the Tensorboard UI. * Added a `jax.named_scope` context manager that adds profiler metadata to Python programs (similar to `jax.named_call`). + * In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit + dtype casts are deprecated, and now result in a `FutureWarning`. + In a future release, this will become an error. An example of an unsafe implicit + cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was + silently truncated to `1`. ## jaxlib 0.3.11 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main). diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 342eedebea67..15a97961ec01 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -16,6 +16,7 @@ import sys from typing import Any, Callable, Optional, Sequence, Tuple, Union +import warnings import numpy as np @@ -81,6 +82,13 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) + if dtype != dtypes.result_type(x, y): + # TODO(jakevdp): change this to an error after the deprecation period. + warnings.warn("scatter inputs have incompatible types: cannot safely cast " + f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. " + "In future JAX releases this will result in an error.", + FutureWarning) + idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) diff --git a/jax/_src/scipy/sparse/linalg.py b/jax/_src/scipy/sparse/linalg.py index 498ebef3d2ca..aa838d8288b3 100644 --- a/jax/_src/scipy/sparse/linalg.py +++ b/jax/_src/scipy/sparse/linalg.py @@ -350,9 +350,10 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2): # This assumes that Q's leaves all have the same dimension in the last # axis. - r = jnp.zeros(tree_leaves(Q)[0].shape[-1]) + Q0 = tree_leaves(Q)[0] + r = jnp.zeros(Q0.shape[-1], dtype=Q0.dtype) q = x - xnorm_scaled = xnorm / jnp.sqrt(2) + xnorm_scaled = xnorm / jnp.sqrt(2.0) def body_function(carry): k, q, r, qnorm_scaled = carry @@ -368,7 +369,7 @@ def qnorm_cond(carry): def qnorm(carry): k, _, q, qnorm_scaled = carry _, qnorm = _safe_normalize(q) - qnorm_scaled = qnorm / jnp.sqrt(2) + qnorm_scaled = qnorm / jnp.sqrt(2.0) return (k, False, q, qnorm_scaled) init = (k, True, q, qnorm_scaled) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 699c45f31d5b..a506de317596 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -77,7 +77,7 @@ def f(x, i): "max", "get"]) def test_jit_oob_update(self, update_fn): def f(x, i): - return getattr(x.at[i], update_fn)(1.) + return getattr(x.at[i], update_fn)(1) f = jax.jit(f) checked_f = checkify.checkify(f, errors=checkify.index_checks) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index ed91871b4c34..f740922faa62 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -960,7 +960,7 @@ def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32)) def testIndexingWeakTypes(self): - x = lax_internal._convert_element_type(jnp.arange(5), int, weak_type=True) + x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True) a = x.at[0].set(1.0) self.assertEqual(a.dtype, x.dtype) @@ -974,6 +974,67 @@ def testIndexingWeakTypes(self): self.assertEqual(c.dtype, x.dtype) self.assertTrue(dtypes.is_weakly_typed(c)) + def testIndexingTypePromotion(self): + def _check(x_type, y_type): + x = jnp.arange(5, dtype=x_type) + y = y_type(0) + out = x.at[0].set(y) + self.assertEqual(x.dtype, out.dtype) + + @jtu.ignore_warning(category=np.ComplexWarning, + message="Casting complex values to real") + def _check_warns(x_type, y_type, msg): + with self.assertWarnsRegex(FutureWarning, msg): + _check(x_type, y_type) + + def _check_raises(x_type, y_type, msg): + with self.assertRaisesRegex(ValueError, msg): + _check(x_type, y_type) + + # Matching dtypes are always OK + _check(jnp.int32, jnp.int32) + _check(jnp.float32, jnp.float32) + _check(jnp.complex64, jnp.complex64) + + # Weakly-typed y values promote. + _check(jnp.int32, int) + _check(jnp.float32, int) + _check(jnp.float32, float) + _check(jnp.complex64, int) + _check(jnp.complex64, float) + _check(jnp.complex64, complex) + + # in standard promotion mode, strong types can promote. + msg = "scatter inputs have incompatible types" + with jax.numpy_dtype_promotion('standard'): + _check(jnp.int32, jnp.int16) + _check(jnp.float32, jnp.float16) + _check(jnp.float32, jnp.int32) + _check(jnp.complex64, jnp.int32) + _check(jnp.complex64, jnp.float32) + + # TODO(jakevdp): make these _check_raises + _check_warns(jnp.int16, jnp.int32, msg) + _check_warns(jnp.int32, jnp.float32, msg) + _check_warns(jnp.int32, jnp.complex64, msg) + _check_warns(jnp.float16, jnp.float32, msg) + _check_warns(jnp.float32, jnp.complex64, msg) + + # in strict promotion mode, strong types do not promote. + msg = "Input dtypes .* have no available implicit dtype promotion path" + with jax.numpy_dtype_promotion('strict'): + _check_raises(jnp.int32, jnp.int16, msg) + _check_raises(jnp.float32, jnp.float16, msg) + _check_raises(jnp.float32, jnp.int32, msg) + _check_raises(jnp.complex64, jnp.int32, msg) + _check_raises(jnp.complex64, jnp.float32, msg) + + _check_raises(jnp.int16, jnp.int32, msg) + _check_raises(jnp.int32, jnp.float32, msg) + _check_raises(jnp.int32, jnp.complex64, msg) + _check_raises(jnp.float16, jnp.float32, msg) + _check_raises(jnp.float32, jnp.complex64, msg) + def _broadcastable_shapes(shape): """Returns all shapes that broadcast to `shape`.""" @@ -989,6 +1050,20 @@ def f(rshape): yield list(reversed(x)) +# TODO(jakevdp): move this implementation to jax.dtypes & use in scatter? +def _can_cast(from_, to): + return lax.dtype(to) == dtypes.result_type(from_, to) + + +def _compatible_dtypes(op, dtype, inexact=False): + if op == UpdateOps.ADD: + return [dtype] + elif inexact: + return [dt for dt in float_dtypes if _can_cast(dt, dtype)] + else: + return [dt for dt in all_dtypes if _can_cast(dt, dtype)] + + class UpdateOps(enum.Enum): UPDATE = 0 ADD = 1 @@ -1060,7 +1135,7 @@ class IndexedUpdateTest(jtu.JaxTestCase): for op in s(UpdateOps) for dtype in s(UpdateOps.dtypes(op)) for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes) + for update_dtype in s(_compatible_dtypes(op, dtype)) for mode in s(MODES)))) def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, indexer, op, mode): @@ -1083,7 +1158,7 @@ def testStaticIndexing(self, shape, dtype, update_shape, update_dtype, for op in s(UpdateOps) for dtype in s(UpdateOps.dtypes(op)) for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) + for update_dtype in s(_compatible_dtypes(op, dtype))))) def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) @@ -1106,7 +1181,7 @@ def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, for op in s(UpdateOps) for dtype in s(UpdateOps.dtypes(op)) for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) + for update_dtype in s(_compatible_dtypes(op, dtype))))) def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) @@ -1130,7 +1205,7 @@ def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype, for op in s(UpdateOps) for dtype in s(UpdateOps.dtypes(op)) for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)))) + for update_dtype in s(_compatible_dtypes(op, dtype))))) def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, indexer, op): rng = jtu.rand_default(self.rng()) @@ -1157,7 +1232,7 @@ def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype, for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE] for dtype in float_dtypes for update_shape in _broadcastable_shapes(update_shape) - for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes))) + for update_dtype in _compatible_dtypes(op, dtype, inexact=True))) def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, indexer, op, mode): rng = jtu.rand_default(self.rng()) @@ -1184,7 +1259,7 @@ def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype, else [UpdateOps.ADD]) for dtype in s(float_dtypes) for update_shape in s(_broadcastable_shapes(update_shape)) - for update_dtype in s([dtype] if op == UpdateOps.ADD else float_dtypes)))) + for update_dtype in s(_compatible_dtypes(op, dtype, inexact=True))))) def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype, indexer, op, unique_indices): rng = jtu.rand_default(self.rng())