Skip to content

Commit

Permalink
Merge pull request #6750 from romanngg:init_dtypes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 392071088
  • Loading branch information
jax authors committed Aug 20, 2021
2 parents a4bd0e7 + b65f39c commit 693d2e2
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
20 changes: 13 additions & 7 deletions jax/_src/nn/initializers.py
Expand Up @@ -28,17 +28,20 @@
from jax import random
from jax import core
from jax._src.util import prod
from jax import dtypes

def zeros(key, shape, dtype=jnp.float32): return jnp.zeros(shape, dtype)
def ones(key, shape, dtype=jnp.float32): return jnp.ones(shape, dtype)
def zeros(key, shape, dtype=jnp.float_): return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
def ones(key, shape, dtype=jnp.float_): return jnp.ones(shape, dtypes.canonicalize_dtype(dtype))

def uniform(scale=1e-2, dtype=jnp.float32):
def uniform(scale=1e-2, dtype=jnp.float_):
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return random.uniform(key, shape, dtype) * scale
return init

def normal(stddev=1e-2, dtype=jnp.float32):
def normal(stddev=1e-2, dtype=jnp.float_):
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
return random.normal(key, shape, dtype) * stddev
return init

Expand All @@ -48,8 +51,9 @@ def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1):
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out

def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float32):
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in": denominator = fan_in
Expand Down Expand Up @@ -78,14 +82,15 @@ def init(key, shape, dtype=dtype):
kaiming_uniform = he_uniform = partial(variance_scaling, 2.0, "fan_in", "uniform")
kaiming_normal = he_normal = partial(variance_scaling, 2.0, "fan_in", "truncated_normal")

def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
def orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
"""
Construct an initializer for uniformly distributed orthogonal matrices.
If the shape is not square, the matrices will have orthonormal rows or columns
depending on which side is smaller.
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
if len(shape) < 2:
raise ValueError("orthogonal initializer requires at least a 2D shape")
n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
Expand All @@ -101,13 +106,14 @@ def init(key, shape, dtype=dtype):
return init


def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32):
def delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float_):
"""
Construct an initializer for delta orthogonal kernels; see arXiv:1806.05393.
The shape must be 3D, 4D or 5D.
"""
def init(key, shape, dtype=dtype):
dtype = dtypes.canonicalize_dtype(dtype)
if len(shape) not in [3, 4, 5]:
raise ValueError("Delta orthogonal initializer requires a 3D, 4D or 5D "
"shape.")
Expand Down
8 changes: 4 additions & 4 deletions tests/jet_test.py
Expand Up @@ -111,12 +111,12 @@ def test_conv(self):

rng = np.random.RandomState(0)

x = rng.randn(*input_shape).astype("float32")
x = rng.randn(*input_shape)
primals = (W, b, x)

series_in1 = [rng.randn(*W.shape).astype("float32") for _ in range(order)]
series_in2 = [rng.randn(*b.shape).astype("float32") for _ in range(order)]
series_in3 = [rng.randn(*x.shape).astype("float32") for _ in range(order)]
series_in1 = [rng.randn(*W.shape) for _ in range(order)]
series_in2 = [rng.randn(*b.shape) for _ in range(order)]
series_in3 = [rng.randn(*x.shape) for _ in range(order)]

series_in = (series_in1, series_in2, series_in3)

Expand Down
3 changes: 2 additions & 1 deletion tests/stax_test.py
Expand Up @@ -22,14 +22,15 @@
from jax import test_util as jtu
from jax import random
from jax.experimental import stax
from jax import dtypes

from jax.config import config
config.parse_flags_with_absl()


def random_inputs(rng, input_shape):
if type(input_shape) is tuple:
return rng.randn(*input_shape).astype(np.float32)
return rng.randn(*input_shape).astype(dtypes.canonicalize_dtype(np.float_))
elif type(input_shape) is list:
return [random_inputs(rng, shape) for shape in input_shape]
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/xmap_test.py
Expand Up @@ -734,7 +734,7 @@ def testVarianceScaling(self, map_in, map_out, fan, distr):
shape = (80, 50, 7)
fan_in, fan_out = jax._src.nn.initializers._compute_fans(
NamedShape(*shape), 0, 1)
key = jax.random.PRNGKey(0)
key = jax.random.PRNGKey(1)
base_scaling = partial(jax.nn.initializers.variance_scaling, 100, fan, distr)
ref_sampler = lambda: base_scaling(in_axis=0, out_axis=1)(key, shape)
if map_in and map_out:
Expand Down

0 comments on commit 693d2e2

Please sign in to comment.