Skip to content

Commit

Permalink
rollback of #9596
Browse files Browse the repository at this point in the history
Why? Shape annotations are inaccurate and cause pytype failures

PiperOrigin-RevId: 465337386
  • Loading branch information
Jake VanderPlas authored and jax authors committed Aug 4, 2022
1 parent 89ce078 commit d52017a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 95 deletions.
134 changes: 40 additions & 94 deletions jax/_src/nn/initializers.py
Expand Up @@ -18,10 +18,9 @@
"""


from typing import Any, Sequence, Tuple, Union
from typing import Any, Callable, Sequence, Union

import numpy as np
from typing_extensions import Literal, Protocol

import jax.numpy as jnp
from jax import lax
Expand All @@ -30,25 +29,9 @@
from jax._src.util import prod
from jax._src import dtypes

KeyArray = random.KeyArray
Array = Any
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeFloat = Any
DTypeLikeComplex = Any
DTypeLikeInexact = Any # DTypeLikeFloat | DTypeLikeComplex
RealNumeric = Any # Scalar jnp array or float

class Initializer(Protocol):
@staticmethod
def __call__(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
...

def zeros(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
DType = Any

def zeros(key, shape, dtype: DType = jnp.float_):
"""An initializer that returns a constant array full of zeros.
The ``key`` argument is ignored.
Expand All @@ -60,9 +43,7 @@ def zeros(key: KeyArray,
"""
return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))

def ones(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = jnp.float_) -> Array:
def ones(key, shape, dtype: DType = jnp.float_):
"""An initializer that returns a constant array full of ones.
The ``key`` argument is ignored.
Expand All @@ -75,9 +56,7 @@ def ones(key: KeyArray,
"""
return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))

def constant(value: Array,
dtype: DTypeLikeInexact = jnp.float_
) -> Initializer:
def constant(value, dtype: DType = jnp.float_) -> Callable:
"""Builds an initializer that returns arrays full of a constant ``value``.
Args:
Expand All @@ -90,15 +69,12 @@ def constant(value: Array,
DeviceArray([[-7., -7., -7.],
[-7., -7., -7.]], dtype=float32)
"""
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return jnp.full(shape, value, dtype=dtype)
return init

def uniform(scale: RealNumeric = 1e-2,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
"""Builds an initializer that returns real uniformly-distributed random arrays.
Args:
Expand All @@ -115,15 +91,12 @@ def uniform(scale: RealNumeric = 1e-2,
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
"""
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return random.uniform(key, shape, dtype) * scale
return init

def normal(stddev: RealNumeric = 1e-2,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
"""Builds an initializer that returns real normally-distributed random arrays.
Args:
Expand All @@ -140,18 +113,13 @@ def normal(stddev: RealNumeric = 1e-2,
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
"""
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return random.normal(key, shape, dtype) * stddev
return init

def _compute_fans(shape: core.NamedShape,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Union[int, Sequence[int]] = ()
) -> Tuple[Array, Array]:
def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1,
batch_axis=()):
"""
Compute effective input and output sizes for a linear or convolutional layer.
Expand All @@ -175,9 +143,7 @@ def _compute_fans(shape: core.NamedShape,
fan_out = out_size * receptive_field_size
return fan_in, fan_out

def _complex_uniform(key: KeyArray,
shape: core.NamedShape,
dtype: DTypeLikeInexact) -> Array:
def _complex_uniform(key, shape, dtype):
"""
Sample uniform random values within a disk on the complex plane,
with zero mean and unit variance.
Expand All @@ -189,33 +155,24 @@ def _complex_uniform(key: KeyArray,
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)

def _complex_truncated_normal(key: KeyArray, upper: Array,
shape: core.NamedShape,
dtype: DTypeLikeInexact) -> Array:
def _complex_truncated_normal(key, upper, shape, dtype):
"""
Sample random values from a centered normal distribution on the complex plane,
whose modulus is truncated to `upper`, and the variance before the truncation
is one.
whose modulus is truncated to `upper`, and the variance before the truncation is one.
"""
key_r, key_theta = random.split(key)
real_dtype = np.array(0, dtype).real.dtype
dtype = dtypes._to_complex_dtype(real_dtype)
t = ((1 - jnp.exp(jnp.array(-(upper ** 2), dtype)))
* random.uniform(key_r, shape, real_dtype).astype(dtype))
t = (1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) * random.uniform(key_r, shape, real_dtype).astype(dtype)
r = jnp.sqrt(-jnp.log(1 - t))
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)

def variance_scaling(
scale: RealNumeric,
mode: Union[Literal["fan_in"], Literal["fan_out"], Literal["fan_avg"]],
distribution: Union[Literal["truncated_normal"], Literal["normal"],
Literal["uniform"]],
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DTypeLikeInexact = jnp.float_
) -> Initializer:
def variance_scaling(scale, mode: str, distribution: str,
in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DType = jnp.float_) -> Callable:
r"""
Initializer that adapts its scale to the shape of the weights tensor.
Expand Down Expand Up @@ -257,12 +214,10 @@ def variance_scaling(
dtype: the dtype of the weights.
"""

def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
named_shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(named_shape, in_axis, out_axis, batch_axis)
shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis, batch_axis)
if mode == "fan_in": denominator = fan_in
elif mode == "fan_out": denominator = fan_out
elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2
Expand All @@ -275,18 +230,18 @@ def init(key: KeyArray,
if jnp.issubdtype(dtype, jnp.floating):
# constant is stddev of standard normal truncated to (-2, 2)
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, named_shape, dtype) * stddev
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
else:
# constant is stddev of complex standard normal truncated to 2
stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype)
return _complex_truncated_normal(key, 2, named_shape, dtype) * stddev
return _complex_truncated_normal(key, 2, shape, dtype) * stddev
elif distribution == "normal":
return random.normal(key, named_shape, dtype) * jnp.sqrt(variance)
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
elif distribution == "uniform":
if jnp.issubdtype(dtype, jnp.floating):
return random.uniform(key, named_shape, dtype, -1) * jnp.sqrt(3 * variance)
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
else:
return _complex_uniform(key, named_shape, dtype) * jnp.sqrt(variance)
return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance)
else:
raise ValueError(f"invalid distribution for variance scaling initializer: {distribution}")

Expand All @@ -295,7 +250,7 @@ def init(key: KeyArray,
def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
dtype: DType = jnp.float_) -> Callable:
"""Builds a Glorot uniform initializer (aka Xavier uniform initializer).
A `Glorot uniform initializer`_ is a specialization of
Expand Down Expand Up @@ -333,7 +288,7 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
dtype: DType = jnp.float_) -> Callable:
"""Builds a Glorot normal initializer (aka Xavier normal initializer).
A `Glorot normal initializer`_ is a specialization of
Expand Down Expand Up @@ -370,7 +325,7 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
dtype: DType = jnp.float_) -> Callable:
"""Builds a Lecun uniform initializer.
A `Lecun uniform initializer`_ is a specialization of
Expand Down Expand Up @@ -405,7 +360,7 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
dtype: DType = jnp.float_) -> Callable:
"""Builds a Lecun normal initializer.
A `Lecun normal initializer`_ is a specialization of
Expand Down Expand Up @@ -441,7 +396,7 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
dtype: DType = jnp.float_) -> Callable:
"""Builds a He uniform initializer (aka Kaiming uniform initializer).
A `He uniform initializer`_ is a specialization of
Expand Down Expand Up @@ -479,7 +434,7 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
def he_normal(in_axis: Union[int, Sequence[int]] = -2,
out_axis: Union[int, Sequence[int]] = -1,
batch_axis: Sequence[int] = (),
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
dtype: DType = jnp.float_) -> Callable:
"""Builds a He normal initializer (aka Kaiming normal initializer).
A `He normal initializer`_ is a specialization of
Expand Down Expand Up @@ -514,9 +469,7 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2,
kaiming_normal = he_normal


def orthogonal(scale: RealNumeric = 1.0,
column_axis: int = -1,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
def orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
"""
Builds an initializer that returns uniformly distributed orthogonal matrices.
Expand All @@ -539,9 +492,7 @@ def orthogonal(scale: RealNumeric = 1.0,
DeviceArray([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],
[ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)
"""
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
if len(shape) < 2:
raise ValueError("orthogonal initializer requires at least a 2D shape")
Expand All @@ -558,10 +509,7 @@ def init(key: KeyArray,
return init


def delta_orthogonal(
scale: RealNumeric = 1.0,
column_axis: int = -1,
dtype: DTypeLikeInexact = jnp.float_) -> Initializer:
def delta_orthogonal(scale=1.0, column_axis=-1, dtype: DType = jnp.float_):
"""
Builds an initializer for delta orthogonal kernels.
Expand Down Expand Up @@ -594,9 +542,7 @@ def delta_orthogonal(
.. _delta orthogonal initializer: https://arxiv.org/abs/1806.05393
"""
def init(key: KeyArray,
shape: core.Shape,
dtype: DTypeLikeInexact = dtype) -> Array:
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
if len(shape) not in [3, 4, 5]:
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D "
Expand Down
1 change: 0 additions & 1 deletion jax/nn/initializers.py
Expand Up @@ -19,7 +19,6 @@

from jax._src.nn.initializers import (
constant as constant,
Initializer as Initializer,
delta_orthogonal as delta_orthogonal,
glorot_normal as glorot_normal,
glorot_uniform as glorot_uniform,
Expand Down

0 comments on commit d52017a

Please sign in to comment.