From 80aec7b25fa8227dc2ce88a7a96504788cb90193 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 8 Mar 2022 09:35:36 -0500 Subject: [PATCH] Documentation improvements. --- docs/jax.example_libraries.rst | 18 ++++++++++++++++ docs/jax.image.rst | 6 ++++++ jax/_src/dlpack.py | 10 ++++----- jax/_src/flatten_util.py | 2 +- jax/_src/image/scale.py | 33 +++++++++++++++++++++++++---- jax/_src/nn/functions.py | 2 +- jax/example_libraries/optimizers.py | 5 ++++- jax/example_libraries/stax.py | 12 +++++++++-- jax/image/__init__.py | 8 ++++++- 9 files changed, 81 insertions(+), 15 deletions(-) diff --git a/docs/jax.example_libraries.rst b/docs/jax.example_libraries.rst index 4b5d561d75fe..cf75b87963c0 100644 --- a/docs/jax.example_libraries.rst +++ b/docs/jax.example_libraries.rst @@ -1,6 +1,20 @@ jax.example_libraries package ============================= +JAX provides some small, experimental libraries for machine learning. These +libraries are in part about providing tools and in part about serving as +examples for how to build such libraries using JAX. Each one is only <300 source +lines of code, so take a look inside and adapt them as you need! + +.. note:: + Each mini-library is meant to be an *inspiration*, but not a prescription. + + To serve that purpose, it is best to keep their code samples minimal; so we + generally **will not merge PRs** adding new features. Instead, please send your + lovely pull requests and design ideas to more fully-featured libraries like + `Haiku`_ or `Flax`_. + + .. toctree:: :maxdepth: 1 @@ -8,3 +22,7 @@ jax.example_libraries package jax.example_libraries.stax .. automodule:: jax.example_libraries + + +.. _Haiku: https://github.com/deepmind/dm-haiku +.. _Flax: https://github.com/google/flax diff --git a/docs/jax.image.rst b/docs/jax.image.rst index 364887dabfb6..8d4aad34394c 100644 --- a/docs/jax.image.rst +++ b/docs/jax.image.rst @@ -15,3 +15,9 @@ Image manipulation functions resize scale_and_translate +Argument classes +---------------- + +.. currentmodule:: jax.image + +.. autoclass:: ResizeMethod diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 09c418653a88..05d66227547b 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -29,13 +29,13 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False): - """Returns a DLPack tensor that encapsulates a DeviceArray `x`. + """Returns a DLPack tensor that encapsulates a ``DeviceArray`` `x`. - Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted + Takes ownership of the contents of ``x``; leaves `x` in an invalid/deleted state. Args: - x: a `DeviceArray`, on either CPU or GPU. + x: a ``DeviceArray``, on either CPU or GPU. take_ownership: If ``True``, JAX hands ownership of the buffer to DLPack, and the consumer is free to mutate the buffer; the JAX buffer acts as if it were deleted. If ``False``, JAX retains ownership of the buffer; it is @@ -49,9 +49,9 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False) x.device_buffer, take_ownership=take_ownership) def from_dlpack(dlpack): - """Returns a `DeviceArray` representation of a DLPack tensor `dlpack`. + """Returns a ``DeviceArray`` representation of a DLPack tensor. - The returned `DeviceArray` shares memory with `dlpack`. + The returned ``DeviceArray`` shares memory with ``dlpack``. Args: dlpack: a DLPack tensor, on either CPU or GPU. diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py index b7d7dc5b8644..f85367db6b48 100644 --- a/jax/_src/flatten_util.py +++ b/jax/_src/flatten_util.py @@ -27,7 +27,7 @@ def ravel_pytree(pytree): - """Ravel (i.e. flatten) a pytree of arrays down to a 1D array. + """Ravel (flatten) a pytree of arrays down to a 1D array. Args: pytree: a pytree of arrays and scalars to ravel. diff --git a/jax/_src/image/scale.py b/jax/_src/image/scale.py index 3c0713b1eec8..1bca94e6deb9 100644 --- a/jax/_src/image/scale.py +++ b/jax/_src/image/scale.py @@ -102,11 +102,36 @@ def _scale_and_translate(x, output_shape: core.Shape, class ResizeMethod(enum.Enum): + """Image resize method. + + Possible values are: + + NEAREST: + Nearest-neighbor interpolation. + + LINEAR: + `Linear interpolation`_. + + LANCZOS3: + `Lanczos resampling`_, using a kernel of radius 3. + + LANCZOS3: + `Lanczos resampling`_, using a kernel of radius 5. + + CUBIC: + `Cubic interpolation`_, using the Keys cubic kernel. + + .. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation + .. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation + .. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling + """ + NEAREST = 0 LINEAR = 1 LANCZOS3 = 2 LANCZOS5 = 3 CUBIC = 4 + # Caution: The current resize implementation assumes that the resize kernels # are interpolating, i.e. for the identity warp the output equals the input. # This is not true for, e.g. a Gaussian kernel, so if such kernels are added @@ -152,10 +177,10 @@ def scale_and_translate(image, shape: core.Shape, (x * scale[1] + translation[1], y * scale[0] + translation[0]) - (Note the _inverse_ warp is used to generate the sample locations.) - Assumes half-centered pixels, i.e the pixel at integer location row,col has - coordinates y, x = row + 0.5, col + 0.5. - Similarly for other input image dimensions. + (Note the *inverse* warp is used to generate the sample locations.) + Assumes half-centered pixels, i.e the pixel at integer location ``row, col`` + has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input + image dimensions. If an output location(pixel) maps to an input sample location that is outside the input boundaries then the value for the output location will be set to diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 53c5888fbf59..9026207ed4f9 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -320,7 +320,7 @@ def normalize(x: Array, variance: Optional[Array] = None, epsilon: Array = 1e-5, where: Optional[Array] = None) -> Array: - """Normalizes an array by subtracting mean and dividing by sqrt(var).""" + r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`.""" if mean is None: mean = jnp.mean(x, axis, keepdims=True, where=where) if variance is None: diff --git a/jax/example_libraries/optimizers.py b/jax/example_libraries/optimizers.py index 75f48174e404..b28a6d692887 100644 --- a/jax/example_libraries/optimizers.py +++ b/jax/example_libraries/optimizers.py @@ -16,7 +16,7 @@ You likely do not mean to import this module! The optimizers in this library are intended as examples only. If you are looking for a fully featured optimizer -library, we recommend `optax` (https://github.com/deepmind/optax). +library, we recommend `Optax`_. This module contains some convenient optimizer definitions, specifically initialization and update functions, which can be used with ndarrays or @@ -83,6 +83,9 @@ def step(step, opt_state): for i in range(num_steps): value, opt_state = step(i, opt_state) + + +.. _Optax: https://github.com/deepmind/optax """ from typing import Any, Callable, NamedTuple, Tuple, Union diff --git a/jax/example_libraries/stax.py b/jax/example_libraries/stax.py index 66f1a9d43f0d..bebb3306084f 100644 --- a/jax/example_libraries/stax.py +++ b/jax/example_libraries/stax.py @@ -14,9 +14,17 @@ """Stax is a small but flexible neural net specification library from scratch. -For an example of its use, see examples/resnet50.py. -""" +You likely do not mean to import this module! Stax is intended as an example +library only. There are a number of other much more fully-featured neural +network libraries for JAX, including `Flax`_ from Google, and `Haiku`_ from +DeepMind. + +For an example of how to use Stax, see the `Stax Resnet-50 example` +`_. +.. _Haiku: https://github.com/deepmind/dm-haiku +.. _Flax: https://github.com/google/flax +""" import functools import operator as op diff --git a/jax/image/__init__.py b/jax/image/__init__.py index a38a194d1a92..b10cf89d1317 100644 --- a/jax/image/__init__.py +++ b/jax/image/__init__.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Common functions for neural network libraries.""" +"""Image manipulation functions. + +More image manipulation functions can be found in libraries built on top of +JAX, such as `PIX`_. + +.. _PIX: https://github.com/deepmind/dm_pix +""" # flake8: noqa: F401 from jax._src.image.scale import (