Skip to content

Commit

Permalink
[x64] deprecate unsafe type casting in scatter-update operations
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 9, 2022
1 parent ca01d1b commit d2f80ef
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 11 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Expand Up @@ -45,6 +45,11 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
traces as an alternative to the Tensorboard UI.
* Added a `jax.named_scope` context manager that adds profiler metadata to
Python programs (similar to `jax.named_call`).
* In scatter-update operations (i.e. :attr:`jax.numpy.ndarray.at`), unsafe implicit
dtype casts are deprecated, and now result in a `FutureWarning`.
In a future release, this will become an error. An example of an unsafe implicit
cast is `jnp.zeros(4, dtype=int).at[0].set(1.5)`, in which `1.5` previously was
silently truncated to `1`.

## jaxlib 0.3.11 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...main).
Expand Down
8 changes: 8 additions & 0 deletions jax/_src/ops/scatter.py
Expand Up @@ -16,6 +16,7 @@

import sys
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import warnings

import numpy as np

Expand Down Expand Up @@ -81,6 +82,13 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
dtype = lax.dtype(x)
weak_type = dtypes.is_weakly_typed(x)

if dtype != dtypes.result_type(x, y):
# TODO(jakevdp): change this to an error after the deprecation period.
warnings.warn("scatter inputs have incompatible types: cannot safely cast "
f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
"In future JAX releases this will result in an error.",
FutureWarning)

idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx,
normalize_indices=normalize_indices)
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/scipy/sparse/linalg.py
Expand Up @@ -350,9 +350,10 @@ def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):

# This assumes that Q's leaves all have the same dimension in the last
# axis.
r = jnp.zeros(tree_leaves(Q)[0].shape[-1])
Q0 = tree_leaves(Q)[0]
r = jnp.zeros(Q0.shape[-1], dtype=Q0.dtype)
q = x
xnorm_scaled = xnorm / jnp.sqrt(2)
xnorm_scaled = xnorm / jnp.sqrt(2.0)

def body_function(carry):
k, q, r, qnorm_scaled = carry
Expand All @@ -368,7 +369,7 @@ def qnorm_cond(carry):
def qnorm(carry):
k, _, q, qnorm_scaled = carry
_, qnorm = _safe_normalize(q)
qnorm_scaled = qnorm / jnp.sqrt(2)
qnorm_scaled = qnorm / jnp.sqrt(2.0)
return (k, False, q, qnorm_scaled)

init = (k, True, q, qnorm_scaled)
Expand Down
2 changes: 1 addition & 1 deletion tests/checkify_test.py
Expand Up @@ -77,7 +77,7 @@ def f(x, i):
"max", "get"])
def test_jit_oob_update(self, update_fn):
def f(x, i):
return getattr(x.at[i], update_fn)(1.)
return getattr(x.at[i], update_fn)(1)

f = jax.jit(f)
checked_f = checkify.checkify(f, errors=checkify.index_checks)
Expand Down
89 changes: 82 additions & 7 deletions tests/lax_numpy_indexing_test.py
Expand Up @@ -960,7 +960,7 @@ def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245
jnp.array([7, 7, 1, 2, 1, 4, 5, 7, 7, 7], jnp.int32))

def testIndexingWeakTypes(self):
x = lax_internal._convert_element_type(jnp.arange(5), int, weak_type=True)
x = lax_internal._convert_element_type(jnp.arange(5), float, weak_type=True)

a = x.at[0].set(1.0)
self.assertEqual(a.dtype, x.dtype)
Expand All @@ -974,6 +974,67 @@ def testIndexingWeakTypes(self):
self.assertEqual(c.dtype, x.dtype)
self.assertTrue(dtypes.is_weakly_typed(c))

def testIndexingTypePromotion(self):
def _check(x_type, y_type):
x = jnp.arange(5, dtype=x_type)
y = y_type(0)
out = x.at[0].set(y)
self.assertEqual(x.dtype, out.dtype)

@jtu.ignore_warning(category=np.ComplexWarning,
message="Casting complex values to real")
def _check_warns(x_type, y_type, msg):
with self.assertWarnsRegex(FutureWarning, msg):
_check(x_type, y_type)

def _check_raises(x_type, y_type, msg):
with self.assertRaisesRegex(ValueError, msg):
_check(x_type, y_type)

# Matching dtypes are always OK
_check(jnp.int32, jnp.int32)
_check(jnp.float32, jnp.float32)
_check(jnp.complex64, jnp.complex64)

# Weakly-typed y values promote.
_check(jnp.int32, int)
_check(jnp.float32, int)
_check(jnp.float32, float)
_check(jnp.complex64, int)
_check(jnp.complex64, float)
_check(jnp.complex64, complex)

# in standard promotion mode, strong types can promote.
msg = "scatter inputs have incompatible types"
with jax.numpy_dtype_promotion('standard'):
_check(jnp.int32, jnp.int16)
_check(jnp.float32, jnp.float16)
_check(jnp.float32, jnp.int32)
_check(jnp.complex64, jnp.int32)
_check(jnp.complex64, jnp.float32)

# TODO(jakevdp): make these _check_raises
_check_warns(jnp.int16, jnp.int32, msg)
_check_warns(jnp.int32, jnp.float32, msg)
_check_warns(jnp.int32, jnp.complex64, msg)
_check_warns(jnp.float16, jnp.float32, msg)
_check_warns(jnp.float32, jnp.complex64, msg)

# in strict promotion mode, strong types do not promote.
msg = "Input dtypes .* have no available implicit dtype promotion path"
with jax.numpy_dtype_promotion('strict'):
_check_raises(jnp.int32, jnp.int16, msg)
_check_raises(jnp.float32, jnp.float16, msg)
_check_raises(jnp.float32, jnp.int32, msg)
_check_raises(jnp.complex64, jnp.int32, msg)
_check_raises(jnp.complex64, jnp.float32, msg)

_check_raises(jnp.int16, jnp.int32, msg)
_check_raises(jnp.int32, jnp.float32, msg)
_check_raises(jnp.int32, jnp.complex64, msg)
_check_raises(jnp.float16, jnp.float32, msg)
_check_raises(jnp.float32, jnp.complex64, msg)


def _broadcastable_shapes(shape):
"""Returns all shapes that broadcast to `shape`."""
Expand All @@ -989,6 +1050,20 @@ def f(rshape):
yield list(reversed(x))


# TODO(jakevdp): move this implementation to jax.dtypes & use in scatter?
def _can_cast(from_, to):
return lax.dtype(to) == dtypes.result_type(from_, to)


def _compatible_dtypes(op, dtype, inexact=False):
if op == UpdateOps.ADD:
return [dtype]
elif inexact:
return [dt for dt in float_dtypes if _can_cast(dt, dtype)]
else:
return [dt for dt in all_dtypes if _can_cast(dt, dtype)]


class UpdateOps(enum.Enum):
UPDATE = 0
ADD = 1
Expand Down Expand Up @@ -1060,7 +1135,7 @@ class IndexedUpdateTest(jtu.JaxTestCase):
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes)
for update_dtype in s(_compatible_dtypes(op, dtype))
for mode in s(MODES))))
def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op, mode):
Expand All @@ -1083,7 +1158,7 @@ def testStaticIndexing(self, shape, dtype, update_shape, update_dtype,
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
Expand All @@ -1106,7 +1181,7 @@ def testAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
Expand All @@ -1130,7 +1205,7 @@ def testAdvancedIndexingSorted(self, shape, dtype, update_shape, update_dtype,
for op in s(UpdateOps)
for dtype in s(UpdateOps.dtypes(op))
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else all_dtypes))))
for update_dtype in s(_compatible_dtypes(op, dtype)))))
def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
indexer, op):
rng = jtu.rand_default(self.rng())
Expand All @@ -1157,7 +1232,7 @@ def testMixedAdvancedIndexing(self, shape, dtype, update_shape, update_dtype,
for op in [UpdateOps.ADD, UpdateOps.MUL, UpdateOps.UPDATE]
for dtype in float_dtypes
for update_shape in _broadcastable_shapes(update_shape)
for update_dtype in ([dtype] if op == UpdateOps.ADD else float_dtypes)))
for update_dtype in _compatible_dtypes(op, dtype, inexact=True)))
def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op, mode):
rng = jtu.rand_default(self.rng())
Expand All @@ -1184,7 +1259,7 @@ def testStaticIndexingGrads(self, shape, dtype, update_shape, update_dtype,
else [UpdateOps.ADD])
for dtype in s(float_dtypes)
for update_shape in s(_broadcastable_shapes(update_shape))
for update_dtype in s([dtype] if op == UpdateOps.ADD else float_dtypes))))
for update_dtype in s(_compatible_dtypes(op, dtype, inexact=True)))))
def testAdvancedIndexingGrads(self, shape, dtype, update_shape, update_dtype,
indexer, op, unique_indices):
rng = jtu.rand_default(self.rng())
Expand Down

0 comments on commit d2f80ef

Please sign in to comment.