Skip to content

Commit

Permalink
Merge pull request #22257 from dfm:fix-numpy-nightly-test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 649108684
  • Loading branch information
jax authors committed Jul 3, 2024
2 parents e6ebd55 + 9e9acc9 commit 82f4b5a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,6 @@ def test_f16_mean(self, dtype):
actual = jnp.mean(x)
self.assertAllClose(expected, actual, atol=0)


@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in all_shapes
Expand Down Expand Up @@ -815,10 +814,14 @@ def np_mock_op(x, axis=None, dtype=None, include_initial=False):
out = jnp.concat([jnp.zeros(zeros_shape, dtype=out.dtype), out], axis=axis)
return out


# We currently "cheat" to ensure we have JAX arrays, not NumPy arrays as
# input because we rely on JAX-specific casting behavior
args_maker = lambda: [jnp.array(rng(shape, dtype))]
def args_maker():
x = jnp.array(rng(shape, dtype))
if out_dtype in unsigned_dtypes:
x = 10 * jnp.abs(x)
return [x]

np_op = getattr(np, "cumulative_sum", np_mock_op)
kwargs = dict(axis=axis, dtype=out_dtype, include_initial=include_initial)

Expand All @@ -827,7 +830,6 @@ def np_mock_op(x, axis=None, dtype=None, include_initial=False):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)


@jtu.sample_product(
shape=filter(lambda x: len(x) != 1, all_shapes), dtype=all_dtypes,
include_initial=[False, True])
Expand Down
10 changes: 10 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5971,6 +5971,7 @@ def testWrappedSignaturesMatch(self):
'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'],
'cov': ['dtype'],
'cumulative_sum': ['out'],
'empty_like': ['subok', 'order'],
'einsum': ['kwargs'],
'einsum_path': ['einsum_call'],
Expand Down Expand Up @@ -6021,6 +6022,15 @@ def testWrappedSignaturesMatch(self):
# numpy 1.24 re-orders the density and weights arguments.
# TODO(jakevdp): migrate histogram APIs to match newer numpy versions.
continue
if name == "clip":
# JAX's support of the Array API spec for clip, and the way it handles
# backwards compatibility was introduced in
# https://github.com/google/jax/pull/20550 with a different signature
# from the one in numpy, introduced in
# https://github.com/numpy/numpy/pull/26724
# TODO(dfm): After our deprecation period for the clip arguments ends
# it should be possible to reintroduce the check.
continue
# Note: can't use inspect.getfullargspec due to numpy issue
# https://github.com/numpy/numpy/issues/12225
try:
Expand Down

0 comments on commit 82f4b5a

Please sign in to comment.