diff --git a/CHANGELOG.md b/CHANGELOG.md index f050f54093b9..38aa61df3921 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ Remember to align the itemized text with the first line of an item within a list {meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`, {meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and {meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`#17054`). + * Added {func}`jax.scipy.integrate.trapezoid`. * When not running under IPython: when an exception is raised, JAX now filters out the entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace" that previously appeared.) This should produce much friendlier-looking tracebacks. See @@ -44,6 +45,7 @@ Remember to align the itemized text with the first line of an item within a list * `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`. * `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead. * `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead. + * `jax.numpy.trapz` has been deprecated. Use `jax.scipy.integrate.trapezoid` instead. * `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated, following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead. * `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11. diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index a058c0af0be1..e3a6be839083 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -14,6 +14,16 @@ jax.scipy.fft idct idctn +jax.scipy.integrate +------------------- + +.. automodule:: jax.scipy.integrate + +.. autosummary:: + :toctree: _autosummary + + trapezoid + jax.scipy.linalg ---------------- diff --git a/jax/_src/scipy/integrate.py b/jax/_src/scipy/integrate.py new file mode 100644 index 000000000000..d61ded3a0864 --- /dev/null +++ b/jax/_src/scipy/integrate.py @@ -0,0 +1,44 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import partial + +import scipy.integrate + +from jax import jit +from jax._src.numpy import util +from jax._src.typing import Array, ArrayLike +import jax.numpy as jnp + +@util._wraps(scipy.integrate.trapezoid) +@partial(jit, static_argnames=('axis',)) +def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0, + axis: int = -1) -> Array: + # TODO(phawkins): remove this annotation after fixing jnp types. + dx_array: Array + if x is None: + util.check_arraylike('trapz', y) + y_arr, = util.promote_dtypes_inexact(y) + dx_array = jnp.asarray(dx) + else: + util.check_arraylike('trapz', y, x) + y_arr, x_arr = util.promote_dtypes_inexact(y, x) + if x_arr.ndim == 1: + dx_array = jnp.diff(x_arr) + else: + dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1) + y_arr = jnp.moveaxis(y_arr, axis, -1) + return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index fdfa1a62a128..472f06040f34 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -226,7 +226,7 @@ tensordot as tensordot, tile as tile, trace as trace, - trapz as trapz, + trapz as _deprecated_trapz, transpose as transpose, tri as tri, tril as tril, @@ -474,6 +474,11 @@ "jax.numpy.in1d is deprecated. Use jax.numpy.isin instead.", _deprecated_in1d, ), + # Added Aug 24, 2023 + "trapz": ( + "jax.numpy.trapz is deprecated. Use jax.scipy.integrate.trapezoid instead.", + _deprecated_trapz, + ), } import typing @@ -488,6 +493,7 @@ PZERO = 0.0 issubsctype = _numpy.core.numerictypes.issubsctype in1d = _deprecated_in1d + trapz = _deprecated_trapz else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) @@ -496,3 +502,4 @@ del _numpy del _deprecated_in1d +del _deprecated_trapz diff --git a/jax/scipy/__init__.py b/jax/scipy/__init__.py index 6f650b55a300..c0746910dd3f 100644 --- a/jax/scipy/__init__.py +++ b/jax/scipy/__init__.py @@ -27,6 +27,7 @@ from jax.scipy import stats as stats from jax.scipy import fft as fft from jax.scipy import cluster as cluster + from jax.scipy import integrate as integrate else: import jax._src.lazy_loader as _lazy __getattr__, __dir__, __all__ = _lazy.attach(__name__, [ @@ -39,6 +40,7 @@ "stats", "fft", "cluster", + "integrate", ]) del _lazy diff --git a/jax/scipy/integrate.py b/jax/scipy/integrate.py new file mode 100644 index 000000000000..b19aa054ca00 --- /dev/null +++ b/jax/scipy/integrate.py @@ -0,0 +1,20 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/google/jax/issues/7570 + +from jax._src.scipy.integrate import ( + trapezoid as trapezoid +) diff --git a/pyproject.toml b/pyproject.toml index 3fc5acb2bea7..32f06ed1ec02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ filterwarnings = [ "ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning", "ignore:np.find_common_type is deprecated.*:DeprecationWarning", "ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning", + "ignore:jax.numpy.trapz is deprecated.*:DeprecationWarning", ] doctest_optionflags = [ "NUMBER", diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 12703a827e7f..7aa1471b28f8 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -20,10 +20,12 @@ from absl.testing import absltest import numpy as np +import scipy.integrate import scipy.special as osp_special import scipy.cluster as osp_cluster import jax +import jax.dtypes from jax import numpy as jnp from jax import lax from jax import scipy as jsp @@ -542,5 +544,33 @@ def test_spence(self, shape, dtype): self.assertArraysEqual(actual, nan_array, check_dtypes=False) + @jtu.sample_product( + [dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis) + for yshape, xshape, dx, axis in [ + ((10,), None, 1.0, -1), + ((3, 10), None, 2.0, -1), + ((3, 10), None, 3.0, -0), + ((10, 3), (10,), 1.0, -2), + ((3, 10), (10,), 1.0, -1), + ((3, 10), (3, 10), 1.0, -1), + ((2, 3, 10), (3, 10), 1.0, -2), + ] + ], + dtype=float_dtypes + int_dtypes, + ) + @jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test. + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def testIntegrateTrapezoid(self, yshape, xshape, dtype, dx, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None] + np_fun = partial(scipy.integrate.trapezoid, dx=dx, axis=axis) + jnp_fun = partial(jax.scipy.integrate.trapezoid, dx=dx, axis=axis) + tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12, + jax.dtypes.bfloat16: 4e-2}) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol, + check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol, + check_dtypes=False) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())