From 7869ff4964b6917f9d1f23d2c2b13f2303e391cd Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Thu, 4 Aug 2022 18:43:17 -0400 Subject: [PATCH] Annotate nn.initializers This was done to expose an Initializers type annotation that can be used in other libraries. --- jax/_src/nn/initializers.py | 134 +++++++++++++++++++++++++----------- jax/nn/initializers.py | 1 + 2 files changed, 95 insertions(+), 40 deletions(-) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index f1b56414a1d4..00b3a17e66c3 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -18,9 +18,10 @@ """ -from typing import Any, Callable, Sequence, Union +from typing import Any, Sequence, Tuple, Union import numpy as np +from typing_extensions import Literal, Protocol import jax.numpy as jnp from jax import lax @@ -29,9 +30,25 @@ from jax._src.util import prod from jax._src import dtypes -DType = Any - -def zeros(key, shape, dtype: DType = jnp.float_): +KeyArray = random.KeyArray +Array = Any +# TODO: Import or define these to match +# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py. +DTypeLikeFloat = Any +DTypeLikeComplex = Any +DTypeLikeInexact = Any # DTypeLikeFloat | DTypeLikeComplex +RealNumeric = Any # Scalar jnp array or float + +class Initializer(Protocol): + @staticmethod + def __call__(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = jnp.float_) -> Array: + ... + +def zeros(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = jnp.float_) -> Array: """An initializer that returns a constant array full of zeros. The ``key`` argument is ignored. @@ -43,7 +60,9 @@ def zeros(key, shape, dtype: DType = jnp.float_): """ return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype)) -def ones(key, shape, dtype: DType = jnp.float_): +def ones(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = jnp.float_) -> Array: """An initializer that returns a constant array full of ones. The ``key`` argument is ignored. @@ -56,7 +75,9 @@ def ones(key, shape, dtype: DType = jnp.float_): """ return jnp.ones(shape, dtypes.canonicalize_dtype(dtype)) -def constant(value, dtype: DType = jnp.float_) -> Callable: +def constant(value: Array, + dtype: DTypeLikeInexact = jnp.float_ + ) -> Initializer: """Builds an initializer that returns arrays full of a constant ``value``. Args: @@ -69,12 +90,15 @@ def constant(value, dtype: DType = jnp.float_) -> Callable: DeviceArray([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32) """ - def init(key, shape, dtype=dtype): + def init(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) return jnp.full(shape, value, dtype=dtype) return init -def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable: +def uniform(scale: RealNumeric = 1e-2, + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds an initializer that returns real uniformly-distributed random arrays. Args: @@ -91,12 +115,15 @@ def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable: DeviceArray([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32) """ - def init(key, shape, dtype=dtype): + def init(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) return random.uniform(key, shape, dtype) * scale return init -def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable: +def normal(stddev: RealNumeric = 1e-2, + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds an initializer that returns real normally-distributed random arrays. Args: @@ -113,13 +140,18 @@ def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable: DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) """ - def init(key, shape, dtype=dtype): + def init(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) return random.normal(key, shape, dtype) * stddev return init -def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1, - batch_axis=()): +def _compute_fans(shape: core.NamedShape, + in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Union[int, Sequence[int]] = () + ) -> Tuple[Array, Array]: """ Compute effective input and output sizes for a linear or convolutional layer. @@ -143,7 +175,9 @@ def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1, fan_out = out_size * receptive_field_size return fan_in, fan_out -def _complex_uniform(key, shape, dtype): +def _complex_uniform(key: KeyArray, + shape: Union[Sequence[int], core.NamedShape], + dtype: DTypeLikeInexact) -> Array: """ Sample uniform random values within a disk on the complex plane, with zero mean and unit variance. @@ -155,24 +189,33 @@ def _complex_uniform(key, shape, dtype): theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) -def _complex_truncated_normal(key, upper, shape, dtype): +def _complex_truncated_normal(key: KeyArray, upper: Array, + shape: Union[Sequence[int], core.NamedShape], + dtype: DTypeLikeInexact) -> Array: """ Sample random values from a centered normal distribution on the complex plane, - whose modulus is truncated to `upper`, and the variance before the truncation is one. + whose modulus is truncated to `upper`, and the variance before the truncation + is one. """ key_r, key_theta = random.split(key) real_dtype = np.array(0, dtype).real.dtype dtype = dtypes._to_complex_dtype(real_dtype) - t = (1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) * random.uniform(key_r, shape, real_dtype).astype(dtype) + t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) + * random.uniform(key_r, shape, real_dtype).astype(dtype)) r = jnp.sqrt(-jnp.log(1 - t)) theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) -def variance_scaling(scale, mode: str, distribution: str, - in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, - batch_axis: Sequence[int] = (), - dtype: DType = jnp.float_) -> Callable: +def variance_scaling( + scale: RealNumeric, + mode: Union[Literal["fan_in"], Literal["fan_out"], Literal["fan_avg"]], + distribution: Union[Literal["truncated_normal"], Literal["normal"], + Literal["uniform"]], + in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Sequence[int] = (), + dtype: DTypeLikeInexact = jnp.float_ +) -> Initializer: r""" Initializer that adapts its scale to the shape of the weights tensor. @@ -214,10 +257,12 @@ def variance_scaling(scale, mode: str, distribution: str, dtype: the dtype of the weights. """ - def init(key, shape, dtype=dtype): + def init(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) - shape = core.as_named_shape(shape) - fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis) + named_shape = core.as_named_shape(shape) + fan_in, fan_out = _compute_fans(named_shape, in_axis, out_axis, batch_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2 @@ -230,18 +275,18 @@ def init(key, shape, dtype=dtype): if jnp.issubdtype(dtype, jnp.floating): # constant is stddev of standard normal truncated to (-2, 2) stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) - return random.truncated_normal(key, -2, 2, shape, dtype) * stddev + return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev else: # constant is stddev of complex standard normal truncated to 2 stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype) - return _complex_truncated_normal(key, 2, shape, dtype) * stddev + return _complex_truncated_normal(key, 2, named_shape, dtype) * stddev elif distribution == "normal": - return random.normal(key, shape, dtype) * jnp.sqrt(variance) + return random.normal(key, named_shape, dtype) * jnp.sqrt(variance) elif distribution == "uniform": if jnp.issubdtype(dtype, jnp.floating): - return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance) + return random.uniform(key, named_shape, dtype, -1) * jnp.sqrt(3 * variance) else: - return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance) + return _complex_uniform(key, named_shape, dtype) * jnp.sqrt(variance) else: raise ValueError(f"invalid distribution for variance scaling initializer: {distribution}") @@ -250,7 +295,7 @@ def init(key, shape, dtype=dtype): def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DType = jnp.float_) -> Callable: + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Glorot uniform initializer (aka Xavier uniform initializer). A `Glorot uniform initializer`_ is a specialization of @@ -288,7 +333,7 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2, def glorot_normal(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DType = jnp.float_) -> Callable: + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Glorot normal initializer (aka Xavier normal initializer). A `Glorot normal initializer`_ is a specialization of @@ -325,7 +370,7 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2, def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DType = jnp.float_) -> Callable: + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Lecun uniform initializer. A `Lecun uniform initializer`_ is a specialization of @@ -360,7 +405,7 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2, def lecun_normal(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DType = jnp.float_) -> Callable: + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a Lecun normal initializer. A `Lecun normal initializer`_ is a specialization of @@ -396,7 +441,7 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2, def he_uniform(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DType = jnp.float_) -> Callable: + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a He uniform initializer (aka Kaiming uniform initializer). A `He uniform initializer`_ is a specialization of @@ -434,7 +479,7 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2, def he_normal(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DType = jnp.float_) -> Callable: + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """Builds a He normal initializer (aka Kaiming normal initializer). A `He normal initializer`_ is a specialization of @@ -469,7 +514,9 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2, kaiming_normal = he_normal -def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_): +def orthogonal(scale: RealNumeric = 1.0, + column_axis: int = -1, + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """ Builds an initializer that returns uniformly distributed orthogonal matrices. @@ -492,7 +539,9 @@ def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_): DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ - def init(key, shape, dtype=dtype): + def init(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) if len(shape) < 2: raise ValueError("orthogonal initializer requires at least a 2D shape") @@ -509,7 +558,10 @@ def init(key, shape, dtype=dtype): return init -def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_): +def delta_orthogonal( + scale: RealNumeric = 1.0, + column_axis: int = -1, + dtype: DTypeLikeInexact = jnp.float_) -> Initializer: """ Builds an initializer for delta orthogonal kernels. @@ -542,7 +594,9 @@ def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_): .. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393 """ - def init(key, shape, dtype=dtype): + def init(key: KeyArray, + shape: core.Shape, + dtype: DTypeLikeInexact = dtype) -> Array: dtype = dtypes.canonicalize_dtype(dtype) if len(shape) not in [3, 4, 5]: raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D " diff --git a/jax/nn/initializers.py b/jax/nn/initializers.py index 1a3b84f4b596..0ae9cd1a2f7f 100644 --- a/jax/nn/initializers.py +++ b/jax/nn/initializers.py @@ -19,6 +19,7 @@ from jax._src.nn.initializers import ( constant as constant, + Initializer as Initializer, delta_orthogonal as delta_orthogonal, glorot_normal as glorot_normal, glorot_uniform as glorot_uniform,