Skip to content

Commit

Permalink
Annotate nn.initializers
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Aug 4, 2022
1 parent a3ad01a commit 1bd3784
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 40 deletions.
134 changes: 94 additions & 40 deletions jax/_src/nn/initializers.py
Expand Up @@ -18,9 +18,10 @@
"""


from typing import Any, Callable, Sequence, Union
from typing import Any, Sequence, Tuple, Union

import numpy as np
from typing_extensions import Literal, Protocol

import jax.numpy as jnp
from jax import lax
Expand All @@ -29,9 +30,25 @@
from jax._src.util import prod
from jax._src import dtypes

DType = Any

def zeros(key, shape, dtype: DType = jnp.float_):
KeyArray = random.KeyArray
Array = Any
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeFloat = Any
DTypeLikeComplex = Any
DTypeLikeInexact = Any # DTypeLikeFloat | DTypeLikeComplex
RealNumeric = Any # Scalar jnp array or float

class Initializer(Protocol):
@staticmethod
def __call__(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
...

def zeros(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
"""An initializer that returns a constant array full of zeros.
The ``key`` argument is ignored.
Expand All @@ -43,7 +60,9 @@ def zeros(key, shape, dtype: DType = jnp.float_):
"""
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))

def ones(key, shape, dtype: DType = jnp.float_):
def ones(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
"""An initializer that returns a constant array full of ones.
The ``key`` argument is ignored.
Expand All @@ -56,7 +75,9 @@ def ones(key, shape, dtype: DType = jnp.float_):
"""
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))

def constant(value, dtype: DType = jnp.float_) -> Callable:
def constant(value: Array,
dtype: DTypeLikeInexact = jnp.float_
) -> Initializer:
"""Builds an initializer that returns arrays full of a constant ``value``.
Args:
Expand All @@ -69,12 +90,15 @@ def constant(value, dtype: DType = jnp.float_) -> Callable:
DeviceArray([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
return jnp.full(shape, value, dtype=dtype)
return init

def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
def uniform(scale: RealNumeric = 1e-2,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds an initializer that returns real uniformly-distributed random arrays.
Args:
Expand All @@ -91,12 +115,15 @@ def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
return random.uniform(key, shape, dtype) * scale
return init

def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
def normal(stddev: RealNumeric = 1e-2,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds an initializer that returns real normally-distributed random arrays.
Args:
Expand All @@ -113,13 +140,18 @@ def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
return random.normal(key, shape, dtype) * stddev
return init

def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1,
batch_axis=()):
def _compute_fans(shape: core.NamedShape,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Union[int, Sequence[int]] = ()
) -> Tuple[Array, Array]:
"""
Compute effective input and output sizes for a linear or convolutional layer.
Expand All @@ -143,7 +175,9 @@ def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1,
fan_out = out_size * receptive_field_size
return fan_in, fan_out

def _complex_uniform(key, shape, dtype):
def _complex_uniform(key: KeyArray,
shape: core.NamedShape,
dtype: DTypeLikeInexact) -> Array:
"""
Sample uniform random values within a disk on the complex plane,
with zero mean and unit variance.
Expand All @@ -155,24 +189,33 @@ def _complex_uniform(key, shape, dtype):
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)

def _complex_truncated_normal(key, upper, shape, dtype):
def _complex_truncated_normal(key: KeyArray, upper: Array,
shape: core.NamedShape,
dtype: DTypeLikeInexact) -> Array:
"""
Sample random values from a centered normal distribution on the complex plane,
whose modulus is truncated to `upper`, and the variance before the truncation is one.
whose modulus is truncated to `upper`, and the variance before the truncation
is one.
"""
key_r, key_theta = random.split(key)
real_dtype = np.array(0, dtype).real.dtype
dtype = dtypes._to_complex_dtype(real_dtype)
t = (1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) * random.uniform(key_r, shape, real_dtype).astype(dtype)
t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype)))
* random.uniform(key_r, shape, real_dtype).astype(dtype))
r = jnp.sqrt(-jnp.log(1 - t))
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)

def variance_scaling(scale, mode: str, distribution: str,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
def variance_scaling(
scale: RealNumeric,
mode: Union[Literal["fan_in"], Literal["fan_out"], Literal["fan_avg"]],
distribution: Union[Literal["truncated_normal"], Literal["normal"],
Literal["uniform"]],
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DTypeLikeInexact = jnp.float_
) -> Initializer:
r"""
Initializer that adapts its scale to the shape of the weights tensor.
Expand Down Expand Up @@ -214,10 +257,12 @@ def variance_scaling(scale, mode: str, distribution: str,
dtype: the dtype of the weights.
"""

def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
named_shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(named_shape, in_axis, out_axis, batch_axis)
if mode == "fan_in": denominator = fan_in
elif mode == "fan_out": denominator = fan_out
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
Expand All @@ -230,18 +275,18 @@ def init(key, shape, dtype=dtype):
if jnp.issubdtype(dtype, jnp.floating):
# constant is stddev of standard normal truncated to (-2, 2)
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev
else:
# constant is stddev of complex standard normal truncated to 2
stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype)
return _complex_truncated_normal(key, 2, shape, dtype) * stddev
return _complex_truncated_normal(key, 2, named_shape, dtype) * stddev
elif distribution == "normal":
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
return random.normal(key, named_shape, dtype) * jnp.sqrt(variance)
elif distribution == "uniform":
if jnp.issubdtype(dtype, jnp.floating):
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
return random.uniform(key, named_shape, dtype, -1) * jnp.sqrt(3 * variance)
else:
return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance)
return _complex_uniform(key, named_shape, dtype) * jnp.sqrt(variance)
else:
raise ValueError(f"invalid distribution for variance scaling initializer: {distribution}")

Expand All @@ -250,7 +295,7 @@ def init(key, shape, dtype=dtype):
def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a Glorot uniform initializer (aka Xavier uniform initializer).
A `Glorot uniform initializer`_ is a specialization of
Expand Down Expand Up @@ -288,7 +333,7 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a Glorot normal initializer (aka Xavier normal initializer).
A `Glorot normal initializer`_ is a specialization of
Expand Down Expand Up @@ -325,7 +370,7 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a Lecun uniform initializer.
A `Lecun uniform initializer`_ is a specialization of
Expand Down Expand Up @@ -360,7 +405,7 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a Lecun normal initializer.
A `Lecun normal initializer`_ is a specialization of
Expand Down Expand Up @@ -396,7 +441,7 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a He uniform initializer (aka Kaiming uniform initializer).
A `He uniform initializer`_ is a specialization of
Expand Down Expand Up @@ -434,7 +479,7 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
def he_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""Builds a He normal initializer (aka Kaiming normal initializer).
A `He normal initializer`_ is a specialization of
Expand Down Expand Up @@ -469,7 +514,9 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2,
kaiming_normal = he_normal


def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
def orthogonal(scale: RealNumeric = 1.0,
column_axis: int = -1,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""
Builds an initializer that returns uniformly distributed orthogonal matrices.
Expand All @@ -492,7 +539,9 @@ def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
if len(shape) < 2:
raise ValueError("orthogonal initializer requires at least a 2D shape")
Expand All @@ -509,7 +558,10 @@ def init(key, shape, dtype=dtype):
return init


def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
def delta_orthogonal(
scale: RealNumeric = 1.0,
column_axis: int = -1,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
"""
Builds an initializer for delta orthogonal kernels.
Expand Down Expand Up @@ -542,7 +594,9 @@ def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
"""
def init(key, shape, dtype=dtype):
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
dtype = dtypes.canonicalize_dtype(dtype)
if len(shape) not in [3, 4, 5]:
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D "
Expand Down
1 change: 1 addition & 0 deletions jax/nn/initializers.py
Expand Up @@ -19,6 +19,7 @@

from jax._src.nn.initializers import (
constant as constant,
Initializer as Initializer,
delta_orthogonal as delta_orthogonal,
glorot_normal as glorot_normal,
glorot_uniform as glorot_uniform,
Expand Down

0 comments on commit 1bd3784

Please sign in to comment.