I'm going through https://flax.readthedocs.io/en/latest/guides/randomness.html and the shorthands do not work (snippet from the site):
rngs = nnx.Rngs(0, params=1)
# using jax.random
z1 = jax.random.normal(rngs(), (2, 3))
z2 = jax.random.bernoulli(rngs.params(), 0.5, (10,))
# shorthand methods
z1 = rngs.normal((2, 3)) # generates key from rngs.default
z2 = rngs.params.bernoulli(0.5, (10,)) # generates key from rngs.params
The non-working shorthands also make the very first example on the introductory site
https://flax.readthedocs.io/en/latest/nnx_basics.html not work:
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(rngs.params.uniform((din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w + self.b
eventually leads to AttributeError: 'RngStream' object has no attribute 'uniform'.
Is there a way to make this 'shorthand' idea work, or is the only approach to use jax.random.normal etc.?
Thanks!
Flax version: 0.11.1
Jax version: 0.7.0
I'm going through https://flax.readthedocs.io/en/latest/guides/randomness.html and the shorthands do not work (snippet from the site):
The non-working shorthands also make the very first example on the introductory site
https://flax.readthedocs.io/en/latest/nnx_basics.html not work:
eventually leads to
AttributeError: 'RngStream' object has no attribute 'uniform'.Is there a way to make this 'shorthand' idea work, or is the only approach to use
jax.random.normaletc.?Thanks!
Flax version: 0.11.1
Jax version: 0.7.0