Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend-agnostic stateless PRNGs #333

Closed
polvalente opened this issue Mar 13, 2021 · 2 comments
Closed

Backend-agnostic stateless PRNGs #333

polvalente opened this issue Mar 13, 2021 · 2 comments
Milestone

Comments

@polvalente
Copy link
Contributor

As discussed in #331, we need an approach for having portable PRNGs that can be applied for any backend.

We definitely need the ability to seed RNGs, and I'm confident this approach will work for Torchx and other backends, but I'm not 100% sure how it will extend to EXLA and other Tensor compilers that can't depend on a stateful RNG. As an example, we can set the RNG seed as an executable run option in EXLA, but that would apply to the entire execution and not individual calls to random_uniform within defn. It would also need to be passed as a compile option rather than directly to calls to random_x.

I think it probably is best to move forward instead with Jax-style stateless PRNGs because we can implement them with our current API and have a solution that extends to every compiler and backend. The JAX PRNGs also tout themselves as perfect for distribution/parallel computing, and I think that aligns with some of our future goals. In order to rework this from the EXLA perspective, it would probably involve using RngBitGenerator and reimplementing our current random functions in terms of that and other primitives.

Originally posted by @seanmor5 in #331 (comment)

@vans163
Copy link
Contributor

vans163 commented Apr 2, 2022

Leaving some notes here

JAX PRNG https://github.com/google/jax/blob/537e35b0fa2c2126cdd22a2e346b65ce11395f80/jax/_src/prng.py
uses ThreeFry https://bashtage.github.io/randomgen/bit_generators/threefry.html which allows parallel application use cases via its key system
XLA RngBitGenerator implementation https://github.com/google/jax/blob/537e35b0fa2c2126cdd22a2e346b65ce11395f80/jax/_src/prng.py#L543
https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator

If we set the RngSeed in XLA, https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor.cpp#L346, would this apply to all future kernels as well (as it seems to be set on the DeviceCtx)? Would future kernels reset seed to 0 if the context does not restart? (Just thinking of a temporary quick stopgap that can be depreciated later once proper PRNG support rolls out)

Implementing the JAX way would require all calls to random functions inside defn to pass along the stateless PRNG context to each call?

@josevalim
Copy link
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants