From d3d666d081a74df749fda4a60da86c51e4d58651 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 7 Mar 2022 17:05:51 -0500 Subject: [PATCH] Document jax.nn.initializers. --- docs/jax.nn.initializers.rst | 24 +- jax/_src/nn/initializers.py | 423 +++++++++++++++++++++++++++++++---- jax/_src/numpy/lax_numpy.py | 2 +- 3 files changed, 396 insertions(+), 53 deletions(-) diff --git a/docs/jax.nn.initializers.rst b/docs/jax.nn.initializers.rst index 2fa44ca23498..e500ef7a8783 100644 --- a/docs/jax.nn.initializers.rst +++ b/docs/jax.nn.initializers.rst @@ -13,17 +13,25 @@ Initializers This module provides common neural network layer initializers, consistent with definitions used in Keras and Sonnet. +An initializer is a function that takes three arguments: +``(key, shape, dtype)`` and returns an array with dimensions ``shape`` and +data type ``dtype``. Argument ``key`` is a :class:`jax.random.PRNGKey` random +key used when generating random numbers to initialize the array. + .. autosummary:: :toctree: _autosummary - zeros + constant + delta_orthogonal + glorot_normal + glorot_uniform + he_normal + he_uniform + lecun_normal + lecun_uniform + normal ones + orthogonal uniform - normal variance_scaling - glorot_uniform - glorot_normal - lecun_uniform - lecun_normal - he_uniform - he_normal + zeros diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 6480fc67fb42..df65edaf50e7 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -18,7 +18,7 @@ """ -from functools import partial +from typing import Any, Callable, Sequence, Union import numpy as np @@ -29,22 +29,90 @@ from jax._src.util import prod from jax import dtypes -def zeros(key, shape, dtype=jnp.float_): return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype)) -def ones(key, shape, dtype=jnp.float_): return jnp.ones(shape, dtypes.canonicalize_dtype(dtype)) +DType = Any -def constant(value, dtype=jnp.float_): +def zeros(key, shape, dtype: DType = jnp.float_): + """An initializer that returns a constant array full of zeros. + + The ``key`` argument is ignored. + + >>> import jax, jax.numpy as jnp + >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[0., 0., 0.], + [0., 0., 0.]], dtype=float32) + """ + return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype)) + +def ones(key, shape, dtype: DType = jnp.float_): + """An initializer that returns a constant array full of ones. + + The ``key`` argument is ignored. + + >>> import jax, jax.numpy as jnp + >>> jax.nn.initializers.ones(jax.random.PRNGKey(42), (3, 2), jnp.float32) + DeviceArray([[1., 1.], + [1., 1.], + [1., 1.]], dtype=float32) + """ + return jnp.ones(shape, dtypes.canonicalize_dtype(dtype)) + +def constant(value, dtype: DType = jnp.float_) -> Callable: + """Builds an initializer that returns arrays full of a constant ``value``. + + Args: + value: the constant value with which to fill the initializer. + dtype: optional; the initializer's default dtype. + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.constant(-7) + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[-7., -7., -7.], + [-7., -7., -7.]], dtype=float32) + """ def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) return jnp.full(shape, value, dtype=dtype) return init -def uniform(scale=1e-2, dtype=jnp.float_): +def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable: + """Builds an initializer that returns real uniformly-distributed random arrays. + + Args: + scale: optional; the upper bound of the random distribution. + dtype: optional; the initializer's default dtype. + + Returns: + An initializer that returns arrays whose values are uniformly distributed in + the range ``[0, scale)``. + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.uniform(10.0) + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[7.298188 , 8.691938 , 8.7230015], + [2.0818567, 1.8662417, 5.5022564]], dtype=float32) + """ def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) return random.uniform(key, shape, dtype) * scale return init -def normal(stddev=1e-2, dtype=jnp.float_): +def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable: + """Builds an initializer that returns real normally-distributed random arrays. + + Args: + stddev: optional; the standard deviation of the distribution. + dtype: optional; the initializer's default dtype. + + Returns: + An initializer that returns arrays whose values are normally distributed + with mean ``0`` and standard deviation ``stddev``. + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.normal(5.0) + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ], + [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) + """ def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) return random.normal(key, shape, dtype) * stddev @@ -98,40 +166,49 @@ def _complex_truncated_normal(key, upper, shape, dtype): theta = 2 * jnp.pi * random.uniform(key_theta, shape, dtype) return r * jnp.exp(1j * theta) -def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, - batch_axis=(), dtype=jnp.float_): - """ - Initializer capable of adapting its scale to the shape of the weights tensor. +def variance_scaling(scale, mode: str, distribution: str, + in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Sequence[int] = (), + dtype: DType = jnp.float_) -> Callable: + r""" + Initializer that adapts its scale to the shape of the weights tensor. + + With ``distribution="truncated_normal"`` or ``distribution="normal"``, samples + are drawn from a (truncated) normal distribution with a mean of zero + and a standard deviation (after truncation, if applicable) of + :math:`\sqrt{\frac{scale}{n}}`, where `n` is: - With `distribution="truncated_normal" or "normal"`, samples are - drawn from a truncated/untruncated normal distribution with a mean of zero and - a standard deviation (after truncation, if used) `stddev = sqrt(scale / n)`, - where `n` is: - - number of input units in the weights tensor, if `mode="fan_in"` - - number of output units, if `mode="fan_out"` - - average of the numbers of input and output units, if `mode="fan_avg"` + * the number of input units in the weights tensor, if ``mode="fan_in"``, + * the number of output units, if ``mode="fan_out"``, or + * the average of the numbers of input and output units, if ``mode="fan_avg"``. - This initializer can be configured with in_axis, out_axis, and batch_axis to - work with general convolutional or dense layers; axes that are not in any of - those arguments are assumed to be the "receptive field" (convolution kernel - spatial axes). + This initializer can be configured with ``in_axis``, ``out_axis``, and + ``batch_axis`` to work with general convolutional or dense layers; axes that + are not in any of those arguments are assumed to be the "receptive field" + (convolution kernel spatial axes). - With `distribution="truncated_normal"`, the absolute values of the samples are - truncated below 2 standard deviations before truncation. + With ``distribution="truncated_normal"``, the absolute values of the samples + are truncated at 2 standard deviations before scaling. - With `distribution="uniform"`, samples are drawn from: - - a uniform interval, if `dtype` is real - - a uniform disk, if `dtype` is complex - with a mean of zero and a standard deviation of `stddev`. + With ``distribution="uniform"``, samples are drawn from: + + * a uniform interval, if `dtype` is real, or + * a uniform disk, if `dtype` is complex, + + with a mean of zero and a standard deviation of ``stddev``. Args: scale: scaling factor (positive float). - mode: one of "fan_in", "fan_out", and "fan_avg". - distribution: random distribution to use. One of "truncated_normal", - "normal" and "uniform". - in_axis: axis or sequence of axes of the input dimension in the weights tensor. - out_axis: axis or sequence of axes of the output dimension in the weights tensor. - batch_axis: axis or sequence of axes in the weight tensor that should be ignored. + mode: one of ``"fan_in"``, ``"fan_out"``, and ``"fan_avg"``. + distribution: random distribution to use. One of ``"truncated_normal"``, + ``"normal"`` and ``"uniform"``. + in_axis: axis or sequence of axes of the input dimension in the weights + array. + out_axis: axis or sequence of axes of the output dimension in the weights + array. + batch_axis: axis or sequence of axes in the weight array that should be + ignored. dtype: the dtype of the weights. """ @@ -168,19 +245,250 @@ def init(key, shape, dtype=dtype): return init -xavier_uniform = glorot_uniform = partial(variance_scaling, 1.0, "fan_avg", "uniform") -xavier_normal = glorot_normal = partial(variance_scaling, 1.0, "fan_avg", "truncated_normal") -lecun_uniform = partial(variance_scaling, 1.0, "fan_in", "uniform") -lecun_normal = partial(variance_scaling, 1.0, "fan_in", "truncated_normal") -kaiming_uniform = he_uniform = partial(variance_scaling, 2.0, "fan_in", "uniform") -kaiming_normal = he_normal = partial(variance_scaling, 2.0, "fan_in", "truncated_normal") +def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Sequence[int] = (), + dtype: DType = jnp.float_) -> Callable: + """Builds a Glorot uniform initializer (aka Xavier uniform initializer). + + A `Glorot uniform initializer`_ is a specialization of + :func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``, + ``mode="fan_avg"``, and ``distribution="uniform"``. -def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_): + Args: + in_axis: axis or sequence of axes of the input dimension in the weights + array. + out_axis: axis or sequence of axes of the output dimension in the weights + array. + batch_axis: axis or sequence of axes in the weight array that should be + ignored. + dtype: the dtype of the weights. + + Returns: + An initializer. + + Example: + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.glorot_uniform() + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[ 0.50350785, 0.8088631 , 0.81566876], + [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32) + + .. _Glorot uniform initializer: http://proceedings.mlr.press/v9/glorot10a.html + """ + return variance_scaling(1.0, "fan_avg", "uniform", in_axis=in_axis, + out_axis=out_axis, batch_axis=batch_axis, dtype=dtype) + +xavier_uniform = glorot_uniform + + +def glorot_normal(in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Sequence[int] = (), + dtype: DType = jnp.float_) -> Callable: + """Builds a Glorot normal initializer (aka Xavier normal initializer). + + A `Glorot normal initializer`_ is a specialization of + :func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``, + ``mode="fan_avg"``, and ``distribution="truncated_normal"``. + + Args: + in_axis: axis or sequence of axes of the input dimension in the weights + array. + out_axis: axis or sequence of axes of the output dimension in the weights + array. + batch_axis: axis or sequence of axes in the weight array that should be + ignored. + dtype: the dtype of the weights. + + Returns: + An initializer. + + Example: + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.glorot_normal() + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[ 0.41770416, 0.75262755, 0.7619329 ], + [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32) + + .. _Glorot normal initializer: http://proceedings.mlr.press/v9/glorot10a.html """ - Construct an initializer for uniformly distributed orthogonal matrices. + return variance_scaling(1.0, "fan_avg", "truncated_normal", in_axis=in_axis, + out_axis=out_axis, batch_axis=batch_axis, dtype=dtype) + +xavier_normal = glorot_normal + +def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Sequence[int] = (), + dtype: DType = jnp.float_) -> Callable: + """Builds a Lecun uniform initializer. + + A `Lecun uniform initializer`_ is a specialization of + :func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``, + ``mode="fan_in"``, and ``distribution="uniform"``. + + Args: + in_axis: axis or sequence of axes of the input dimension in the weights + array. + out_axis: axis or sequence of axes of the output dimension in the weights + array. + batch_axis: axis or sequence of axes in the weight array that should be + ignored. + dtype: the dtype of the weights. + + Returns: + An initializer. + + Example: + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.lecun_uniform() + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[ 0.56293887, 0.90433645, 0.9119454 ], + [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32) + + .. _Lecun uniform initializer: https://arxiv.org/abs/1706.02515 + """ + return variance_scaling(1.0, "fan_in", "uniform", in_axis=in_axis, + out_axis=out_axis, batch_axis=batch_axis, dtype=dtype) + +def lecun_normal(in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Sequence[int] = (), + dtype: DType = jnp.float_) -> Callable: + """Builds a Lecun normal initializer. + + A `Lecun normal initializer`_ is a specialization of + :func:`jax.nn.initializers.variance_scaling` where ``scale = 1.0``, + ``mode="fan_in"``, and ``distribution="truncated_normal"``. + + Args: + in_axis: axis or sequence of axes of the input dimension in the weights + array. + out_axis: axis or sequence of axes of the output dimension in the weights + array. + batch_axis: axis or sequence of axes in the weight array that should be + ignored. + dtype: the dtype of the weights. + + Returns: + An initializer. + + Example: + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.lecun_normal() + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[ 0.46700746, 0.8414632 , 0.8518669 ], + [-0.61677957, -0.67402434, 0.09683388]], dtype=float32) + + .. _Lecun normal initializer: https://arxiv.org/abs/1706.02515 + """ + return variance_scaling(1.0, "fan_in", "truncated_normal", in_axis=in_axis, + out_axis=out_axis, batch_axis=batch_axis, dtype=dtype) + + +def he_uniform(in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Sequence[int] = (), + dtype: DType = jnp.float_) -> Callable: + """Builds a He uniform initializer (aka Kaiming uniform initializer). + + A `He uniform initializer`_ is a specialization of + :func:`jax.nn.initializers.variance_scaling` where ``scale = 2.0``, + ``mode="fan_in"``, and ``distribution="uniform"``. + + Args: + in_axis: axis or sequence of axes of the input dimension in the weights + array. + out_axis: axis or sequence of axes of the output dimension in the weights + array. + batch_axis: axis or sequence of axes in the weight array that should be + ignored. + dtype: the dtype of the weights. + + Returns: + An initializer. + + Example: + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.kaiming_uniform() + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[ 0.79611576, 1.2789248 , 1.2896855 ], + [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32) + + .. _He uniform initializer: https://arxiv.org/abs/1706.02515 + """ + return variance_scaling(2.0, "fan_in", "uniform", in_axis=in_axis, + out_axis=out_axis, batch_axis=batch_axis, dtype=dtype) + +kaiming_uniform = he_uniform + + +def he_normal(in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Sequence[int] = (), + dtype: DType = jnp.float_) -> Callable: + """Builds a He normal initializer (aka Kaiming normal initializer). + + A `He normal initializer`_ is a specialization of + :func:`jax.nn.initializers.variance_scaling` where ``scale = 2.0``, + ``mode="fan_in"``, and ``distribution="truncated_normal"``. + + Args: + in_axis: axis or sequence of axes of the input dimension in the weights + array. + out_axis: axis or sequence of axes of the output dimension in the weights + array. + batch_axis: axis or sequence of axes in the weight array that should be + ignored. + dtype: the dtype of the weights. + + Returns: + An initializer. + + Example: + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.kaiming_normal() + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + DeviceArray([[ 0.6604483 , 1.1900088 , 1.2047218 ], + [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32) + + .. _He normal initializer: https://arxiv.org/abs/1706.02515 + """ + return variance_scaling(2.0, "fan_in", "truncated_normal", in_axis=in_axis, + out_axis=out_axis, batch_axis=batch_axis, dtype=dtype) + +kaiming_normal = he_normal + + +def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_): + """ + Builds an initializer that returns uniformly distributed orthogonal matrices. If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller. + + Args: + scale: the upper bound of the uniform distribution. + column_axis: the axis that contains the columns that should be orthogonal. + dtype: the default dtype of the weights. + + Returns: + An orthogonal initializer. + + Example: + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.orthogonal() + >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], + [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) @@ -199,11 +507,38 @@ def init(key, shape, dtype=dtype): return init -def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_): +def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_): """ - Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393. + Builds an initializer for delta orthogonal kernels. + + Args: + scale: the upper bound of the uniform distribution. + column_axis: the axis that contains the columns that should be orthogonal. + dtype: the default dtype of the weights. + + Returns: + A `delta orthogonal initializer`_. The shape passed to the initializer must + be 3D, 4D, or 5D. + + Example: + + >>> import jax, jax.numpy as jnp + >>> initializer = jax.nn.initializers.delta_orthogonal() + >>> initializer(jax.random.PRNGKey(42), (3, 3, 3), jnp.float32) # doctest: +SKIP + DeviceArray([[[ 0. , 0. , 0. ], + [ 0. , 0. , 0. ], + [ 0. , 0. , 0. ]], + + [[ 0.27858758, -0.7949833 , -0.53887904], + [ 0.9120717 , 0.04322892, 0.40774566], + [-0.30085585, -0.6050892 , 0.73712474]], + + [[ 0. , 0. , 0. ], + [ 0. , 0. , 0. ], + [ 0. , 0. , 0. ]]], dtype=float32) + - The shape must be 3D, 4D or 5D. + .. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393 """ def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a203d2c0e866..bc7160e8eef3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -146,7 +146,7 @@ def _make_scalar_type(np_scalar_type): int_ = int32 if dtypes.int_ == np.int32 else int64 uint = uint32 if dtypes.uint == np.uint32 else uint64 -float_ = float32 if dtypes.float_ == np.float32 else float64 +float_: Any = float32 if dtypes.float_ == np.float32 else float64 complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128 generic = np.generic