Skip to content

Commit

Permalink
DOC: add ability to document extra_params within _wraps
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 11, 2022
1 parent 8df1932 commit 22ff25b
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 55 deletions.
132 changes: 82 additions & 50 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -1987,15 +1987,19 @@ def in1d(ar1, ar2, assume_unique=False, invert=False):
else:
return (ar1[:, None] == ar2[None, :]).any(-1)

_SETDIFF1D_DOC = """\
Because the size of the output of ``setdiff1d`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the output array: it must be specified statically for ``jnp.setdiff1d``
to be compiled with non-static operands. If specified, the first `size` unique elements will
be returned; if there are fewer unique elements than `size` indicates, the return value will
be padded with the `fill_value`, which defaults to zero."""

@_wraps(np.setdiff1d, lax_description=_SETDIFF1D_DOC)
@_wraps(np.setdiff1d,
lax_description=_dedent("""
Because the size of the output of ``setdiff1d`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.setdiff1d`` to be used within some of JAX's
transformations."""),
extra_params=_dedent("""
size : int, optional
If specified, the first ``size`` elements of the result will be returned. If there are
fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
_check_arraylike("setdiff1d", ar1, ar2)
if size is None:
Expand All @@ -2019,15 +2023,20 @@ def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
return where(arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)


_UNION1D_DOC = """\
Because the size of the output of ``union1d`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the output array: it must be specified statically for ``jnp.union1d``
to be compiled with non-static operands. If specified, the first `size` unique elements
will be returned; if there are fewer unique elements than `size` indicates, the return
value will be padded with `fill_value`, which defaults to the minimum value of the union."""

@_wraps(np.union1d, lax_description=_UNION1D_DOC)
@_wraps(np.union1d,
lax_description=_dedent("""
Because the size of the output of ``union1d`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.union1d`` to be used within some of JAX's
transformations."""),
extra_params=_dedent("""
size : int, optional
If specified, the first ``size`` elements of the result will be returned. If there are
fewer elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to the minimum
value of the union."""))
def union1d(ar1, ar2, *, size=None, fill_value=None):
_check_arraylike("union1d", ar1, ar2)
if size is None:
Expand Down Expand Up @@ -2144,17 +2153,21 @@ def _where(condition, x=None, y=None):
return lax.select(condition, x, y) if not is_always_empty else x


_WHERE_DOC = """\
At present, JAX does not support JIT-compilation of the single-argument form
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
three-argument form does not have a data-dependent shape and can be JIT-compiled
successfully. Alternatively, you can specify the optional ``size`` keyword:
if specified, the first ``size`` True elements will be returned; if there
are fewer True elements than ``size`` indicates, the index arrays will be
padded with ``fill_value`` (default is 0.)
"""

@_wraps(np.where, update_doc=False, lax_description=_WHERE_DOC)
@_wraps(np.where,
lax_description=_dedent("""
At present, JAX does not support JIT-compilation of the single-argument form
of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
three-argument form does not have a data-dependent shape and can be JIT-compiled
successfully. Alternatively, you can use the optional ``size`` keyword to
statically specify the expected size of the output."""),
extra_params=_dedent("""
size : int, optional
Only referenced when ``x`` and ``y`` are ``None``. If specified, the indices of the first
``size`` elements of the result will be returned. If there are fewer elements than ``size``
indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def where(condition, x=None, y=None, *, size=None, fill_value=None):
if x is None and y is None:
_check_arraylike("where", condition)
Expand Down Expand Up @@ -2843,15 +2856,20 @@ def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,

_NONZERO_DOC = """\
Because the size of the output of ``nonzero`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the output arrays: it must be specified statically for ``jnp.nonzero``
to be compiled with non-static operands. If specified, the first `size` nonzero elements
will be returned; if there are fewer nonzero elements than `size` indicates, the result
will be padded with ``fill_value``, which defaults to zero. ``fill_value`` may be a scalar,
or a tuple specifying the fill value in each dimension.
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.nonzero`` to be used within some of JAX's
transformations.
"""
_NONZERO_EXTRA_PARAMS = """
size : int, optional
If specified, the indices of the first ``size`` True elements will be returned. If there are
fewer unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero.
"""

@_wraps(np.nonzero, lax_description=_NONZERO_DOC)
@_wraps(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def nonzero(a, *, size=None, fill_value=None):
a = atleast_1d(a)
mask = a != 0
Expand All @@ -2874,7 +2892,7 @@ def nonzero(a, *, size=None, fill_value=None):
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value, out))
return out

@_wraps(np.flatnonzero, lax_description=_NONZERO_DOC)
@_wraps(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def flatnonzero(a, *, size=None, fill_value=None):
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]

Expand Down Expand Up @@ -5236,7 +5254,19 @@ def vander(x, N=None, increasing=False):
nonzero elements than `size` indicates, the index arrays will be zero-padded.
"""

@_wraps(np.argwhere, lax_description=_ARGWHERE_DOC)
@_wraps(np.argwhere,
lax_description=_dedent("""
Because the size of the output of ``argwhere`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.argwhere`` to be used within some of JAX's
transformations."""),
extra_params=_dedent("""
size : int, optional
If specified, the indices of the first ``size`` True elements will be returned. If there
are fewer results than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def argwhere(a, *, size=None, fill_value=None):
result = transpose(vstack(nonzero(a, size=size, fill_value=fill_value)))
if ndim(a) == 0:
Expand Down Expand Up @@ -5715,18 +5745,20 @@ def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=Fa
ret += (mask.sum(),)
return ret[0] if len(ret) == 1 else ret


_UNIQUE_DOC = """\
Because the size of the output of ``unique`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the data-dependent output arrays: it must be specified statically
for ``jnp.unique`` to be compiled with non-static operands. If specified, the first `size`
unique elements will be returned; if there are fewer unique elements than `size` indicates,
the return value will be padded with `fill_value`, which defaults to the minimum value
along the specified axis of the input."""


@_wraps(np.unique, skip_params=['axis'], lax_description=_UNIQUE_DOC)
@_wraps(np.unique, skip_params=['axis'],
lax_description=_dedent("""
Because the size of the output of ``unique`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional ``size`` argument which
must be specified statically for ``jnp.unique`` to be used within some of JAX's
transformations."""),
extra_params=_dedent("""
size : int, optional
If specified, the first ``size`` unique elements will be returned. If there are fewer unique
elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``. The default is the minimum value
along the specified axis of the input."""))
def unique(ar, return_index=False, return_inverse=False,
return_counts=False, axis: Optional[int] = None, *, size=None, fill_value=None):
_check_arraylike("unique", ar)
Expand Down
20 changes: 15 additions & 5 deletions jax/_src/numpy/util.py
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
import re
import textwrap
from typing import Callable, NamedTuple, Optional, Dict, Sequence
Expand All @@ -37,7 +36,7 @@ class ParsedDoc(NamedTuple):
signature: str = ""
summary: str = ""
front_matter: str = ""
sections: Dict[str, str] = OrderedDict()
sections: Dict[str, str] = {}


def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
Expand Down Expand Up @@ -79,7 +78,7 @@ def _parse_numpydoc(docstr: Optional[str]) -> ParsedDoc:
section_list = _section_break.split(body)
if not _section_break.match(section_list[0]):
front_matter, *section_list = section_list
sections = OrderedDict((section.split('\n', 1)[0], section) for section in section_list)
sections = {section.split('\n', 1)[0]: section for section in section_list}

return ParsedDoc(docstr=docstr, signature=signature, summary=summary,
front_matter=front_matter, sections=sections)
Expand All @@ -91,12 +90,18 @@ def _parse_parameters(body: str) -> Dict[str, str]:
assert title == 'Parameters'
assert underline and not underline.strip('-')
parameters = _parameter_break.split(content)
return OrderedDict((p.partition(' : ')[0].partition(', ')[0], p) for p in parameters)
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}


def _parse_extra_params(extra_params: str) -> Dict[str, str]:
"""Parse the extra parameters passed to _wraps()"""
parameters = _parameter_break.split(extra_params.strip('\n'))
return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters}


def _wraps(fun: Optional[Callable], update_doc: bool = True, lax_description: str = "",
sections: Sequence[str] = ('Parameters', 'Returns', 'References'),
skip_params: Sequence[str] = ()):
skip_params: Sequence[str] = (), extra_params: Optional[str]=None):
"""Specialized version of functools.wraps for wrapping numpy functions.
This produces a wrapped function with a modified docstring. In particular, if
Expand All @@ -116,6 +121,9 @@ def _wraps(fun: Optional[Callable], update_doc: bool = True, lax_description: st
["Parameters", "returns", "References"]
skip_params: a list of strings containing names of parameters accepted by the
function that should be skipped in the parameter list.
extra_params: an optional string containing additional parameter descriptions.
When ``update_doc=True``, these will be added to the list of parameter
descriptions in the updated doc.
"""
def wrap(op):
docstr = getattr(fun, "__doc__", None)
Expand All @@ -127,6 +135,8 @@ def wrap(op):
code = getattr(getattr(op, "__wrapped__", op), "__code__", None)
# Remove unrecognized parameter descriptions.
parameters = _parse_parameters(parsed.sections['Parameters'])
if extra_params:
parameters.update(_parse_extra_params(extra_params))
parsed.sections['Parameters'] = (
"Parameters\n"
"----------\n" +
Expand Down

0 comments on commit 22ff25b

Please sign in to comment.