diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 4dc9c1a38481..f1b56414a1d4 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -18,10 +18,9 @@ """ -from typing import Any, Sequence, Tuple, Union +from typing import Any, Callable, Sequence, Union import numpy as np -from typing_extensions import Literal, Protocol import jax.numpy as jnp from jax import lax @@ -30,25 +29,9 @@ from jax._src.util import prod from jax._src import dtypes -KeyArray = random.KeyArray -Array = Any -# TODO: Import or define these to match -# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py. -DTypeLikeFloat = Any -DTypeLikeComplex = Any -DTypeLikeInexact = Any # DTypeLikeFloat | DTypeLikeComplex -RealNumeric = Any # Scalar jnp array or float - -class Initializer(Protocol): - @staticmethod - def __call__(key: KeyArray, - shape: core.Shape, - dtype: DTypeLikeInexact = jnp.float_) -> Array: - ... - -def zeros(key: KeyArray, - shape: core.Shape, - dtype: DTypeLikeInexact = jnp.float_) -> Array: +DType = Any + +def zeros(key, shape, dtype: DType = jnp.float_): """An initializer that returns a constant array full of zeros. The ``key`` argument is ignored. @@ -60,9 +43,7 @@ def zeros(key: KeyArray, """ return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype)) -def ones(key: KeyArray, - shape: core.Shape, - dtype: DTypeLikeInexact = jnp.float_) -> Array: +def ones(key, shape, dtype: DType = jnp.float_): """An initializer that returns a constant array full of ones. The ``key`` argument is ignored. @@ -75,9 +56,7 @@ def ones(key: KeyArray, """ return jnp.ones(shape, dtypes.canonicalize_dtype(dtype)) -def constant(value: Array, - dtype: DTypeLikeInexact = jnp.float_ - ) -> Initializer: +def constant(value, dtype: DType = jnp.float_) -> Callable: """Builds an initializer that returns arrays full of a constant ``value``. Args: @@ -90,15 +69,12 @@ def constant(value: Array, DeviceArray([[-7., -7., -7.], [-7., -7., -7.]], dtype=float32) """ - def init(key: KeyArray, - shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) return jnp.full(shape, value, dtype=dtype) return init -def uniform(scale: RealNumeric = 1e-2, - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: +def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable: """Builds an initializer that returns real uniformly-distributed random arrays. Args: @@ -115,15 +91,12 @@ def uniform(scale: RealNumeric = 1e-2, DeviceArray([[7.298188 , 8.691938 , 8.7230015], [2.0818567, 1.8662417, 5.5022564]], dtype=float32) """ - def init(key: KeyArray, - shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) return random.uniform(key, shape, dtype) * scale return init -def normal(stddev: RealNumeric = 1e-2, - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: +def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable: """Builds an initializer that returns real normally-distributed random arrays. Args: @@ -140,18 +113,13 @@ def normal(stddev: RealNumeric = 1e-2, DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ], [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32) """ - def init(key: KeyArray, - shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) return random.normal(key, shape, dtype) * stddev return init -def _compute_fans(shape: core.NamedShape, - in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, - batch_axis: Union[int, Sequence[int]] = () - ) -> Tuple[Array, Array]: +def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1, + batch_axis=()): """ Compute effective input and output sizes for a linear or convolutional layer. @@ -175,9 +143,7 @@ def _compute_fans(shape: core.NamedShape, fan_out = out_size * receptive_field_size return fan_in, fan_out -def _complex_uniform(key: KeyArray, - shape: core.NamedShape, - dtype: DTypeLikeInexact) -> Array: +def _complex_uniform(key, shape, dtype): """ Sample uniform random values within a disk on the complex plane, with zero mean and unit variance. @@ -189,33 +155,24 @@ def _complex_uniform(key: KeyArray, theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) -def _complex_truncated_normal(key: KeyArray, upper: Array, - shape: core.NamedShape, - dtype: DTypeLikeInexact) -> Array: +def _complex_truncated_normal(key, upper, shape, dtype): """ Sample random values from a centered normal distribution on the complex plane, - whose modulus is truncated to `upper`, and the variance before the truncation - is one. + whose modulus is truncated to `upper`, and the variance before the truncation is one. """ key_r, key_theta = random.split(key) real_dtype = np.array(0, dtype).real.dtype dtype = dtypes._to_complex_dtype(real_dtype) - t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) - * random.uniform(key_r, shape, real_dtype).astype(dtype)) + t = (1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) * random.uniform(key_r, shape, real_dtype).astype(dtype) r = jnp.sqrt(-jnp.log(1 - t)) theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) -def variance_scaling( - scale: RealNumeric, - mode: Union[Literal["fan_in"], Literal["fan_out"], Literal["fan_avg"]], - distribution: Union[Literal["truncated_normal"], Literal["normal"], - Literal["uniform"]], - in_axis: Union[int, Sequence[int]] = -2, - out_axis: Union[int, Sequence[int]] = -1, - batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_ -) -> Initializer: +def variance_scaling(scale, mode: str, distribution: str, + in_axis: Union[int, Sequence[int]] = -2, + out_axis: Union[int, Sequence[int]] = -1, + batch_axis: Sequence[int] = (), + dtype: DType = jnp.float_) -> Callable: r""" Initializer that adapts its scale to the shape of the weights tensor. @@ -257,12 +214,10 @@ def variance_scaling( dtype: the dtype of the weights. """ - def init(key: KeyArray, - shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) - named_shape = core.as_named_shape(shape) - fan_in, fan_out = _compute_fans(named_shape, in_axis, out_axis, batch_axis) + shape = core.as_named_shape(shape) + fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2 @@ -275,18 +230,18 @@ def init(key: KeyArray, if jnp.issubdtype(dtype, jnp.floating): # constant is stddev of standard normal truncated to (-2, 2) stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) - return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev + return random.truncated_normal(key, -2, 2, shape, dtype) * stddev else: # constant is stddev of complex standard normal truncated to 2 stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype) - return _complex_truncated_normal(key, 2, named_shape, dtype) * stddev + return _complex_truncated_normal(key, 2, shape, dtype) * stddev elif distribution == "normal": - return random.normal(key, named_shape, dtype) * jnp.sqrt(variance) + return random.normal(key, shape, dtype) * jnp.sqrt(variance) elif distribution == "uniform": if jnp.issubdtype(dtype, jnp.floating): - return random.uniform(key, named_shape, dtype, -1) * jnp.sqrt(3 * variance) + return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance) else: - return _complex_uniform(key, named_shape, dtype) * jnp.sqrt(variance) + return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance) else: raise ValueError(f"invalid distribution for variance scaling initializer: {distribution}") @@ -295,7 +250,7 @@ def init(key: KeyArray, def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DType = jnp.float_) -> Callable: """Builds a Glorot uniform initializer (aka Xavier uniform initializer). A `Glorot uniform initializer`_ is a specialization of @@ -333,7 +288,7 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2, def glorot_normal(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DType = jnp.float_) -> Callable: """Builds a Glorot normal initializer (aka Xavier normal initializer). A `Glorot normal initializer`_ is a specialization of @@ -370,7 +325,7 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2, def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DType = jnp.float_) -> Callable: """Builds a Lecun uniform initializer. A `Lecun uniform initializer`_ is a specialization of @@ -405,7 +360,7 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2, def lecun_normal(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DType = jnp.float_) -> Callable: """Builds a Lecun normal initializer. A `Lecun normal initializer`_ is a specialization of @@ -441,7 +396,7 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2, def he_uniform(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DType = jnp.float_) -> Callable: """Builds a He uniform initializer (aka Kaiming uniform initializer). A `He uniform initializer`_ is a specialization of @@ -479,7 +434,7 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2, def he_normal(in_axis: Union[int, Sequence[int]] = -2, out_axis: Union[int, Sequence[int]] = -1, batch_axis: Sequence[int] = (), - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: + dtype: DType = jnp.float_) -> Callable: """Builds a He normal initializer (aka Kaiming normal initializer). A `He normal initializer`_ is a specialization of @@ -514,9 +469,7 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2, kaiming_normal = he_normal -def orthogonal(scale: RealNumeric = 1.0, - column_axis: int = -1, - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: +def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_): """ Builds an initializer that returns uniformly distributed orthogonal matrices. @@ -539,9 +492,7 @@ def orthogonal(scale: RealNumeric = 1.0, DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01], [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32) """ - def init(key: KeyArray, - shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) if len(shape) < 2: raise ValueError("orthogonal initializer requires at least a 2D shape") @@ -558,10 +509,7 @@ def init(key: KeyArray, return init -def delta_orthogonal( - scale: RealNumeric = 1.0, - column_axis: int = -1, - dtype: DTypeLikeInexact = jnp.float_) -> Initializer: +def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_): """ Builds an initializer for delta orthogonal kernels. @@ -594,9 +542,7 @@ def delta_orthogonal( .. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393 """ - def init(key: KeyArray, - shape: core.Shape, - dtype: DTypeLikeInexact = dtype) -> Array: + def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) if len(shape) not in [3, 4, 5]: raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D " diff --git a/jax/nn/initializers.py b/jax/nn/initializers.py index 0ae9cd1a2f7f..1a3b84f4b596 100644 --- a/jax/nn/initializers.py +++ b/jax/nn/initializers.py @@ -19,7 +19,6 @@ from jax._src.nn.initializers import ( constant as constant, - Initializer as Initializer, delta_orthogonal as delta_orthogonal, glorot_normal as glorot_normal, glorot_uniform as glorot_uniform,