-
Notifications
You must be signed in to change notification settings - Fork 193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backend-agnostic stateless PRNGs #333
Comments
Leaving some notes here JAX PRNG https://github.com/google/jax/blob/537e35b0fa2c2126cdd22a2e346b65ce11395f80/jax/_src/prng.py If we set the RngSeed in XLA, https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor.cpp#L346, would this apply to all future kernels as well (as it seems to be set on the DeviceCtx)? Would future kernels reset seed to 0 if the context does not restart? (Just thinking of a temporary quick stopgap that can be depreciated later once proper PRNG support rolls out) Implementing the JAX way would require all calls to random functions inside |
JAX PRNG reference: https://github.com/google/jax/blob/main/docs/design_notes/prng.md |
As discussed in #331, we need an approach for having portable PRNGs that can be applied for any backend.
Originally posted by @seanmor5 in #331 (comment)
The text was updated successfully, but these errors were encountered: