Skip to content

Commit

Permalink
Annotate nn.initializers
Browse files Browse the repository at this point in the history
This was done to expose an Initializers type annotation that can be used
in other libraries.
  • Loading branch information
NeilGirdhar committed Aug 5, 2022
1 parent 07da502 commit 7869ff4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 40 deletions.
134 changes: 94 additions & 40 deletions jax/_src/nn/initializers.py
Expand Up @@ -18,9 +18,10 @@
"""


from typing import Any, Callable, Sequence, Union
from typing import Any, Sequence, Tuple, Union

import numpy as np
from typing_extensions import Literal, Protocol

import jax.numpy as jnp
from jax import lax
Expand All @@ -29,9 +30,25 @@
from jax._src.util import prod
from jax._src import dtypes

DType = Any

def zeros(key, shape, dtype: DType = jnp.float_):
KeyArray = random.KeyArray
Array = Any
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeFloat = Any
DTypeLikeComplex = Any
DTypeLikeInexact = Any # DTypeLikeFloat | DTypeLikeComplex
RealNumeric = Any # Scalar jnp array or float

class Initializer(Protocol):
@staticmethod
def __call__(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
...

def zeros(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
"""An initializer that returns a constant array full of zeros.
The ``key`` argument is ignored.
Expand All @@ -43,7 +60,9 @@ def zeros(key, shape, dtype: DType = jnp.float_):
"""
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))

def ones(key, shape, dtype: DType = jnp.float_):
def ones(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
"""An initializer that returns a constant array full of ones.
The ``key`` argument is ignored.
Expand All @@ -56,7 +75,9 @@ def ones(key, shape, dtype: DType = jnp.float_):
"""
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))

def constant(value, dtype: DType = jnp.float_) -> Callable:
def constant(value: Array,
dtype: DTypeLikeInexact = jnp.float_
) -> Initializer:
"""Builds an initializer that returns arrays full of a constant ``value``.
Args:
Expand All @@ -69,12 +90,15 @@ def constant(value, dtype: DType = jnp.float_) -> Callable:
DeviceArray([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
return jnp.full(shape, value, dtype=dtype)
return init

def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
def uniform(scale: RealNumeric = 1e-2,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds an initializer that returns real uniformly-distributed random arrays.
Args:
Expand All @@ -91,12 +115,15 @@ def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
return random.uniform(key, shape, dtype) * scale
return init

def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
def normal(stddev: RealNumeric = 1e-2,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds an initializer that returns real normally-distributed random arrays.
Args:
Expand All @@ -113,13 +140,18 @@ def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
return random.normal(key, shape, dtype) * stddev
return init

def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1,
batch_axis=()):
def _compute_fans(shape: core.NamedShape,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Union[int, Sequence[int]] = ()
) -> Tuple[Array, Array]:
"""
Compute effective input and output sizes for a linear or convolutional layer.
Expand All @@ -143,7 +175,9 @@ def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1,
fan_out = out_size * receptive_field_size
return fan_in, fan_out

def _complex_uniform(key, shape, dtype):
def _complex_uniform(key: KeyArray,
shape: Union[Sequence[int], core.NamedShape],
dtype: DTypeLikeInexact) -> Array:
"""
Sample uniform random values within a disk on the complex plane,
with zero mean and unit variance.
Expand All @@ -155,24 +189,33 @@ def _complex_uniform(key, shape, dtype):
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)

def _complex_truncated_normal(key, upper, shape, dtype):
def _complex_truncated_normal(key: KeyArray, upper: Array,
shape: Union[Sequence[int], core.NamedShape],
dtype: DTypeLikeInexact) -> Array:
"""
Sample random values from a centered normal distribution on the complex plane,
whose modulus is truncated to `upper`, and the variance before the truncation is one.
whose modulus is truncated to `upper`, and the variance before the truncation
is one.
"""
key_r, key_theta = random.split(key)
real_dtype = np.array(0, dtype).real.dtype
dtype = dtypes._to_complex_dtype(real_dtype)
t = (1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) * random.uniform(key_r, shape, real_dtype).astype(dtype)
t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype)))
* random.uniform(key_r, shape, real_dtype).astype(dtype))
r = jnp.sqrt(-jnp.log(1 - t))
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)

def variance_scaling(scale, mode: str, distribution: str,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
def variance_scaling(
scale: RealNumeric,
mode: Union[Literal["fan_in"], Literal["fan_out"], Literal["fan_avg"]],
distribution: Union[Literal["truncated_normal"], Literal["normal"],
Literal["uniform"]],
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DTypeLikeInexact = jnp.float_
) -> Initializer:
r"""
Initializer that adapts its scale to the shape of the weights tensor.
Expand Down Expand Up @@ -214,10 +257,12 @@ def variance_scaling(scale, mode: str, distribution: str,
dtype: the dtype of the weights.
"""

def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
named_shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(named_shape, in_axis, out_axis, batch_axis)
if mode == "fan_in": denominator = fan_in
elif mode == "fan_out": denominator = fan_out
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
Expand All @@ -230,18 +275,18 @@ def init(key, shape, dtype=dtype):
if jnp.issubdtype(dtype, jnp.floating):
# constant is stddev of standard normal truncated to (-2, 2)
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev
else:
# constant is stddev of complex standard normal truncated to 2
stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype)
return _complex_truncated_normal(key, 2, shape, dtype) * stddev
return _complex_truncated_normal(key, 2, named_shape, dtype) * stddev
elif distribution == "normal":
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
return random.normal(key, named_shape, dtype) * jnp.sqrt(variance)
elif distribution == "uniform":
if jnp.issubdtype(dtype, jnp.floating):
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
return random.uniform(key, named_shape, dtype, -1) * jnp.sqrt(3 * variance)
else:
return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance)
return _complex_uniform(key, named_shape, dtype) * jnp.sqrt(variance)
else:
raise ValueError(f"invalid distribution for variance scaling initializer: {distribution}")

Expand All @@ -250,7 +295,7 @@ def init(key, shape, dtype=dtype):
def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a Glorot uniform initializer (aka Xavier uniform initializer).
A `Glorot uniform initializer`_ is a specialization of
Expand Down Expand Up @@ -288,7 +333,7 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a Glorot normal initializer (aka Xavier normal initializer).
A `Glorot normal initializer`_ is a specialization of
Expand Down Expand Up @@ -325,7 +370,7 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a Lecun uniform initializer.
A `Lecun uniform initializer`_ is a specialization of
Expand Down Expand Up @@ -360,7 +405,7 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a Lecun normal initializer.
A `Lecun normal initializer`_ is a specialization of
Expand Down Expand Up @@ -396,7 +441,7 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a He uniform initializer (aka Kaiming uniform initializer).
A `He uniform initializer`_ is a specialization of
Expand Down Expand Up @@ -434,7 +479,7 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
def he_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a He normal initializer (aka Kaiming normal initializer).
A `He normal initializer`_ is a specialization of
Expand Down Expand Up @@ -469,7 +514,9 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2,
kaiming_normal = he_normal


def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
def orthogonal(scale: RealNumeric = 1.0,
column_axis: int = -1,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""
Builds an initializer that returns uniformly distributed orthogonal matrices.
Expand All @@ -492,7 +539,9 @@ def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
if len(shape) < 2:
raise ValueError("orthogonal initializer requires at least a 2D shape")
Expand All @@ -509,7 +558,10 @@ def init(key, shape, dtype=dtype):
return init


def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
def delta_orthogonal(
scale: RealNumeric = 1.0,
column_axis: int = -1,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""
Builds an initializer for delta orthogonal kernels.
Expand Down Expand Up @@ -542,7 +594,9 @@ def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
if len(shape) not in [3, 4, 5]:
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D "
Expand Down
1 change: 1 addition & 0 deletions jax/nn/initializers.py
Expand Up @@ -19,6 +19,7 @@

from jax._src.nn.initializers import (
constant as constant,
Initializer as Initializer,
delta_orthogonal as delta_orthogonal,
glorot_normal as glorot_normal,
glorot_uniform as glorot_uniform,
Expand Down

0 comments on commit 7869ff4

Please sign in to comment.