Which XlaOp
s are used for random number generation in JAX?
#22708
-
I created an XLA issue to try to understand this here. I'm interested in understanding how a high-level API can be built on top of XLA in general, and RNGs are essential to this. In summary, we have the op In contrast, in JAX, every RNG function takes a specific seed and generates a specific distribution. I'd like to understand how this works at a lower level and I'm not sure where to look. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The random generators and distributions are implement completely in JAX, see the jax/_src/random.py file. For example, see here for the code for If you are interested in seeing the XLA code for random number generation, you can dump the MLIR code resulting from JIT compilation of the particular random function. Example: import jax
key = jax.random.key(0)
print(jax.jit(jax.random.normal).lower(key).as_text()) This will produce something like this (note that the body of the
|
Beta Was this translation helpful? Give feedback.
The random generators and distributions are implement completely in JAX, see the jax/_src/random.py file. For example, see here for the code for
jax.random.normal
.If you are interested in seeing the XLA code for random number generation, you can dump the MLIR code resulting from JIT compilation of the particular random function. Example:
This will produce something like this (note that the body of the
_normal_real
function indeed corresponds to the JAX code for_normal_real
):