Skip to content

Commit

Permalink
[x64] add promote_integers argument to jnp.prod & jnp.sum
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Sep 26, 2022
1 parent e034432 commit 1860f6d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 19 deletions.
64 changes: 45 additions & 19 deletions jax/_src/numpy/reductions.py
Expand Up @@ -26,7 +26,7 @@
from jax._src import api
from jax._src import dtypes
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.util import _broadcast_to, _check_arraylike, _complex_elem_type, _promote_dtypes_inexact, _where, _wraps
from jax._src.numpy.util import _broadcast_to, _check_arraylike, _complex_elem_type, _promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps
from jax._src.lax import lax as lax_internal
from jax._src.util import canonicalize_axis as _canonicalize_axis, maybe_named_axis

Expand Down Expand Up @@ -62,7 +62,7 @@ def _upcast_f16(dtype):
def _reduction(a, name, np_fun, op, init_val, has_identity=True,
preproc=None, bool_op=None, upcast_f16_for_computation=False,
axis=None, dtype=None, out=None, keepdims=False, initial=None,
where_=None, parallel_reduce=None):
where_=None, parallel_reduce=None, promote_integers=False):
bool_op = bool_op or op
# Note: we must accept out=None as an argument, because numpy reductions delegate to
# object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
Expand All @@ -86,7 +86,18 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True,
if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims):
raise ValueError(f"zero-size array to reduction operation {name} which has no identity")

result_dtype = dtypes.canonicalize_dtype(dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a)))))
result_dtype = dtypes.canonicalize_dtype(dtype or dtypes.dtype(a))

# promote_integers=True matches NumPy's behavior for sum() and prod(), which promotes
# all int-like inputs to the widest available dtype.
if dtype is None and promote_integers:
if dtypes.issubdtype(result_dtype, np.bool_):
result_dtype = dtypes.canonicalize_dtype(np.int64)
elif dtypes.issubdtype(result_dtype, np.unsignedinteger):
result_dtype = dtypes.canonicalize_dtype(np.uint64)
elif dtypes.issubdtype(result_dtype, np.integer):
result_dtype = dtypes.canonicalize_dtype(np.int64)

if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact):
computation_dtype = _upcast_f16(result_dtype)
else:
Expand Down Expand Up @@ -146,6 +157,9 @@ def _cast_to_bool(operand):
warnings.filterwarnings("ignore", category=np.ComplexWarning)
return lax.convert_element_type(operand, np.bool_)

def _cast_to_numeric(operand):
return _promote_dtypes_numeric(operand)[0]


def _ensure_optional_axes(x):
def force(x):
Expand All @@ -159,34 +173,46 @@ def force(x):
force, x, "The axis argument must be known statically.")


@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _reduce_sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, keepdims=None, initial=None, where=None):
return _reduction(a, "sum", np.sum, lax.add, 0,
# TODO(jakevdp) change promote_integers default to False
_PROMOTE_INTEGERS_DOC = """
promote_integers : bool, default=True
If True, then integer inputs will be promoted to the widest available integer
dtype, following numpy's behavior. If False, the result will have the same dtype
as the input. ``promote_integers`` is ignored if ``dtype`` is specified.
"""


@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None, promote_integers=True):
return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where, parallel_reduce=lax.psum)
initial=initial, where_=where, parallel_reduce=lax.psum,
promote_integers=promote_integers)

@_wraps(np.sum, skip_params=['out'])
@_wraps(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
out=None, keepdims=None, initial=None, where=None, promote_integers=True):
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
keepdims=keepdims, initial=initial, where=where)
keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)


@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _reduce_prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
dtype=None, out=None, keepdims=None, initial=None, where=None):
return _reduction(a, "prod", np.prod, lax.mul, 1,
@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
def _reduce_prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None, promote_integers=True):
return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric,
bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)
initial=initial, where_=where, promote_integers=promote_integers)

@_wraps(np.prod, skip_params=['out'])
@_wraps(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC)
def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=None, initial=None, where=None):
out=None, keepdims=None, initial=None, where=None, promote_integers=True):
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
out=out, keepdims=keepdims, initial=initial, where=where)
out=out, keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)


@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
Expand Down
48 changes: 48 additions & 0 deletions tests/lax_numpy_test.py
Expand Up @@ -407,6 +407,11 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
op_record("ptp", 1, number_dtypes, nonempty_shapes, jtu.rand_default, []),
]

JAX_REDUCER_PROMOTE_INT_RECORDS = [
op_record("prod", 1, all_dtypes, all_shapes, jtu.rand_small_positive, []),
op_record("sum", 1, all_dtypes, all_shapes, jtu.rand_default, []),
]

JAX_ARGMINMAX_RECORDS = [
op_record("argmin", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
op_record("argmax", 1, default_dtypes, nonempty_shapes, jtu.rand_some_equal, []),
Expand Down Expand Up @@ -941,6 +946,49 @@ def np_fun(x):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_initial={}_promote_integers={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims, initial, promote_integers),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name),
"initial": initial, "axis": axis, "keepdims": keepdims, "inexact": rec.inexact,
"promote_integers": promote_integers}
for shape in rec.shapes for dtype in rec.dtypes
for axis in list(range(-len(shape), len(shape))) + [None]
for initial in [0, 1] for keepdims in [False, True]
for promote_integers in [True, False]
if jtu.is_valid_shape(shape, dtype))
for rec in JAX_REDUCER_PROMOTE_INT_RECORDS))
def testReducerPromoteInt(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, initial, inexact, promote_integers):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan'
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=np.ComplexWarning)
def np_fun(x):
x = np.asarray(x)
if inexact:
x = x.astype(dtypes.to_inexact_dtype(x.dtype))
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, initial=initial)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
print(f"res.dtype = {res.dtype}")
if not promote_integers and dtypes.issubdtype(res.dtype, np.integer):
res = res.astype(dtypes.to_numeric_dtype(x.dtype))
return res

jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, initial=initial, promote_integers=promote_integers)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
tol = {jnp.bfloat16: 3E-2}
print(jnp_fun(*args_maker()))
print(np_fun(*args_maker()))
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, rtol=tol)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}".format(
Expand Down

0 comments on commit 1860f6d

Please sign in to comment.