New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn.stochastic context does not extend into jitted functions #112
Comments
@adarob could you provide a minimal repro for this? |
Doesn't work:
Output:
Works
Output:
|
In that second example you meant to write |
I don't know which line you're referring to but it looks like what I intended. |
Ah, my apologies I misread it on the first read. |
This is part of a larger issue concerning mixing states and jax transformations. |
Btw we are also looking into automatically supporting things like stateful and stochastic in combination with jax transforms together with the Haiku folks and the jax core team. But for know we just try to avoid silent errors |
@jheek assigning to you because I believe you're looking into this |
|
This causes silent failures where
nn.make_rng()
inside of the jitted function always produces the same "random" numbers. This issue should be well-documented and errors should be thrown to protect against it.The text was updated successfully, but these errors were encountered: