Skip to content

Commit

Permalink
Better errors for array scalar/boolean conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 19, 2023
1 parent 3b66fbf commit 0dc2252
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 23 deletions.
14 changes: 9 additions & 5 deletions jax/_src/array.py
Expand Up @@ -256,29 +256,33 @@ def __len__(self):
raise TypeError("len() of unsized object") from err # same as numpy error

def __bool__(self):
return bool(self._value)

def __nonzero__(self):
# deprecated 2023 September 18.
# TODO(jakevdp) change to warn_on_empty=False
core.check_bool_conversion(self, warn_on_empty=True)
return bool(self._value)

def __float__(self):
core.check_scalar_conversion(self)
return self._value.__float__()

def __int__(self):
core.check_scalar_conversion(self)
return self._value.__int__()

def __complex__(self):
core.check_scalar_conversion(self)
return self._value.__complex__()

def __hex__(self):
assert self.ndim == 0, 'hex only works on scalar values'
core.check_integer_conversion(self)
return hex(self._value) # type: ignore

def __oct__(self):
assert self.ndim == 0, 'oct only works on scalar values'
core.check_integer_conversion(self)
return oct(self._value) # type: ignore

def __index__(self):
core.check_integer_conversion(self)
return op.index(self._value)

def tobytes(self, order="C"):
Expand Down
87 changes: 70 additions & 17 deletions jax/_src/core.py
Expand Up @@ -598,6 +598,36 @@ def escaped_tracer_error(tracer, detail=None):
return UnexpectedTracerError(msg)


def check_scalar_conversion(arr: Array):
if arr.size != 1:
raise TypeError("Only length-1 arrays can be converted to Python scalars.")
if arr.shape != ():
# Added 2023 September 18.
warnings.warn("Conversion of an array with ndim > 0 to a scalar is deprecated, "
"and will error in future.", DeprecationWarning, stacklevel=3)


def check_integer_conversion(arr: Array):
if not (arr.shape == () and dtypes.issubdtype(arr.dtype, np.integer)):
raise TypeError("Only integer scalar arrays can be converted to a scalar index.")


def check_bool_conversion(arr: Array, warn_on_empty=False):
if arr.size == 0:
if warn_on_empty:
warnings.warn(
"The truth value of an empty array is ambiguous. Returning False. In the future this "
"will result in an error. Use `array.size > 0` to check that an array is not empty.",
DeprecationWarning, stacklevel=3)
else:
raise ValueError("The truth value of an empty array is ambiguous. Use "
"`array.size > 0` to check that an array is not empty.")
if arr.size > 1:
raise ValueError("The truth value of an array with more than one element is "
"ambiguous. Use a.any() or a.all()")



class Tracer(typing.Array):
__array_priority__ = 1000
__slots__ = ['_trace', '_line_info']
Expand All @@ -615,9 +645,6 @@ def __dlpack__(self, *args, **kw):
f"The __dlpack__() method was called on {self._error_repr()}."
f"{self._origin_msg()}")

def __index__(self):
raise TracerIntegerConversionError(self)

def tolist(self):
raise ConcretizationTypeError(self,
f"The tolist() method was called on {self._error_repr()}."
Expand Down Expand Up @@ -670,12 +697,33 @@ def _assert_live(self) -> None:
def get_referent(self) -> Any:
return self # Override for object equivalence checking

def __bool__(self): return self.aval._bool(self)
def __int__(self): return self.aval._int(self)
def __hex__(self): return self.aval._hex(self)
def __oct__(self): return self.aval._oct(self)
def __float__(self): return self.aval._float(self)
def __complex__(self): return self.aval._complex(self)
def __bool__(self):
check_bool_conversion(self)
return self.aval._bool(self)

def __int__(self):
check_scalar_conversion(self)
return self.aval._int(self)

def __float__(self):
check_scalar_conversion(self)
return self.aval._float(self)

def __complex__(self):
check_scalar_conversion(self)
return self.aval._complex(self)

def __hex__(self):
check_integer_conversion(self)
return self.aval._hex(self)

def __oct__(self):
check_integer_conversion(self)
return self.aval._oct(self)

def __index__(self):
check_integer_conversion(self)
raise self.aval._index(self)

# raises a useful error on attempts to pickle a Tracer.
def __reduce__(self):
Expand Down Expand Up @@ -1394,6 +1442,9 @@ def concretization_function_error(fun, suggest_astype=False):
if fun is bool:
def error(self, arg):
raise TracerBoolConversionError(arg)
elif fun in (hex, oct, operator.index):
def error(self, arg):
raise TracerIntegerConversionError(arg)
else:
def error(self, arg):
raise ConcretizationTypeError(arg, fname_context)
Expand Down Expand Up @@ -1495,12 +1546,13 @@ def __repr__(self):
return '{}({}{})'.format(self.__class__.__name__, self.str_short(),
", weak_type=True" if self.weak_type else "")

_bool = _nonzero = concretization_function_error(bool)
_float = concretization_function_error(float, True)
_bool = concretization_function_error(bool)
_int = concretization_function_error(int, True)
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
_hex = concretization_function_error(hex)
_oct = concretization_function_error(oct)
_index = concretization_function_error(operator.index)

def at_least_vspace(self) -> AbstractValue:
return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype),
Expand Down Expand Up @@ -1659,13 +1711,14 @@ def str_short(self, short_dtypes=False) -> str:
dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name
return f'{self.val}, dtype={dt_str}'

_bool = _nonzero = partialmethod(_forward_to_value, bool)
_int = partialmethod(_forward_to_value, int)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
_bool = partialmethod(_forward_to_value, bool)
_int = partialmethod(_forward_to_value, int)
_hex = partialmethod(_forward_to_value, hex)
_oct = partialmethod(_forward_to_value, oct)
_index = partialmethod(_forward_to_value, operator.index)

_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)
_float = concretization_function_error(float, True)
_complex = concretization_function_error(complex, True)

def primal_dtype_to_tangent_dtype(primal_dtype):
# TODO(frostig,mattjj): determines that all extended dtypes have
Expand Down
47 changes: 46 additions & 1 deletion tests/api_test.py
Expand Up @@ -66,7 +66,8 @@
import jax.custom_batching
import jax.custom_derivatives
import jax.custom_transpose
from jax.errors import UnexpectedTracerError
from jax.errors import (UnexpectedTracerError, TracerIntegerConversionError,
ConcretizationTypeError, TracerBoolConversionError)
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching
Expand Down Expand Up @@ -4406,6 +4407,50 @@ def test_jvp_asarray_returns_array(self):
_check_instance(self, p)
_check_instance(self, t)

def test_scalar_conversion_errors(self):
array_int = jnp.arange(10, dtype=int)
scalar_float = jnp.float32(0)
scalar_int = jnp.int32(0)
array1_float = jnp.arange(1, dtype='float32')

assertIntError = partial(self.assertRaisesRegex, TypeError,
"Only integer scalar arrays can be converted to a scalar index.")
for func in [operator.index, hex, oct]:
assertIntError(func, array_int)
assertIntError(func, scalar_float)
assertIntError(jax.jit(func), array_int)
assertIntError(jax.jit(func), scalar_float)
self.assertRaises(TracerIntegerConversionError, jax.jit(func), scalar_int)
_ = func(scalar_int) # no error

assertScalarError = partial(self.assertRaisesRegex, TypeError,
"Only length-1 arrays can be converted to Python scalars.")
for func in [int, float, complex]:
assertScalarError(func, array_int)
assertScalarError(jax.jit(func), array_int)
self.assertRaises(ConcretizationTypeError, jax.jit(func), scalar_int)
_ = func(scalar_int) # no error
# TODO(jakevdp): remove this ignore warning when possible
with jtu.ignore_warning(category=DeprecationWarning):
self.assertRaises(ConcretizationTypeError, jax.jit(func), array1_float)
_ = func(array1_float) # no error

# TODO(jakevdp): add these tests once these deprecated operations error.
# empty_int = jnp.arange(0, dtype='int32')
# assertEmptyBoolError = partial(
# self.assertRaisesRegex, ValueError,
# "The truth value of an empty array is ambiguous.")
# assertEmptyBoolError(bool, empty_int)
# assertEmptyBoolError(jax.jit(bool), empty_int)

assertBoolError = partial(
self.assertRaisesRegex, ValueError,
"The truth value of an array with more than one element is ambiguous.")
assertBoolError(bool, array_int)
assertBoolError(jax.jit(bool), array_int)
self.assertRaises(TracerBoolConversionError, jax.jit(bool), scalar_int)
_ = bool(scalar_int) # no error


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 0dc2252

Please sign in to comment.