Skip to content

Commit

Permalink
Merge pull request #17750 from hawkinsp:mode
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568199918
  • Loading branch information
jax authors committed Sep 25, 2023
2 parents c478282 + 2fd6df4 commit f093b55
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ Remember to align the itemized text with the first line of an item within a list
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
* Changes:
* {func}`jax.scipy.stats.mode` now returns a 0 count if the mode is taken
across a size-0 axis, matching the behavior of `scipy.stats.mode` in SciPy
1.11.

# jaxlib 0.4.17

Expand Down
3 changes: 2 additions & 1 deletion jax/_src/scipy/stats/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def mode(a: ArrayLike, axis: Optional[int] = 0, nan_policy: str = "propagate", k
def _mode_helper(x: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Helper function to return mode and count of a given array."""
if x.size == 0:
return jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)), jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_))
return (jnp.array(jnp.nan, dtype=dtypes.canonicalize_dtype(jnp.float_)),
jnp.array(0, dtype=dtypes.canonicalize_dtype(jnp.float_)))
else:
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
return vals[jnp.argmax(counts)], counts.max()
Expand Down
19 changes: 13 additions & 6 deletions tests/scipy_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,13 +1481,20 @@ def testMode(self, shape, dtype, axis, contains_nans, keepdims):

def scipy_mode_wrapper(a, axis=0, nan_policy='propagate', keepdims=None):
"""Wrapper to manage the shape discrepancies between scipy and jax"""
if scipy_version < (1, 9, 0) and a.size == 0 and keepdims == True:
if axis == None:
output_shape = tuple(1 for _ in a.shape)
if scipy_version < (1, 11, 0) and a.size == 0:
if keepdims:
if axis == None:
output_shape = tuple(1 for _ in a.shape)
else:
output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape))
else:
output_shape = tuple(1 if i == axis else s for i, s in enumerate(a.shape))
return (np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)),
np.full(output_shape, np.nan, dtype=dtypes.canonicalize_dtype(jax.numpy.float_)))
if axis == None:
output_shape = ()
else:
output_shape = np.delete(np.array(a.shape, dtype=np.int64), axis)
t = dtypes.canonicalize_dtype(jax.numpy.float_)
return (np.full(output_shape, np.nan, dtype=t),
np.zeros(output_shape, dtype=t))

if scipy_version < (1, 9, 0):
result = osp_stats.mode(a, axis=axis, nan_policy=nan_policy)
Expand Down

0 comments on commit f093b55

Please sign in to comment.