Skip to content

Commit

Permalink
Deprecate jax.numpy.trapz.
Browse files Browse the repository at this point in the history
Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead.

Fixes #17244
  • Loading branch information
hawkinsp committed Aug 25, 2023
1 parent a454081 commit 975dae3
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -14,6 +14,7 @@ Remember to align the itemized text with the first line of an item within a list
{meth}`~jax.numpy.ufunc.outer`, {meth}`~jax.numpy.ufunc.reduce`,
{meth}`~jax.numpy.ufunc.accumulate`, {meth}`~jax.numpy.ufunc.at`, and
{meth}`~jax.numpy.ufunc.reduceat` ({jax-issue}`#17054`).
* Added {func}`jax.scipy.integrate.trapezoid`.
* When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See
Expand Down Expand Up @@ -44,6 +45,7 @@ Remember to align the itemized text with the first line of an item within a list
* `jax.numpy.issubsctype(x, t)` has been deprecated. Use `jax.numpy.issubdtype(x.dtype, t)`.
* `jax.numpy.row_stack` has been deprecated. Use `jax.numpy.vstack` instead.
* `jax.numpy.in1d` has been deprecated. Use `jax.numpy.isin` instead.
* `jax.numpy.trapz` has been deprecated. Use `jax.scipy.integrate.trapezoid` instead.
* `jax.scipy.linalg.tril` and `jax.scipy.linalg.triu` have been deprecated,
following SciPy. Use `jax.numpy.tril` and `jax.numpy.triu` instead.
* `jax.lax.prod` has been removed after being deprecated in JAX v0.4.11.
Expand Down
10 changes: 10 additions & 0 deletions docs/jax.scipy.rst
Expand Up @@ -14,6 +14,16 @@ jax.scipy.fft
idct
idctn

jax.scipy.integrate
-------------------

.. automodule:: jax.scipy.integrate

.. autosummary::
:toctree: _autosummary

trapezoid

jax.scipy.linalg
----------------

Expand Down
44 changes: 44 additions & 0 deletions jax/_src/scipy/integrate.py
@@ -0,0 +1,44 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from functools import partial

import scipy.integrate

from jax import jit
from jax._src.numpy import util
from jax._src.typing import Array, ArrayLike
import jax.numpy as jnp

@util._wraps(scipy.integrate.trapezoid)
@partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:
# TODO(phawkins): remove this annotation after fixing jnp types.
dx_array: Array
if x is None:
util.check_arraylike('trapz', y)
y_arr, = util.promote_dtypes_inexact(y)
dx_array = jnp.asarray(dx)
else:
util.check_arraylike('trapz', y, x)
y_arr, x_arr = util.promote_dtypes_inexact(y, x)
if x_arr.ndim == 1:
dx_array = jnp.diff(x_arr)
else:
dx_array = jnp.moveaxis(jnp.diff(x_arr, axis=axis), axis, -1)
y_arr = jnp.moveaxis(y_arr, axis, -1)
return 0.5 * (dx_array * (y_arr[..., 1:] + y_arr[..., :-1])).sum(-1)
9 changes: 8 additions & 1 deletion jax/numpy/__init__.py
Expand Up @@ -226,7 +226,7 @@
tensordot as tensordot,
tile as tile,
trace as trace,
trapz as trapz,
trapz as _deprecated_trapz,
transpose as transpose,
tri as tri,
tril as tril,
Expand Down Expand Up @@ -474,6 +474,11 @@
"jax.numpy.in1d is deprecated. Use jax.numpy.isin instead.",
_deprecated_in1d,
),
# Added Aug 24, 2023
"trapz": (
"jax.numpy.trapz is deprecated. Use jax.scipy.integrate.trapezoid instead.",
_deprecated_trapz,
),
}

import typing
Expand All @@ -488,6 +493,7 @@
PZERO = 0.0
issubsctype = _numpy.core.numerictypes.issubsctype
in1d = _deprecated_in1d
trapz = _deprecated_trapz
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
Expand All @@ -496,3 +502,4 @@
del _numpy

del _deprecated_in1d
del _deprecated_trapz
2 changes: 2 additions & 0 deletions jax/scipy/__init__.py
Expand Up @@ -27,6 +27,7 @@
from jax.scipy import stats as stats
from jax.scipy import fft as fft
from jax.scipy import cluster as cluster
from jax.scipy import integrate as integrate
else:
import jax._src.lazy_loader as _lazy
__getattr__, __dir__, __all__ = _lazy.attach(__name__, [
Expand All @@ -39,6 +40,7 @@
"stats",
"fft",
"cluster",
"integrate",
])
del _lazy

Expand Down
20 changes: 20 additions & 0 deletions jax/scipy/integrate.py
@@ -0,0 +1,20 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.scipy.integrate import (
trapezoid as trapezoid
)
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -73,6 +73,7 @@ filterwarnings = [
"ignore:JAX_USE_PJRT_C_API_ON_TPU=false will no longer be supported.*:UserWarning",
"ignore:np.find_common_type is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning",
"ignore:jax.numpy.trapz is deprecated.*:DeprecationWarning",
]
doctest_optionflags = [
"NUMBER",
Expand Down
30 changes: 30 additions & 0 deletions tests/lax_scipy_test.py
Expand Up @@ -20,10 +20,12 @@
from absl.testing import absltest

import numpy as np
import scipy.integrate
import scipy.special as osp_special
import scipy.cluster as osp_cluster

import jax
import jax.dtypes
from jax import numpy as jnp
from jax import lax
from jax import scipy as jsp
Expand Down Expand Up @@ -542,5 +544,33 @@ def test_spence(self, shape, dtype):
self.assertArraysEqual(actual, nan_array, check_dtypes=False)


@jtu.sample_product(
[dict(yshape=yshape, xshape=xshape, dx=dx, axis=axis)
for yshape, xshape, dx, axis in [
((10,), None, 1.0, -1),
((3, 10), None, 2.0, -1),
((3, 10), None, 3.0, -0),
((10, 3), (10,), 1.0, -2),
((3, 10), (10,), 1.0, -1),
((3, 10), (3, 10), 1.0, -1),
((2, 3, 10), (3, 10), 1.0, -2),
]
],
dtype=float_dtypes + int_dtypes,
)
@jtu.skip_on_devices("tpu") # TODO(jakevdp): fix and reenable this test.
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testIntegrateTrapezoid(self, yshape, xshape, dtype, dx, axis):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(yshape, dtype), rng(xshape, dtype) if xshape is not None else None]
np_fun = partial(scipy.integrate.trapezoid, dx=dx, axis=axis)
jnp_fun = partial(jax.scipy.integrate.trapezoid, dx=dx, axis=axis)
tol = jtu.tolerance(dtype, {np.float16: 2e-3, np.float64: 1e-12,
jax.dtypes.bfloat16: 4e-2})
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=tol,
check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker, atol=tol, rtol=tol,
check_dtypes=False)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 975dae3

Please sign in to comment.