Skip to content

Commit

Permalink
DOC: Improve remaining jax.scipy docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 2, 2024
1 parent 51fc4f8 commit 18e4cfa
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 54 deletions.
10 changes: 10 additions & 0 deletions docs/jax.scipy.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
``jax.scipy`` module
====================

jax.scipy.cluster
-----------------

.. automodule:: jax.scipy.cluster.vq

.. autosummary::
:toctree: _autosummary

vq

jax.scipy.fft
-------------

Expand Down
84 changes: 56 additions & 28 deletions jax/_src/scipy/cluster/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,64 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import operator

import scipy.cluster.vq
import textwrap

from jax import vmap
import jax.numpy as jnp
from jax._src.numpy.util import implements, check_arraylike, promote_dtypes_inexact


_no_chkfinite_doc = textwrap.dedent("""
Does not support the Scipy argument ``check_finite=True``,
because compiled JAX code cannot perform checks of array values at runtime
""")


@implements(scipy.cluster.vq.vq, lax_description=_no_chkfinite_doc, skip_params=('check_finite',))
def vq(obs, code_book, check_finite=True):
check_arraylike("scipy.cluster.vq.vq", obs, code_book)
if obs.ndim != code_book.ndim:
raise ValueError("Observation and code_book should have the same rank")
obs, code_book = promote_dtypes_inexact(obs, code_book)
if obs.ndim == 1:
obs, code_book = obs[..., None], code_book[..., None]
if obs.ndim != 2:
raise ValueError("ndim different than 1 or 2 are not supported")

# explicitly rank promotion
dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - code_book, axis=-1))(obs)
code = jnp.argmin(dist, axis=-1)
dist_min = vmap(operator.getitem)(dist, code)
return code, dist_min
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike


def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
"""Assign codes from a code book to a set of observations.
JAX implementation of :func:`scipy.cluster.vq.vq`.
Assigns each observation vector in ``obs`` to a code from ``code_book``
based on the nearest Euclidean distance.
Args:
obs: array of observation vectors of shape ``(M, N)``. Each row represents
a single observation. If ``obs`` is one-dimensional, then each entry is
treated as a length-1 observation.
code_book: array of codes with shape ``(K, N)``. Each row represents a single
code vector. If ``code_book`` is one-dimensional, then each entry is treated
as a length-1 code.
check_finite: unused in JAX
Returns:
A tuple of arrays ``(code, dist)``
- ``code`` is an integer array of shape ``(M,)`` containing indices ``0 <= i < K``
of the closest entry in ``code_book`` for the given entry in ``obs``.
- ``dist`` is a float array of shape ``(M,)`` containing the euclidean
distance between each observation and the nearest code.
Examples:
>>> obs = jnp.array([[1.1, 2.1, 3.1],
... [5.9, 4.8, 6.2]])
>>> code_book = jnp.array([[1., 2., 3.],
... [2., 3., 4.],
... [3., 4., 5.],
... [4., 5., 6.]])
>>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book)
>>> print(codes)
[0 3]
>>> print(distances)
[0.17320499 1.9209373 ]
"""
del check_finite # unused
check_arraylike("scipy.cluster.vq.vq", obs, code_book)
obs_arr, cb_arr = promote_dtypes_inexact(obs, code_book)
if obs_arr.ndim != cb_arr.ndim:
raise ValueError("Observation and code_book should have the same rank")
if obs_arr.ndim == 1:
obs_arr, cb_arr = obs_arr[..., None], cb_arr[..., None]
if obs_arr.ndim != 2:
raise ValueError("ndim different than 1 or 2 are not supported")
dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr)
code = jnp.argmin(dist, axis=-1)
dist_min = vmap(operator.getitem)(dist, code)
return code, dist_min
47 changes: 43 additions & 4 deletions jax/_src/scipy/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,54 @@

from functools import partial

import scipy.integrate

from jax import jit
from jax._src.numpy import util
from jax._src.typing import Array, ArrayLike
import jax.numpy as jnp

@util.implements(scipy.integrate.trapezoid)

@partial(jit, static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:
r"""
Integrate along the given axis using the composite trapezoidal rule.
JAX implementation of :func:`scipy.integrate.trapezoid`
The trapezoidal rule approximates the integral under a curve by summing the
areas of trapezoids formed between adjacent data points.
Args:
y: array of data to integrate.
x: optional array of sample points corresponding to the ``y`` values. If not
provided, ``x`` defaults to equally spaced with spacing given by ``dx``.
dx: The spacing between sample points when `x` is None (default: 1.0).
axis: The axis along which to integrate (default: -1)
Returns:
The definite integral approximated by the trapezoidal rule.
See also:
:func:`jax.numpy.trapezoid`: NumPy-style API for trapezoidal integration
Examples:
Integrate over a regular grid, with spacing 1.0:
>>> y = jnp.array([1, 2, 3, 2, 3, 2, 1])
>>> jax.scipy.integrate.trapezoid(y, dx=1.0)
Array(13., dtype=float32)
Integrate over an irregular grid:
>>> x = jnp.array([0, 2, 5, 7, 10, 15, 20])
>>> jax.scipy.integrate.trapezoid(y, x)
Array(43., dtype=float32)
Approximate :math:`\int_0^{2\pi} \sin^2(x)dx`, which equals :math:`\pi`:
>>> x = jnp.linspace(0, 2 * jnp.pi, 1000)
>>> y = jnp.sin(x) ** 2
>>> result = jax.scipy.integrate.trapezoid(y, x)
>>> jnp.allclose(result, jnp.pi)
Array(True, dtype=bool)
"""
return jnp.trapezoid(y, x, dx, axis)
60 changes: 49 additions & 11 deletions jax/_src/scipy/ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,12 @@
import functools
import itertools
import operator
import textwrap
from typing import Callable

import scipy.ndimage

from jax._src import api
from jax._src import util
from jax import lax
import jax.numpy as jnp
from jax._src.numpy.util import implements
from jax._src.typing import ArrayLike, Array
from jax._src.util import safe_zip as zip

Expand Down Expand Up @@ -127,15 +123,57 @@ def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
return result.astype(input_arr.dtype)


@implements(scipy.ndimage.map_coordinates, lax_description=textwrap.dedent("""\
"""
Only nearest neighbor (``order=0``), linear interpolation (``order=1``) and
modes ``'constant'``, ``'nearest'``, ``'wrap'`` ``'mirror'`` and ``'reflect'`` are currently supported.
Note that interpolation near boundaries differs from the scipy function,
because we fixed an outstanding bug (https://github.com/scipy/scipy/issues/2640);
this function interprets the ``mode`` argument as documented by SciPy, but
not as implemented by SciPy.
"""))
"""

def map_coordinates(
input: ArrayLike, coordinates: Sequence[ArrayLike], order: int, mode: str = 'constant', cval: ArrayLike = 0.0,
input: ArrayLike, coordinates: Sequence[ArrayLike], order: int,
mode: str = 'constant', cval: ArrayLike = 0.0,
):
"""
Map the input array to new coordinates using interpolation.
JAX implementation of :func:`scipy.ndimage.map_coordinates`
Given an input array and a set of coordinates, this function returns the
interpolated values of the input array at those coordinates.
Args:
input: N-dimensional input array from which values are interpolated.
coordinates: length-N sequence of arrays specifying the coordinates
at which to evaluate the interpolated values
order: The order of interpolation. JAX supports the following:
* 0: Nearest-neighbor
* 1: Linear
mode: Points outside the boundaries of the input are filled according to the given mode.
JAX supports one of ``('constant', 'nearest', 'mirror', 'wrap', 'reflect')``.
Default is 'constant'.
cval: Value used for points outside the boundaries of the input if ``mode='constant'``
Default is 0.0.
Returns:
The interpolated values at the specified coordinates.
Examples:
>>> input = jnp.arange(12.0).reshape(3, 4)
>>> input
Array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]], dtype=float32)
>>> coordinates = [jnp.array([0.5, 1.5]),
... jnp.array([1.5, 2.5])]
>>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1)
Array([3.5, 8.5], dtype=float32)
Note:
Interpolation near boundaries differs from the scipy function, because JAX
fixed an outstanding bug; see https://github.com/google/jax/issues/11097.
This function interprets the ``mode`` argument as documented by SciPy, but
not as implemented by SciPy.
"""
return _map_coordinates(input, coordinates, order, mode, cval)
65 changes: 58 additions & 7 deletions jax/_src/scipy/spatial/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,46 @@
import re
import typing

import scipy.spatial.transform

import jax
import jax.numpy as jnp
from jax._src.numpy.util import implements


@implements(scipy.spatial.transform.Rotation)
class Rotation(typing.NamedTuple):
"""Rotation in 3 dimensions."""
"""Rotation in 3 dimensions.
JAX implementation of :class:`scipy.spatial.transform.Rotation`.
Examples:
Construct an object describing a 90 degree rotation about the z-axis:
>>> from jax.scipy.spatial.transform import Rotation
>>> r = Rotation.from_euler('z', 90, degrees=True)
Convert to a rotation vector:
>>> r.as_rotvec()
Array([0. , 0. , 1.5707964], dtype=float32)
Convert to rotation matrix:
>>> r.as_matrix()
Array([[ 0. , -0.99999994, 0. ],
[ 0.99999994, 0. , 0. ],
[ 0. , 0. , 0.99999994]], dtype=float32)
Compose with another rotation:
>>> r2 = Rotation.from_euler('x', 90, degrees=True)
>>> r3 = r * r2
>>> r3.as_matrix()
Array([[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.]], dtype=float32)
See the scipy :class:`~scipy.spatial.transform.Rotation` documentation for
further examples of manipulating Rotation objects.
"""
quat: jax.Array

@classmethod
Expand Down Expand Up @@ -86,7 +115,7 @@ def identity(cls, num: int | None = None, dtype=float):
def random(cls, random_key: jax.Array, num: int | None = None):
"""Generate uniformly distributed rotations."""
# Need to implement scipy.stats.special_ortho_group for this to work...
raise NotImplementedError
raise NotImplementedError()

def __getitem__(self, indexer):
"""Extract rotation(s) at given index(es) from object."""
Expand Down Expand Up @@ -169,9 +198,31 @@ def single(self) -> bool:
return self.quat.ndim == 1


@implements(scipy.spatial.transform.Slerp)
class Slerp(typing.NamedTuple):
"""Spherical Linear Interpolation of Rotations."""
"""Spherical Linear Interpolation of Rotations.
JAX implementation of :class:`scipy.spatial.transform.Slerp`.
Examples:
Create a Slerp instance from a series of rotations:
>>> import math
>>> from jax.scipy.spatial.transform import Rotation, Slerp
>>> rots = jnp.array([[90, 0, 0],
... [0, 45, 0],
... [0, 0, -30]])
>>> key_rotations = Rotation.from_euler('zxy', rots, degrees=True)
>>> key_times = [0, 1, 2]
>>> slerp = Slerp.init(key_times, key_rotations)
>>> times = [0, 0.5, 1, 1.5, 2]
>>> interp_rots = slerp(times)
>>> interp_rots.as_euler('zxy')
Array([[ 1.5707963e+00, 0.0000000e+00, 0.0000000e+00],
[ 8.5309029e-01, 3.8711953e-01, 1.7768645e-01],
[-2.3841858e-07, 7.8539824e-01, 0.0000000e+00],
[-5.6668043e-02, 3.9213133e-01, -2.8347540e-01],
[ 0.0000000e+00, 0.0000000e+00, -5.2359891e-01]], dtype=float32)
"""

times: jnp.ndarray
timedelta: jnp.ndarray
Expand Down
4 changes: 0 additions & 4 deletions tests/scipy_ndimage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,6 @@ def testMapCoordinatesErrors(self):
with self.assertRaisesRegex(ValueError, 'sequence of length'):
lsp_ndimage.map_coordinates(x, [c, c], order=1)

def testMapCoordinateDocstring(self):
self.assertIn("Only nearest neighbor",
lsp_ndimage.map_coordinates.__doc__)

@jtu.sample_product(
dtype=float_dtypes + int_dtypes,
order=[0, 1],
Expand Down

0 comments on commit 18e4cfa

Please sign in to comment.