Skip to content

Commit

Permalink
Merge pull request #9798 from hawkinsp:initdoc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 433187449
  • Loading branch information
jax authors committed Mar 8, 2022
2 parents dc2ca18 + ad5144f commit e8f1a02
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions jax/_src/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def uniform(scale=1e-2, dtype: DType = jnp.float_) -> Callable:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.uniform(10.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[7.298188 , 8.691938 , 8.7230015],
[2.0818567, 1.8662417, 5.5022564]], dtype=float32)
"""
Expand All @@ -109,7 +109,7 @@ def normal(stddev=1e-2, dtype: DType = jnp.float_) -> Callable:
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.normal(5.0)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 3.0613258 , 5.6129413 , 5.6866574 ],
[-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)
"""
Expand Down Expand Up @@ -271,7 +271,7 @@ def glorot_uniform(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.50350785, 0.8088631 , 0.81566876],
[-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)
Expand Down Expand Up @@ -309,7 +309,7 @@ def glorot_normal(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.glorot_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.41770416, 0.75262755, 0.7619329 ],
[-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)
Expand Down Expand Up @@ -346,7 +346,7 @@ def lecun_uniform(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.56293887, 0.90433645, 0.9119454 ],
[-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)
Expand Down Expand Up @@ -381,7 +381,7 @@ def lecun_normal(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.lecun_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.46700746, 0.8414632 , 0.8518669 ],
[-0.61677957, -0.67402434, 0.09683388]], dtype=float32)
Expand Down Expand Up @@ -417,7 +417,7 @@ def he_uniform(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.kaiming_uniform()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.79611576, 1.2789248 , 1.2896855 ],
[-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)
Expand Down Expand Up @@ -455,7 +455,7 @@ def he_normal(in_axis: Union[int, Sequence[int]] = -2,
>>> import jax, jax.numpy as jnp
>>> initializer = jax.nn.initializers.kaiming_normal()
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32)
>>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP
DeviceArray([[ 0.6604483 , 1.1900088 , 1.2047218 ],
[-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)
Expand Down

0 comments on commit e8f1a02

Please sign in to comment.