Skip to content

Commit

Permalink
Implement variance scaling initializers with complex dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
wdphy16 committed Aug 19, 2021
1 parent eb15207 commit 5138743
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 20 deletions.
72 changes: 67 additions & 5 deletions jax/_src/nn/initializers.py
Expand Up @@ -48,7 +48,59 @@ def _compute_fans(shape: core.NamedShape, in_axis=-2, out_axis=-1):
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out

def _complex_uniform(key, shape, dtype):
"""
Sample uniform random values within a disk on the complex plane,
with zero mean and unit variance.
"""
key_r, key_theta = random.split(key)
dtype = np.array(0, dtype).real.dtype
r = jnp.sqrt(2 * random.uniform(key_r, shape, dtype))
theta = 2 * jnp.pi * random.uniform(key_theta, shape, dtype)
return r * jnp.exp(1j * theta)

def _complex_truncated_normal(key, upper, shape, dtype):
"""
Sample random values from a centered normal distribution on the complex plane,
whose modulus is truncated to `upper`, and the variance before the truncation is one.
"""
key_r, key_theta = random.split(key)
dtype = np.array(0, dtype).real.dtype
t = (1 - jnp.exp(jnp.array(-(upper ** 2), dtype))) * random.uniform(key_r, shape, dtype)
r = jnp.sqrt(-jnp.log(1 - t))
theta = 2 * jnp.pi * random.uniform(key_theta, shape, dtype)
return r * jnp.exp(1j * theta)

def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float32):
"""
Initializer capable of adapting its scale to the shape of the weights tensor.
With `distribution="truncated_normal" or "normal"`, samples are
drawn from a truncated/untruncated normal distribution with a mean of zero and
a standard deviation (after truncation, if used) `stddev = sqrt(scale / n)`,
where `n` is:
- number of input units in the weights tensor, if `mode="fan_in"`
- number of output units, if `mode="fan_out"`
- average of the numbers of input and output units, if `mode="fan_avg"`
With `distribution="truncated_normal"`, the absolute values of the samples are
truncated below 2 standard deviations before truncation.
With `distribution="uniform"`, samples are drawn from:
- a uniform interval, if `dtype` is real
- a uniform disk, if `dtype` is complex
with a mean of zero and a standard deviation of `stddev`.
Args:
scale: scaling factor (positive float).
mode: one of "fan_in", "fan_out", and "fan_avg".
distribution: random distribution to use. One of "truncated_normal",
"normal" and "uniform".
in_axis: axis of the input dimension in the weights tensor.
out_axis: axis of the output dimension in the weights tensor.
dtype: the dtype of the weights.
"""

def init(key, shape, dtype=dtype):
shape = core.as_named_shape(shape)
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
Expand All @@ -59,16 +111,26 @@ def init(key, shape, dtype=dtype):
raise ValueError(
"invalid mode for variance scaling initializer: {}".format(mode))
variance = jnp.array(scale / denominator, dtype=dtype)

if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
if jnp.issubdtype(dtype, jnp.floating):
# constant is stddev of standard normal truncated to (-2, 2)
stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
else:
# constant is stddev of complex standard normal truncated to 2
stddev = jnp.sqrt(variance) / jnp.array(.95311164380491208, dtype)
return _complex_truncated_normal(key, 2, shape, dtype) * stddev
elif distribution == "normal":
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
elif distribution == "uniform":
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
if jnp.issubdtype(dtype, jnp.floating):
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
else:
return _complex_uniform(key, shape, dtype) * jnp.sqrt(variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
raise ValueError("invalid distribution for variance scaling initializer: {}".format(distribution))

return init

xavier_uniform = glorot_uniform = partial(variance_scaling, 1.0, "fan_avg", "uniform")
Expand Down
31 changes: 16 additions & 15 deletions tests/nn_test.py
Expand Up @@ -178,26 +178,26 @@ def testTanhExists(self):

InitializerRecord = collections.namedtuple(
"InitializerRecord",
["name", "initializer", "shapes"])
["name", "initializer", "shapes", "dtypes"])

ALL_SHAPES = [(2,), (2, 2), (2, 3), (3, 2), (2, 3, 4), (4, 3, 2), (2, 3, 4, 5)]

def initializer_record(name, initializer, min_dims=2, max_dims=4):
def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4):
shapes = [shape for shape in ALL_SHAPES
if min_dims <= len(shape) <= max_dims]
return InitializerRecord(name, initializer, shapes)
return InitializerRecord(name, initializer, shapes, dtypes)

INITIALIZER_RECS = [
initializer_record("uniform", nn.initializers.uniform, 1),
initializer_record("normal", nn.initializers.normal, 1),
initializer_record("he_normal", nn.initializers.he_normal),
initializer_record("he_uniform", nn.initializers.he_uniform),
initializer_record("glorot_normal", nn.initializers.glorot_normal),
initializer_record("glorot_uniform", nn.initializers.glorot_uniform),
initializer_record("lecun_normal", nn.initializers.lecun_normal),
initializer_record("lecun_uniform", nn.initializers.lecun_uniform),
initializer_record("orthogonal", nn.initializers.orthogonal, 2, 2),
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, 4, 4)
initializer_record("uniform", nn.initializers.uniform, jtu.dtypes.floating, 1),
initializer_record("normal", nn.initializers.normal, jtu.dtypes.inexact, 1),
initializer_record("he_normal", nn.initializers.he_normal, jtu.dtypes.inexact),
initializer_record("he_uniform", nn.initializers.he_uniform, jtu.dtypes.inexact),
initializer_record("glorot_normal", nn.initializers.glorot_normal, jtu.dtypes.inexact),
initializer_record("glorot_uniform", nn.initializers.glorot_uniform, jtu.dtypes.inexact),
initializer_record("lecun_normal", nn.initializers.lecun_normal, jtu.dtypes.inexact),
initializer_record("lecun_uniform", nn.initializers.lecun_uniform, jtu.dtypes.inexact),
initializer_record("orthogonal", nn.initializers.orthogonal, jtu.dtypes.floating, 2, 2),
initializer_record("delta_orthogonal", nn.initializers.delta_orthogonal, jtu.dtypes.floating, 4, 4)
]

class NNInitializersTest(jtu.JaxTestCase):
Expand All @@ -219,10 +219,11 @@ def tearDown(self):
"shape": shape, "dtype": dtype}
for rec in INITIALIZER_RECS
for shape in rec.shapes
for dtype in jtu.dtypes.floating))
for dtype in rec.dtypes))
def testInitializer(self, initializer, shape, dtype):
rng = random.PRNGKey(0)
val = initializer(rng, shape, dtype)

self.assertEqual(shape, jnp.shape(val))
self.assertEqual(jax.dtypes.canonicalize_dtype(dtype), jnp.dtype(val))

Expand All @@ -235,7 +236,7 @@ def testInitializer(self, initializer, shape, dtype):
"shape": shape, "dtype": dtype}
for rec in INITIALIZER_RECS
for shape in rec.shapes
for dtype in jtu.dtypes.floating))
for dtype in rec.dtypes))
def testInitializerProvider(self, initializer_provider, shape, dtype):
rng = random.PRNGKey(0)
initializer = initializer_provider(dtype=dtype)
Expand Down

0 comments on commit 5138743

Please sign in to comment.