Skip to content

Commit

Permalink
Documentation improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Mar 8, 2022
1 parent e8f1a02 commit 80aec7b
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 15 deletions.
18 changes: 18 additions & 0 deletions docs/jax.example_libraries.rst
@@ -1,10 +1,28 @@
jax.example_libraries package
=============================

JAX provides some small, experimental libraries for machine learning. These
libraries are in part about providing tools and in part about serving as
examples for how to build such libraries using JAX. Each one is only <300 source
lines of code, so take a look inside and adapt them as you need!

.. note::
Each mini-library is meant to be an *inspiration*, but not a prescription.

To serve that purpose, it is best to keep their code samples minimal; so we
generally **will not merge PRs** adding new features. Instead, please send your
lovely pull requests and design ideas to more fully-featured libraries like
`Haiku`_ or `Flax`_.


.. toctree::
:maxdepth: 1

jax.example_libraries.optimizers
jax.example_libraries.stax

.. automodule:: jax.example_libraries


.. _Haiku: https://github.com/deepmind/dm-haiku
.. _Flax: https://github.com/google/flax
6 changes: 6 additions & 0 deletions docs/jax.image.rst
Expand Up @@ -15,3 +15,9 @@ Image manipulation functions
resize
scale_and_translate

Argument classes
----------------

.. currentmodule:: jax.image

.. autoclass:: ResizeMethod
10 changes: 5 additions & 5 deletions jax/_src/dlpack.py
Expand Up @@ -29,13 +29,13 @@


def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False):
"""Returns a DLPack tensor that encapsulates a DeviceArray `x`.
"""Returns a DLPack tensor that encapsulates a ``DeviceArray`` `x`.
Takes ownership of the contents of `x`; leaves `x` in an invalid/deleted
Takes ownership of the contents of ``x``; leaves `x` in an invalid/deleted
state.
Args:
x: a `DeviceArray`, on either CPU or GPU.
x: a ``DeviceArray``, on either CPU or GPU.
take_ownership: If ``True``, JAX hands ownership of the buffer to DLPack,
and the consumer is free to mutate the buffer; the JAX buffer acts as if
it were deleted. If ``False``, JAX retains ownership of the buffer; it is
Expand All @@ -49,9 +49,9 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
x.device_buffer, take_ownership=take_ownership)

def from_dlpack(dlpack):
"""Returns a `DeviceArray` representation of a DLPack tensor `dlpack`.
"""Returns a ``DeviceArray`` representation of a DLPack tensor.
The returned `DeviceArray` shares memory with `dlpack`.
The returned ``DeviceArray`` shares memory with ``dlpack``.
Args:
dlpack: a DLPack tensor, on either CPU or GPU.
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/flatten_util.py
Expand Up @@ -27,7 +27,7 @@


def ravel_pytree(pytree):
"""Ravel (i.e. flatten) a pytree of arrays down to a 1D array.
"""Ravel (flatten) a pytree of arrays down to a 1D array.
Args:
pytree: a pytree of arrays and scalars to ravel.
Expand Down
33 changes: 29 additions & 4 deletions jax/_src/image/scale.py
Expand Up @@ -102,11 +102,36 @@ def _scale_and_translate(x, output_shape: core.Shape,


class ResizeMethod(enum.Enum):
"""Image resize method.
Possible values are:
NEAREST:
Nearest-neighbor interpolation.
LINEAR:
`Linear interpolation`_.
LANCZOS3:
`Lanczos resampling`_, using a kernel of radius 3.
LANCZOS3:
`Lanczos resampling`_, using a kernel of radius 5.
CUBIC:
`Cubic interpolation`_, using the Keys cubic kernel.
.. _Linear interpolation: https://en.wikipedia.org/wiki/Bilinear_interpolation
.. _Cubic interpolation: https://en.wikipedia.org/wiki/Bicubic_interpolation
.. _Lanczos resampling: https://en.wikipedia.org/wiki/Lanczos_resampling
"""

NEAREST = 0
LINEAR = 1
LANCZOS3 = 2
LANCZOS5 = 3
CUBIC = 4

# Caution: The current resize implementation assumes that the resize kernels
# are interpolating, i.e. for the identity warp the output equals the input.
# This is not true for, e.g. a Gaussian kernel, so if such kernels are added
Expand Down Expand Up @@ -152,10 +177,10 @@ def scale_and_translate(image, shape: core.Shape,
(x * scale[1] + translation[1], y * scale[0] + translation[0])
(Note the _inverse_ warp is used to generate the sample locations.)
Assumes half-centered pixels, i.e the pixel at integer location row,col has
coordinates y, x = row + 0.5, col + 0.5.
Similarly for other input image dimensions.
(Note the *inverse* warp is used to generate the sample locations.)
Assumes half-centered pixels, i.e the pixel at integer location ``row, col``
has coordinates ``y, x = row + 0.5, col + 0.5``, and similarly for other input
image dimensions.
If an output location(pixel) maps to an input sample location that is outside
the input boundaries then the value for the output location will be set to
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/nn/functions.py
Expand Up @@ -320,7 +320,7 @@ def normalize(x: Array,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
where: Optional[Array] = None) -> Array:
"""Normalizes an array by subtracting mean and dividing by sqrt(var)."""
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
if mean is None:
mean = jnp.mean(x, axis, keepdims=True, where=where)
if variance is None:
Expand Down
5 changes: 4 additions & 1 deletion jax/example_libraries/optimizers.py
Expand Up @@ -16,7 +16,7 @@
You likely do not mean to import this module! The optimizers in this library
are intended as examples only. If you are looking for a fully featured optimizer
library, we recommend `optax` (https://github.com/deepmind/optax).
library, we recommend `Optax`_.
This module contains some convenient optimizer definitions, specifically
initialization and update functions, which can be used with ndarrays or
Expand Down Expand Up @@ -83,6 +83,9 @@ def step(step, opt_state):
for i in range(num_steps):
value, opt_state = step(i, opt_state)
.. _Optax: https://github.com/deepmind/optax
"""

from typing import Any, Callable, NamedTuple, Tuple, Union
Expand Down
12 changes: 10 additions & 2 deletions jax/example_libraries/stax.py
Expand Up @@ -14,9 +14,17 @@

"""Stax is a small but flexible neural net specification library from scratch.
For an example of its use, see examples/resnet50.py.
"""
You likely do not mean to import this module! Stax is intended as an example
library only. There are a number of other much more fully-featured neural
network libraries for JAX, including `Flax`_ from Google, and `Haiku`_ from
DeepMind.
For an example of how to use Stax, see the `Stax Resnet-50 example`
<https://github.com/google/jax/blob/main/examples/resnet50.py>`_.
.. _Haiku: https://github.com/deepmind/dm-haiku
.. _Flax: https://github.com/google/flax
"""

import functools
import operator as op
Expand Down
8 changes: 7 additions & 1 deletion jax/image/__init__.py
Expand Up @@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common functions for neural network libraries."""
"""Image manipulation functions.
More image manipulation functions can be found in libraries built on top of
JAX, such as `PIX`_.
.. _PIX: https://github.com/deepmind/dm_pix
"""

# flake8: noqa: F401
from jax._src.image.scale import (
Expand Down

0 comments on commit 80aec7b

Please sign in to comment.