Skip to content

Commit

Permalink
jnp.piecewise: support scalar inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 9, 2021
1 parent b0d14fd commit dbdb189
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -5141,12 +5141,12 @@ def piecewise(x, condlist, funclist, *args, **kw):
funclist = [0] + list(funclist)
else:
raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}")
indices = argmax(cumsum(vstack([zeros_like(condlist[:1]), condlist]), 0), 0)
indices = argmax(cumsum(concatenate([zeros_like(condlist[:1]), condlist], 0), 0), 0)
dtype = _dtype(x)
def _call(f):
return lambda x: f(x, *args, **kw).astype(dtype)
def _const(v):
return lambda x: full_like(x, v)
return lambda x: array(v, dtype=dtype)
funclist = [_call(f) if callable(f) else _const(f) for f in funclist]
return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)

Expand Down
11 changes: 6 additions & 5 deletions jax/test_util.py
Expand Up @@ -875,7 +875,7 @@ def assertMultiLineStrippedEqual(self, expected, what):
msg="Found\n{}\nExpecting\n{}".format(what, expected))

def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True,
rtol=None, atol=None):
rtol=None, atol=None, check_cache_misses=True):
"""Helper method for running JAX compilation and allclose assertions."""
args = args_maker()

Expand All @@ -892,10 +892,11 @@ def wrapped_fun(*args):

cache_misses = xla.xla_primitive_callable.cache_info().misses
python_ans = fun(*args)
self.assertEqual(
cache_misses, xla.xla_primitive_callable.cache_info().misses,
"Compilation detected during second call of {} in op-by-op "
"mode.".format(fun))
if check_cache_misses:
self.assertEqual(
cache_misses, xla.xla_primitive_callable.cache_info().misses,
"Compilation detected during second call of {} in op-by-op "
"mode.".format(fun))

cfun = api.jit(wrapped_fun)
python_should_be_executing = True
Expand Down
7 changes: 4 additions & 3 deletions tests/lax_numpy_test.py
Expand Up @@ -1647,17 +1647,18 @@ def testExtract(self, shape, dtype):
"shape": shape, "dtype": dtype, "ncond": ncond, "nfunc": nfunc}
for ncond in [1, 2, 3]
for nfunc in [ncond, ncond + 1]
for shape in nonempty_nonscalar_array_shapes
for shape in all_shapes
for dtype in all_dtypes))
def testPiecewise(self, shape, dtype, ncond, nfunc):
rng = jtu.rand_default(self.rng())
rng_bool = jtu.rand_int(self.rng(), 0, 2)
funclist = [lambda x: x - 1, 1, lambda x: x, 0][:nfunc]
args_maker = lambda: (rng(shape, dtype), list(rng_bool((ncond,) + shape, bool)))
args_maker = lambda: (rng(shape, dtype), [rng_bool(shape, bool) for i in range(ncond)])
np_fun = partial(np.piecewise, funclist=funclist)
jnp_fun = partial(jnp.piecewise, funclist=funclist)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True)
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True)
# This is a higher-order function, so the cache miss check will fail.
self._CompileAndCheck(jnp_fun, args_maker, check_dtypes=True, check_cache_misses=False)


@parameterized.named_parameters(jtu.cases_from_list(
Expand Down

0 comments on commit dbdb189

Please sign in to comment.