Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.stochastic context does not extend into jitted functions #112

Closed
adarob opened this issue Mar 25, 2020 · 10 comments
Closed

nn.stochastic context does not extend into jitted functions #112

adarob opened this issue Mar 25, 2020 · 10 comments
Assignees

Comments

@adarob
Copy link
Member

adarob commented Mar 25, 2020

This causes silent failures where nn.make_rng() inside of the jitted function always produces the same "random" numbers. This issue should be well-documented and errors should be thrown to protect against it.

@AlexeyG
Copy link
Collaborator

AlexeyG commented Mar 25, 2020

@adarob could you provide a minimal repro for this?

@adarob
Copy link
Member Author

adarob commented Mar 25, 2020

Doesn't work:

@jax.jit
def rnd():
  return (jax.random.randint(nn.make_rng(), (5,), 0, 10), 
          jax.random.randint(nn.make_rng(), (5,), 0, 10))

with nn.stochastic(jax.random.PRNGKey(0)):
  for _ in range(5):
    print(rnd())

Output:

(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))
(DeviceArray([0, 2, 0, 3, 2], dtype=int32), DeviceArray([4, 1, 7, 8, 9], dtype=int32))

Works

@jax.jit
def rnd(rng):
  with nn.stochastic(rng):
    return (jax.random.randint(nn.make_rng(), (5,), 0, 10),
            jax.random.randint(nn.make_rng(), (5,), 0, 10))

with nn.stochastic(jax.random.PRNGKey(0)):
  for _ in range(5):
    print(rnd(nn.make_rng()))

Output:

(DeviceArray([8, 5, 6, 6, 7], dtype=int32), DeviceArray([4, 9, 7, 1, 5], dtype=int32))
(DeviceArray([9, 3, 1, 6, 0], dtype=int32), DeviceArray([6, 0, 5, 3, 9], dtype=int32))
(DeviceArray([2, 7, 8, 8, 1], dtype=int32), DeviceArray([9, 2, 5, 0, 6], dtype=int32))
(DeviceArray([0, 1, 2, 8, 1], dtype=int32), DeviceArray([5, 4, 6, 1, 1], dtype=int32))
(DeviceArray([1, 8, 4, 8, 3], dtype=int32), DeviceArray([1, 3, 6, 6, 4], dtype=int32))

@levskaya
Copy link
Collaborator

In that second example you meant to write rng in place of nn.make_rng(), no?

@adarob
Copy link
Member Author

adarob commented Mar 26, 2020

I don't know which line you're referring to but it looks like what I intended.

@levskaya
Copy link
Collaborator

Ah, my apologies I misread it on the first read.

@jheek
Copy link
Member

jheek commented Mar 26, 2020

This is part of a larger issue concerning mixing states and jax transformations. nn.stochastic should throw an exception in this case because mixing jax transformations and internal state are ambigious. I will make a PR for this but it might lead to some false positives that need to be fixed.

@avital
Copy link
Contributor

avital commented Mar 27, 2020

I think #125 is the PR that should address this.

Effectively, it should make your code @adarob throw an explicit error, and then you can decide how to deal with the PRNGs. E.g. if you're using vmap you will have to explicitly choose whether you split them or reuse the PRNG.

@jheek
Copy link
Member

jheek commented Mar 27, 2020

Btw we are also looking into automatically supporting things like stateful and stochastic in combination with jax transforms together with the Haiku folks and the jax core team. But for know we just try to avoid silent errors

@avital
Copy link
Contributor

avital commented Apr 1, 2020

@jheek assigning to you because I believe you're looking into this

@jheek
Copy link
Member

jheek commented Apr 2, 2020

nn.stochastic correctly throws an error but it does now extend into init_by_shape (as of PR #159).

@jheek jheek closed this as completed Apr 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants