Skip to content

Commit

Permalink
Remove deprecated function as_named_shape.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 655920752
  • Loading branch information
The e3x Authors committed Jul 25, 2024
1 parent ab86199 commit e2e9312
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions e3x/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _random_array(
return stddev * _complex_truncated_normal(
key=key,
upper=2,
shape=jax.core.as_named_shape(shape),
shape=shape,
dtype=dtype,
)
elif distribution == 'normal':
Expand All @@ -192,7 +192,7 @@ def _random_array(
)
else:
return jnp.sqrt(variance) * _complex_uniform(
key=key, shape=jax.core.as_named_shape(shape), dtype=dtype
key=key, shape=shape, dtype=dtype
)
else:
raise ValueError(f"invalid distribution '{distribution}' for _random_array")
Expand Down

0 comments on commit e2e9312

Please sign in to comment.