Skip to content

Commit

Permalink
Added 'where' keyword to 'jnp.{mean, var, std}'
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Mar 12, 2021
1 parent 2fc2ff4 commit d743aa5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 17 deletions.
41 changes: 24 additions & 17 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,15 +2013,15 @@ def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,

@_wraps(np.all)
def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None):
keepdims=None, *, where=None):
return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims)
axis=axis, out=out, keepdims=keepdims, where_=where)

@_wraps(np.any)
def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
keepdims=None):
keepdims=None, *, where=None):
return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
axis=axis, out=out, keepdims=keepdims)
axis=axis, out=out, keepdims=keepdims, where_=where)

product = prod
amin = min
Expand All @@ -2040,16 +2040,20 @@ def _axis_size(a, axis):

@_wraps(np.mean)
def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, keepdims=False):
out=None, keepdims=False, *, where=None):
_check_arraylike("mean", a)
lax._check_user_dtype_supported(dtype, "mean")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")

if axis is None:
normalizer = size(a)
if where is None:
if axis is None:
normalizer = size(a)
else:
normalizer = _axis_size(a, axis)
else:
normalizer = _axis_size(a, axis)
normalizer = sum(broadcast_to(where, a.shape), axis, dtype=dtype, keepdims=keepdims)

if dtype is None:
if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
dtype = float_
Expand All @@ -2058,7 +2062,7 @@ def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
dtype = dtypes.canonicalize_dtype(dtype)

return lax.div(
sum(a, axis, dtype=dtype, keepdims=keepdims),
sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
lax.convert_element_type(normalizer, dtype))

@_wraps(np.average)
Expand Down Expand Up @@ -2113,27 +2117,30 @@ def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,

@_wraps(np.var)
def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False):
out=None, ddof=0, keepdims=False, *, where=None):
_check_arraylike("var", a)
lax._check_user_dtype_supported(dtype, "var")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")

a_dtype, dtype = _var_promote_types(_dtype(a), dtype)
a_mean = mean(a, axis, dtype=a_dtype, keepdims=True)
a_mean = mean(a, axis, dtype=a_dtype, keepdims=True, where=where)
centered = a - a_mean
if issubdtype(centered.dtype, complexfloating):
centered = lax.real(lax.mul(centered, lax.conj(centered)))
else:
centered = lax.square(centered)

if axis is None:
normalizer = size(a)
if where is None:
if axis is None:
normalizer = size(a)
else:
normalizer = _axis_size(a, axis)
else:
normalizer = _axis_size(a, axis)
normalizer = sum(broadcast_to(where, a.shape), axis, dtype=dtype, keepdims=keepdims)
normalizer = normalizer - ddof

result = sum(centered, axis, keepdims=keepdims)
result = sum(centered, axis, keepdims=keepdims, where=where)
out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
return lax.convert_element_type(out, dtype)

Expand All @@ -2160,12 +2167,12 @@ def _var_promote_types(a_dtype, dtype):

@_wraps(np.std)
def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
out=None, ddof=0, keepdims=False):
out=None, ddof=0, keepdims=False, *, where=None):
_check_arraylike("std", a)
lax._check_user_dtype_supported(dtype, "std")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims))
return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))


@_wraps(np.ptp)
Expand Down
53 changes: 53 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,17 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
op_record("min", 1, all_dtypes, all_shapes, jtu.rand_default, []),
]

JAX_REDUCER_WHERE_NO_INITIAL_RECORDS = [
op_record("all", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("any", 1, bool_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("mean", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("var", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
op_record("std", 1, all_dtypes, nonempty_shapes, jtu.rand_default, [],
inexact=True),
]

JAX_REDUCER_NO_DTYPE_RECORDS = [
op_record("all", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
op_record("any", 1, all_dtypes, all_shapes, jtu.rand_some_zero, []),
Expand Down Expand Up @@ -853,6 +864,48 @@ def np_fun(x):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@unittest.skipIf(numpy_version < (1, 20), "where parameter not supported in older numpy")
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": "{}_inshape={}_axis={}_keepdims={}_whereshape={}".format(
rec.test_name.capitalize(),
jtu.format_shape_dtype_string(shape, dtype), axis, keepdims,
jtu.format_shape_dtype_string(whereshape, bool)),
"rng_factory": rec.rng_factory, "shape": shape, "dtype": dtype,
"np_op": getattr(np, rec.name), "jnp_op": getattr(jnp, rec.name), "whereshape": whereshape,
"axis": axis, "keepdims": keepdims, "inexact": rec.inexact}
for shape in rec.shapes for dtype in rec.dtypes
for whereshape in _compatible_shapes(shape)
for axis in list(range(-len(shape), len(shape))) + [None]
for keepdims in [False, True])
for rec in JAX_REDUCER_WHERE_NO_INITIAL_RECORDS))
def testReducerWhereNoInitial(self, np_op, jnp_op, rng_factory, shape, dtype, axis,
keepdims, inexact, whereshape):
rng = rng_factory(self.rng())
is_bf16_nan_test = dtype == jnp.bfloat16
# Do not pass where via args_maker as that is incompatible with _promote_like_jnp.
where = jtu.rand_bool(self.rng())(whereshape, np.bool_)
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="Mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="invalid value encountered in true_divide*")
def np_fun(x):
x_cast = x if not is_bf16_nan_test else x.astype(np.float32)
res = np_op(x_cast, axis, keepdims=keepdims, where=where)
res = res if not is_bf16_nan_test else res.astype(jnp.bfloat16)
return res

np_fun = _promote_like_jnp(np_fun, inexact)
np_fun = jtu.ignore_warning(category=np.ComplexWarning)(np_fun)
jnp_fun = lambda x: jnp_op(x, axis, keepdims=keepdims, where=where)
jnp_fun = jtu.ignore_warning(category=jnp.ComplexWarning)(jnp_fun)
args_maker = lambda: [rng(shape, dtype)]
if numpy_version >= (1, 20, 2) or np_op.__name__ in ("all", "any"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis),
Expand Down

0 comments on commit d743aa5

Please sign in to comment.