Skip to content

Commit

Permalink
jnp.ufunc: support where argument in ufunc.reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 24, 2023
1 parent d452eea commit ac1233b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
42 changes: 28 additions & 14 deletions jax/_src/numpy/ufunc_api.py
Expand Up @@ -83,26 +83,33 @@ def __call__(self, *args, out=None, where=None, **kwargs):
@_wraps(np.ufunc.reduce, module="numpy.ufunc")
@partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims'])
def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, initial=None, where=None):
check_arraylike(f"{self.__name__}.reduce", a)
if self.nin != 2:
raise ValueError("reduce only supported for binary ufuncs")
if self.nout != 1:
raise ValueError("reduce only supported for functions returning a single value")
if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.reduce()")
# TODO(jakevdp): implement where.
if initial is not None:
check_arraylike(f"{self.__name__}.reduce", initial)
if where is not None:
raise NotImplementedError(f"where argument of {self.__name__}.reduce()")
return self._reduce_via_scan(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial)

def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None):
check_arraylike(f"{self.__name__}.reduce", where)
if self.identity is None and initial is None:
raise ValueError(f"reduction operation {self.__name__!r} does not have an identity, "
"so to use a where mask one has to specify 'initial'.")
if lax_internal._dtype(where) != bool:
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
return self._reduce_via_scan(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)

def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None, where=None):
assert self.nin == 2 and self.nout == 1
check_arraylike(f"{self.__name__}.reduce", arr)
arr = lax_internal.asarray(arr)
if initial is None:
initial = self.identity
if dtype is None:
dtype = jax.eval_shape(self._func, lax_internal._one(arr), lax_internal._one(arr)).dtype

if where is not None:
where = _broadcast_to(where, arr.shape)
if isinstance(axis, tuple):
axis = tuple(canonicalize_axis(a, arr.ndim) for a in axis)
raise NotImplementedError("tuple of axes")
Expand All @@ -112,6 +119,8 @@ def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None
else:
final_shape = ()
arr = arr.ravel()
if where is not None:
where = where.ravel()
axis = 0
else:
axis = canonicalize_axis(axis, arr.ndim)
Expand All @@ -123,23 +132,28 @@ def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None
# TODO: handle without transpose?
if axis != 0:
arr = _moveaxis(arr, axis, 0)
if where is not None:
where = _moveaxis(where, axis, 0)

if initial is None and arr.shape[0] == 0:
raise ValueError("zero-size array to reduction operation {self.__name__} which has no ideneity")

def body_fun(i, val):
return self._call(val, arr[i].astype(dtype))
if where is None:
return self._call(val, arr[i].astype(dtype))
else:
return _where(where[i], self._call(val, arr[i].astype(dtype)), val)

if initial is None:
start = 1
initial = arr[0]
start_index = 1
start_value = arr[0]
else:
check_arraylike(f"{self.__name__}.reduce", arr)
start = 0
start_index = 0
start_value = initial
start_value = _broadcast_to(lax_internal.asarray(start_value).astype(dtype), arr.shape[1:])

initial = _broadcast_to(lax_internal.asarray(initial).astype(dtype), arr.shape[1:])
result = jax.lax.fori_loop(start_index, arr.shape[0], body_fun, start_value)

result = jax.lax.fori_loop(start, arr.shape[0], body_fun, initial)
if keepdims:
result = result.reshape(final_shape)
return result
Expand Down
35 changes: 34 additions & 1 deletion tests/lax_numpy_ufuncs_test.py
Expand Up @@ -145,7 +145,40 @@ def test_reduce(self, func, nin, nout, identity, shape, axis, dtype):
args_maker = lambda: [rng(shape, dtype)]

self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker, check_cache_misses=False) # TODO(jakevdp): why the cache misses?
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in [None, *range(-len(shape), len(shape))]],
dtype=jtu.dtypes.floating,
)
def test_reduce_where(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")

# Need initial if identity is None
initial = 1 if identity is None else None

def jnp_fun(arr, where):
return jnp.frompyfunc(func, nin, nout, identity=identity).reduce(
arr, where=where, axis=axis, initial=initial)

@cast_outputs
def np_fun(arr, where):
# Workaround for https://github.com/numpy/numpy/issues/24530
# TODO(jakevdp): remove this when possible.
initial_workaround = identity if initial is None else initial
return np.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce(
arr, where=where, axis=axis, initial=initial_workaround)

rng = jtu.rand_default(self.rng())
rng_where = jtu.rand_bool(self.rng())
args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)]

self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
SCALAR_FUNCS,
Expand Down

0 comments on commit ac1233b

Please sign in to comment.