Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding random number generation API #431

Closed
jakirkham opened this issue May 13, 2022 · 51 comments
Closed

Adding random number generation API #431

jakirkham opened this issue May 13, 2022 · 51 comments
Labels
RFC Request for comments. Feature requests and proposed changes.

Comments

@jakirkham
Copy link
Member

jakirkham commented May 13, 2022

In conjunction with the Array API, often need some form of random number generation as well. Since this would generate Arrays and be used with Arrays the user has, there is some benefit from downstream users perspective if there is a standard API for handling random number generation.

For example in scikit-learn, the following discussion ( scikit-learn/scikit-learn#22352 (comment) ) may shed some light. A couple points that came up were whether to have a state or stateless API and NumPy's old and new APIs.

@leofang
Copy link
Contributor

leofang commented May 18, 2022

A couple points that came up were whether to have a state or stateless API and NumPy's old and new APIs.

I'd like to quote @rgommers from there (scikit-learn/scikit-learn#22352 (comment)) because I feel it clarified a lot:

It's not a stateless PRNG (there's no such thing)

@jakirkham
Copy link
Member Author

Yeah was just summarizing the discussion up to that point.

Stateless isn't really the right term for this I think. A more correct description might be functional as many functional languages use a similar strategy. IOW something like this

state0 = seed(42)
state1, n = randint(state0)
state2, m = randint(state1)

IOW the user is responsible for tracking the state in this model and threading it through subsequent random number generation function calls.

The other model is object oriented like NumPy's where state lives in some object that gets mutated each time a random number is generated.

@rgommers
Copy link
Member

We had a previous discussion on PRNG APIs, and IIRC no one was attached to the legacy numpy.random API (which most libraries currently have as their default - although TensorFlow is quite different: tf.random.Generator, also class-based).

The new NumPy API seems a little more user-friendly than the JAX APIs, and offers more functionality too. And it's no less safe - I think it's just as well possible to shoot yourself in the foot with JAX: just forget to manually create a new sub-key once. This is more of a philosophical difference: there's only one way of doing things in JAX, which is more verbose but the same for serial and parallel execution - while NumPy has a more concise way for serial (its default mode), but the user must remember to use a second method (equivalent to jax.random.split) to get the right parallel behavior.

One advantage of the JAX API is that there's only one way to do things, while NumPy has multiple ways to deal with parallelism - so more functionality, but also harder to understand or standardize. That may be partly due to providing multiple PRNG algorithms though, while JAX provides only the Threefry algorithm (not entirely true, there's a second experimental one - XLA Random Bit Generator). PyTorch for example provides Philox (which JAX also could have used) and MT19937. Now a standard doesn't have to deal with exact reproducibility across libraries, but it should allow a design that allows libraries to choose their own algorithms. I think this works with either JAX's key (which contains an algorithm selector), or NumPy's class-based API. NumPy's SeedSequence is like the seeding part of JAX's key, and should work with any algorithm too. NumPy's .jumped cannot be supported by all algorithms AFAIK.

The main issues for standardization I see are:

  • (a) JAX cannot adopt any class-based API,
  • (b) NumPy probably isn't going to be very receptive to yet another API, and
  • (c) it'd be a lot of work for other libraries like PyTorch to switch from their old API to any other API.

tl;dr this will not be easy to standardize

@jakirkham
Copy link
Member Author

Thanks for the detailed write up Ralf! 🙏

FWIW Dask is already working on adopting the new NumPy API ( dask/dask#9038 )

It looks like CuPy already did this in 9.0.0 with PR ( cupy/cupy#4177 ) (though Leo should feel free to correct me)

That all being said, maybe it is worth asking if a subset of NumPy's API might be easier to adopt and if so what that looks like.

Also the other important question here is what API is going to be most useful for downstream users. Started with scikit-learn as they make a lot of use of random number generation and creating some usable API with scikit-learn would be a win. Though maybe there are other downstream libraries that would make sense as well (perhaps statsmodels? others?).

@oleksandr-pavlyk
Copy link
Contributor

oleksandr-pavlyk commented May 18, 2022

@leofang While there is no stateless PRNG, HW-based true random number generators have empty state (MKL's NONDETERM basic random number generator) and these should be supported by the spec as well.

@rgommers
Copy link
Member

rgommers commented Jun 1, 2022

To get a better feel for the tradeoffs, here is code for a couple of things:

  1. Serial and parallel (multiprocessing) usage for numpy.random and jax.random
  2. An implementation of the JAX API (PRNGKey, split, uniform) on top of the numpy.random infrastructure. Vice-versa is not possible.

For all the below code in a single gist, see here.

import secrets
import multiprocessing

import numpy as np
import jax


USE_FIXED_SEED = False

if USE_FIXED_SEED:
    seed = 38968222334307
else:
    # Generate a random high-entropy seed for use in the below examples
    # jax.random.PRNGKey doesn't accept None to do this automatically
    seed = secrets.randbits(32)  # JAX can't deal with >32-bits


# NumPy serial
rng = np.random.default_rng(seed=seed)
vals = rng.uniform(size=3)
val = rng.uniform(size=1)

# NumPy parallel
sseq = np.random.SeedSequence(entropy=seed)
child_seeds = sseq.spawn(4)
rngs = [np.random.default_rng(seed=s) for s in child_seeds]

def use_rngs_numpy(rng):
    vals = rng.uniform(size=3)
    val = rng.uniform(size=1)
    print(vals, val)

def main_numpy():
    with multiprocessing.Pool(processes=4) as pool:
        pool.map(use_rngs_numpy, rngs)


# JAX serial (also auto-parallelizes fine by design)
key = jax.random.PRNGKey(seed)
key, subkey = jax.random.split(key)  # this one could be left out, but best practice is probably to always use `split` first
vals = jax.random.uniform(subkey, shape=(3,))
key, subkey = jax.random.split(key)  # don't forget this!
val = jax.random.uniform(subkey, shape=(1,))


# JAX parallel with multiprocessing
def use_rngs_jax(key):
    key, subkey = jax.random.split(key)
    vals = jax.random.uniform(subkey, shape=(3,))
    key, subkey = jax.random.split(key)
    val = jax.random.uniform(subkey, shape=(1,))
    print(vals, val)


def main_jax():
    key = jax.random.PRNGKey(seed)
    key, *subkeys = jax.random.split(key, 5)  # gotcha: "5" gives us 4 subkeys
    with multiprocessing.Pool(processes=4) as pool:
        pool.map(use_rngs_jax, subkeys)


if __name__ == '__main__':
    # JAX does not work with the default `fork` (due to internal threading)
    multiprocessing.set_start_method('forkserver')

    print('\nNumPy with multiprocessing:\n')
    main_numpy()
    print('\n\nJAX with multiprocessing:\n')
    main_jax()

JAX does seem to have a few gotchas with seed creation, it can't deal with high-entropy seeds apparently (at least in 0.2.27, released 18 Jan 2022):

In [24]: seed = secrets.randbits(64)

In [25]: jax.random.PRNGKey(seed)
---------------------------------------------------------------------------
OverflowError                             Traceback (most recent call last)
<ipython-input-25-7a8d328c270c> in <module>
----> 1 jax.random.PRNGKey(seed)

~/anaconda3/envs/many-libs/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed)
     57   # Explicitly cast to int64 for JIT invariance of behavior on large ints.
     58   if isinstance(seed, int):
---> 59     seed = np.int64(seed)
     60   # Converting to jnp.array may truncate bits when jax_enable_x64=False, but this
     61   # is necessary for the sake of JIT invariance of the result for such values.

OverflowError: Python int too large to convert to C long

In [26]: seed = secrets.randbits(128)

In [27]: jax.random.PRNGKey(seed)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-27-7a8d328c270c> in <module>
----> 1 jax.random.PRNGKey(seed)

~/anaconda3/envs/many-libs/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed)
     53     raise TypeError(f"PRNGKey seed must be a scalar; got {seed!r}.")
     54   if not np.issubdtype(np.result_type(seed), np.integer):
---> 55     raise TypeError(f"PRNGKey seed must be an integer; got {seed!r}")
     56
     57   # Explicitly cast to int64 for JIT invariance of behavior on large ints.

TypeError: PRNGKey seed must be an integer; got 67681183633192462759155065893448052088

In [28]: seed = secrets.randbits(64)

In [29]: jax.random.PRNGKey(seed)
---------------------------------------------------------------------------
OverflowError                             Traceback (most recent call last)
<ipython-input-29-7a8d328c270c> in <module>
----> 1 jax.random.PRNGKey(seed)

~/anaconda3/envs/many-libs/lib/python3.7/site-packages/jax/_src/random.py in PRNGKey(seed)
     57   # Explicitly cast to int64 for JIT invariance of behavior on large ints.
     58   if isinstance(seed, int):
---> 59     seed = np.int64(seed)
     60   # Converting to jnp.array may truncate bits when jax_enable_x64=False, but this
     61   # is necessary for the sake of JIT invariance of the result for such values.

OverflowError: Python int too large to convert to C long

In [30]: seed = secrets.randbits(32)

In [31]: jax.random.PRNGKey(seed)
Out[31]: DeviceArray([         0, 3279739543], dtype=uint32)

A JAX-style API with numpy.random under the hood:

"""
Implement `jax.random` APIs with NumPy, and `numpy.random` APIs with JAX.
The purpose of this is to be able to compare APIs more easily, and clarify
where they are and aren't similar.
"""

import secrets
import multiprocessing

import numpy as np
import jax


USE_FIXED_SEED = False

if USE_FIXED_SEED:
    seed = 38968222334307
else:
    # Generate a random high-entropy seed for use in the below examples
    # jax.random.PRNGKey doesn't accept None to do this automatically
    seed = secrets.randbits(32)  # JAX can't deal with >32-bits



def PRNGKey(seed):
    """
    Create a key from a seed. `seed` must be a 32-bit (or 64-bit?) integer.
    """
    # Note: selecting a non-default PRNG algorithm is done via a global config
    #       flag (not good, should be a keyword or similar ...)
    seed = np.random.SeedSequence(seed)
    rng = np.random.default_rng(seed)
    key = (seed, rng)
    return key


def split(key, num=2):
    """
    Parameters
    ----------
    key : tuple
        Size-2 tuple, the first element a `SeedSequence` instance, the second
        containing the algorithm selector.
    num : int, optional
        The number of keys to produce (default: 2).

    Returns
    -------
    keys : tuple of 2-tuples
        `num` number of keys (each key being a 2-tuple)
    """
    seed, rng = key
    child_seeds = seed.spawn(num)
    keys = ((s, rng) for s in child_seeds)
    return keys


def uniform(key, shape=(), dtype=np.float64, minval=0.0, maxval=1.0):
    seed, rng = key
    # Creating a new Generator instance from an old one with the same
    # underlying BitGenerator type requires using non-public API:
    rng = np.random.Generator(rng._bit_generator.__class__(seed))
    return rng.uniform(low=minval, high=maxval, size=shape).astype(dtype)


def use_jaxlike_api(key=None):
    if key is None:
        key = PRNGKey(seed)

    key, subkey = split(key)
    vals = uniform(subkey, shape=(3,))
    key, subkey = split(key)  # don't forget this!
    val = uniform(subkey, shape=(1,))
    print(vals, val)


def use_jaxlike_api_mp():
    key = PRNGKey(seed)
    key, *subkeys = split(key, 5)
    with multiprocessing.Pool(processes=4) as pool:
        pool.map(use_jaxlike_api, subkeys)


if __name__ == '__main__':
    # JAX does not work with the default `fork` (due to internal threading)
    multiprocessing.set_start_method('forkserver')

    print('\n\nUse JAX-like API (serial):\n')
    use_jaxlike_api()
    print('\n\nUse JAX-like API (multiprocessing):\n')
    use_jaxlike_api_mp()

A couple of thoughts:

  • JAX has several rough edges, like not being able to deal with all Python integers, not working well with multiprocessing, not allowing seed=None, and the num keyword to split being odd (to use 4 subkeys in your parallel generation, you need to say num=5).
    • The first one and last two are fairly minor.
    • The multiprocessing one is quite annoying, but an inherent issue due to JAX's and multiprocessing's internals not getting along, and in principle unrelated to API design.
  • Tradeoff: NumPy's API is more concise for both serial usage and parallel usage with multiprocessing. However, JAX's API works naturally with auto-parallelization/threading, while with NumPy the user has to switch to a completely different method using SeedSequence (note: .jumped is similar, doesn't change this tradeoff).
    • How do you shoot yourself in the foot with JAX? Forget to use split() before drawing random numbers.
    • How do you shoot yourself in the foot with NumPy? For serial use (NumPy's default): safe, no footguns. For parallel use: quite easy to shoot yourself in the foot, users have to realize that they need to switch paradigm and then fiddle with SeedSequence or .jumped.
  • Creating new BitGenerator and Generator instances each time one wants to generate more random numbers (see the uniform implementation) will have significant overhead. I haven't measured, but I'm fairly sure it matters. This can probably not be avoided with the current design, would need a larger change in numpy.random to make it efficient I think (disclaimer: I haven't thought about this super hard).

One other thing to point out: the JAX docs comparing to NumPy are wildly outdated/unfair, they use the non-recommended (global state) way of using the legacy API. In general, numpy.random provides a superset of capabilities of those of jax.random, as evidenced by how easy it is to implement a JAX-like API on top of the current numpy.random. Related: see this informative comment by @rkern for more thoughts on jax.random vs. numpy.random. In particular this comment:

Just to provide some mathematical background, Jax's PRNG is in the same weak-crypto family as our Philox BitGenerator. The method by which it splits only works well for that family of weak-crypto PRNGs because that family keeps its initial seed around as the key value and only evolves a counter as one draws numbers from it. The other PRNGs iterate the state.

My tentative conclusions based on the above:

  • If consumer libraries want to write portable code across NumPy, JAX, PyTorch and other libraries with NumPy-matching APIs (Dask, CuPy), then there's two options:
    • (1) A shim like the above JAX-like API.
    • (2) Use the numpy.random API as the common case, and simply special-case JAX.
    • It will be quite an effort to teach everyone (1), however I'm not sure that's a show-stopping concern since the new-style numpy.random.Generator API is still not used much (even experienced scikit-learn devs weren't aware of it in a recent thread). I've given two talks that covered the new numpy.random.Generator API, and each time many people were surprised (and enthusiastic) - so there's low uptake. Which would be nice to improve on either way.
    • That said, (2) is going to be more performant, given that split only works well for counter-based algorithms, and special-casing JAX is not that much work.
  • The JAX API needs some improvements for the rough edges, but is in principle easier to teach.
  • Regarding @jakirkham's question "That all being said, maybe it is worth asking if a subset of NumPy's API might be easier to adopt and if so what that looks like.": I think SeedSequence is probably the thing to include for parallelism; having multiple ways of doing parallelism is nice for power users, but probably confusing to many others.

This is all a little nontrivial, so let me ping a few folks for input: @rkern for design/implementation thoughts and whether I missed anything important related to the NumPy implementation. @shoyer, @apaszke for thoughts from the JAX side.

@rkern
Copy link

rkern commented Jun 1, 2022

  • For parallel use: quite easy to shoot yourself in the foot, users have to realize that they need to switch paradigm and then fiddle with SeedSequence or .jumped.

I'll need to spend more time to read the whole thread, but I will add here that we have always had a plan to lift .spawn() up to Generator. We left it out in the initial release to build some more comfort with SeedSequence spawning before we committed to it. But I think it's time. It would make it a lot easier to avoid the foot-guns without having to go back and change a lot of code to pass around more state.

@rkern
Copy link

rkern commented Jun 1, 2022

Specifically, this code:

# NumPy parallel
sseq = np.random.SeedSequence(entropy=seed)
child_seeds = sseq.spawn(4)
rngs = [np.random.default_rng(seed=s) for s in child_seeds]

would become:

rngs = rng.spawn(4)

@rgommers
Copy link
Member

rgommers commented Jun 1, 2022

I'll need to spend more time to read the whole thread, but I will add here that we have always had a plan to lift .spawn() up to Generator.

Thanks - I think that would indeed be quite helpful!

@jakirkham
Copy link
Member Author

Sorry for a bit of a tangent here, but is calling .spawn() incrementally valid? Like

rng1 = rng.spawn(1)
rng2 = rng.spawn(1)
...
rngn = rng.spawn(1)

This can come up when new children processes are created/destroyed on-demand (IOW autoscaling).

@seberg
Copy link
Contributor

seberg commented Jun 1, 2022

Yes it is. More importantly:

rng3 = rng2.spawn()

is also valid. You can spawn any children to get more independent streams. There will never be a collision (at least not within reasonable probabilities).

@rgommers
Copy link
Member

rgommers commented Jun 2, 2022

I went searching through some older notes and found this from @alextp: "I think the direction in TensorFlow is to follow JAX. Have functions for stateless random generation. They compose well. But they are not ergonomic. Layer on top of this will be something stateful. Once you do this, you introduce checkpoints etc, for determinism. So I would OK to add stateless API, but wouldn't be okay to add stateful one." That's from more than a year ago, so I don't know if it has been implemented like that, or anything changed in the meantime in TensorFlow. Maybe @edloper you can tell us?

@rkern
Copy link

rkern commented Jun 2, 2022

I'm perfectly content with having 0 PRNG APIs in the standard (and far prefer 0 in the standard to having 2 in the standard). It seems like there is significant variety in what different communities need and want.

@shoyer
Copy link
Contributor

shoyer commented Jun 2, 2022

I think it's equally straightforward to write an explicitly stateful RNG system like numpy.random.Generator using JAX. Here's a prototype:

import jax


class JaxGenerator:
  def __init__(self, state):
    self.state = state

  def uniform(self):
    self.state, key = jax.random.split(self.state)
    return jax.random.uniform(key)
  
  def spawn(self, count):
    self.state, *keys = jax.random.split(self.state, count + 1)
    return [JaxGenerator(key) for key in keys]

  def __repr__(self):
    return f'{type(self).__name__}(state={self.state})'


def jax_default_rng(seed):
  return JaxGenerator(jax.random.PRNGKey(seed))


rng = jax_default_rng(0)
print(rng.uniform())  # 0.10536897
print(rng.uniform())  # 0.2787192
print(rng)  # JaxGenerator(state=[2384771982 3928867769])

rng2 = JaxGenerator(rng.state)
print(rng2)  # JaxGenerator(state=[2384771982 3928867769])

rngs = rng.spawn(2)
print(rngs)  # [JaxGenerator(state=[1777981902 3244208681]), JaxGenerator(state=[ 669635267 2816531647])]

From a JAX design perspective, stateful RNGs like this are not encouraged, because JAX's function transforms will break if you pass stateful objects into them. But you can still safely use this sort of API with JAX, as long as you're careful to create Generator objects inside pure functions, e.g.,

@jax.jit
@jax.vmap
def batched_random_uniform(seed):
  return jax_default_rng(seed).uniform()

print(batched_random_uniform(jnp.arange(5)))
# DeviceArray([0.10536897, 0.12568676, 0.4336251 , 0.47652578, 0.7844808 ],            dtype=float32)

Or explicitly keeping track of updated RNG state:

@jax.jit
def explicit_state_random_uniform(state):
  rng = JaxGenerator(state)
  sample = rng.uniform()  # must happen *before* calculating new_state 
  new_state = rng.state
  return new_state, sample

state = jax.random.PRNGKey(0)
state2, sample = explicit_state_random_uniform(state)
print(state2, sample)

In fact, Haiku, which is one of the most popular neural net libraries in JAX, does something very similar with haiku.next_rng_key.

These stateful interfaces are perhaps easier to misuse with JAX's functional transforms than pure functions, but are still a major improvement over global state.

@shoyer
Copy link
Contributor

shoyer commented Jun 2, 2022

To bridge the gap between NumPy's and JAX's random number APIs, I would suggest slightly extending NumPy's API so it's easier to be explicit about state. Namely, we should add the spawn() method, and also support directly accessing/setting .state on Generator objects, rather than only on BitGenerator.

Here's my minimal wrapper of NumPy's RNG API with these extensions, to match the JAX API above:

import numpy as np


class NumpyGenerator:

  def __init__(self, state):
    # TODO: avoid initializing this dummy Generator/BitGenerator?
    self._rng = np.random.default_rng()
    self.state = state

  def uniform(self):
    return self._rng.uniform()
  
  def spawn(self, count):
    entropy = self._rng.integers(2**63)  # TODO: better entropy?
    sseq = np.random.SeedSequence(entropy)
    child_seeds = sseq.spawn(count)
    return [numpy_default_rng(seed) for seed in child_seeds]

  @property
  def state(self):
    return self._rng.bit_generator.state

  @state.setter
  def state(self, value):
    self._rng.bit_generator.state = value

  def __repr__(self):
    return f'{type(self).__name__}(state={self.state})'


def numpy_default_rng(seed):
  return NumpyGenerator(np.random.default_rng(seed).bit_generator.state)


rng = numpy_default_rng(0)
print(rng.uniform())  # 0.6369616873214543
print(rng.uniform())  # 0.2697867137638703
print(rng)  # NumpyGenerator(state={'bit_generator': 'PCG64', 'state': {'state': 143609658456486183636066271097634410721, 'inc': 87136372517582989555478159403783844777}, 'has_uint32': 0, 'uinteger': 0})

rng2 = NumpyGenerator(rng.state)
print(rng2)  # NumpyGenerator(state={'bit_generator': 'PCG64', 'state': {'state': 143609658456486183636066271097634410721, 'inc': 87136372517582989555478159403783844777}, 'has_uint32': 0, 'uinteger': 0})

rngs = rng.spawn(2)
print(rngs)  # [NumpyGenerator(state={'bit_generator': 'PCG64', 'state': {'state': 164868448360684748498847325894109072011, 'inc': 241524822143570234404080558697197945801}, 'has_uint32': 0, 'uinteger': 0}), NumpyGenerator(state={'bit_generator': 'PCG64', 'state': {'state': 219573644955839246335449654370252341036, 'inc': 175512849395095609630857841553467033115}, 'has_uint32': 0, 'uinteger': 0})]

@jakirkham
Copy link
Member Author

It's nice to see the alternative wrapping for perspective

Namely, we should add the spawn() method, and also support directly accessing/setting .state on Generator objects, rather than only on BitGenerator.

Am curious how much work it would be to add these to NumPy?

@rkern
Copy link

rkern commented Jun 3, 2022

Not much.

@jakevdp
Copy link

jakevdp commented Jun 3, 2022

One note regarding JAX:

JAX does seem to have a few gotchas with seed creation, it can't deal with high-entropy seeds apparently (at least in 0.2.27, released 18 Jan 2022):

JAX disables 64-bit data types by default, but you can enable this if you wish to use them:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

If you enable 64-bit values, then 64-bit seeds will be fine.

@jakevdp
Copy link

jakevdp commented Jun 3, 2022

One other thing to point out: the JAX docs comparing to NumPy are wildly outdated/unfair, they use the non-recommended (global state) way of using the legacy API.

I'm happy to update these if you think it would be helpful. In my experience, the global default rng is how most users use numpy's random APIs, despite more recent changes to recommendations in numpy's docs, so it's a useful way to introduce how JAX differs. As for the JAX discussion, whether the seed is global or a mutated rng object, I think the bulk of the discussion (regarding side effects, implicit vs. explicit statefulness / pure functions) still holds true.

@edloper
Copy link

edloper commented Jun 3, 2022

I think the recommendation from the TF side would still be to use the stateless API. (As mentioned above, a stateful API can be layered on top of the stateless one, if desired.)

@rkern
Copy link

rkern commented Jun 4, 2022

We can build shim APIs across either implementation to get stateless or stateful APIs, but the way that each shim API needs to be implemented has its own costs, constraints, and tradeoffs. I would suggest exploring those implementation strategies and how well they actually support a set of use cases before relying on the mere existence proof of shim APIs as a reason to go ahead and standardize on one API instead of another. What would the shim APIs actually allow us to do?

Let's say a stateless API becomes the standard. data = xp.uniform(key, low, high), etc. We can certainly build a stateful shim API on top that just uses the standard and manages the key splitting internally. What does this shim let us do? Can I write some Monte Carlo routines using that stateful shim API (completely ignorant of JAX) and still use that routine in JAX code? @shoyer suggests that has limitations. It seems like I'd be taking all the compromises necessary to make a stateless API without getting the corresponding benefits. Splittable PRNG schemes don't grow on trees, and are difficult to design robustly; the API standard would have to define an implementation as well, and I don't think I can recommend JAX's after investigating it. I can go into that in more detail here or somewhere else, if anyone would like.

Vice versa, we can definitely make a stateless API on top of the current implementation of np.random with just a bit of manual manipulation of SeedSequence inside the xp.split() function (SeedSequence can also be minorly rewritten to support more efficient pure-function splitting). Now that might allow me, a pure numpy user, to use a sampling routine that was written with JAX in mind (but otherwise restricted to the xp API). But making the stateful API standard just to allow an implementation to wrap a stateless API around it doesn't solve JAX's problem of not wanting to have a stateful API implementation in the first place (and you should not have to, IMO).

Again, I'm pretty happy for there to be no API standard on this subject. I would be interested in talking more on how to hand off from one style of API to the other, though. There's room in the SeedSequence design space to make easier.

@oleksandr-pavlyk
Copy link
Contributor

I can go into that in more detail here or somewhere else, if anyone would like.

I vote for presenting them here, please.

@rkern
Copy link

rkern commented Jun 6, 2022

It's, uh, getting long. Not because my list of issues is long, just that I am long-winded and am including a lot of background.

@rkern
Copy link

rkern commented Jun 6, 2022

Apologies

This is long, and everyone has my apologies for that, but I want to make sure the background is laid out.

Introduction

So first off, I want to say that the issues I am going to lay out don't make JAX's PRNG a bad one for all purposes. JAX has a fairly specific usage profile, and these are limitations that can be usefully lived within, if one is willing to. But these are limitations that I would hesitate to propagate up to general use amongst all Array API users, particularly in contexts where one doesn't get the benefits paid for by those limitations. I will mention a number of issues that I have with specific implementation choices that are in principle addressable. I don't really count those as determinative (they can indeed be fixed while still maintaining the style of API). But I do think they can be taken, as a whole, that imposing this style of API places some serious constraints on implementations that is probably out of scope of this standardization effort.

Reviewing JAX's Design

To review, JAX's PRNG API is based around explicit splitting of the PRNG state and then pure functions to generate possibly-large arrays of data from the leaf keys. The split(prng_key, num=2) function takes prng_key and derives num new PRNG keys from it. Generally, you'll take the leftmost PRNG key and keep it around for later and then pass the other split PRNG keys down to functions (either those in jax.random or others that will do further split()s). The core PRNG algorithm for actually drawing random bits in those jax.random functions comes from Parallel Random Numbers: As Easy as 1, 2, 3 ([Random123]). In particular, it uses the threefry2x32 weak cryptographic block cipher as a primitive. This is a block cipher that takes a 64-bit plaintext block and encrypts it using the given 64-bit cipher key to get a 64-bit ciphertext. For this variant, all of the 64-bit blocks are broken up into 2 32-bit words (hence the name) so we can use cheap 32-bit arithmetic.

cipher_text = threefry2x32(cipher_key, plain_text)

This primitive is a keyed bijection: given a fixed cipher_key, there is an invertible one-to-one mapping between plain_text and cipher_text. Changing the cipher_key, you select a different bijection.

For the core PRNG algorithm to draw random bits, we use a simple incrementing counter as each plaintext block. This is nice for GPUs because we can create that counter array and then have the GPU run threefry2x32 in parallel across the whole thing. The JAX implementation of threefry_random_bits() which does this operation has some quirks that don't make too much of a difference for the quality of the pseudorandom numbers from a single call, but will show up a bit later when we talk about splitting.

In particular, instead of incrementing the counter in 64-bit blocks, it creates a 32-bit counter array, splits it in half to use the first half as the upper 32-bit word and the second half as the lower 32-bit word. That's a bit wasteful (you could instead just use zeros(rem//2) and lax.iota(rem//2) for rem requested words instead), but they are valid inputs. The two arrays which are output (being separately the upper and lower words of the cipher_text blocks) are then concatenated together end-to-end instead of reassembling those words. I assume this is done because it's cheaper to concatenate than to interleave again. This is fine as far as it goes; it doesn't hugely affect the quality of numbers coming out of one jax.random.uniform() call, for example. But it does wreck the bijection properties, which will become important when we talk about implementation choices for split().

So far, so good. Despite that quirk, I have only one qualm with the core PRNG scheme for drawing bits in the jax.random functions: 64 bits is quite small. I consider 128-bit PRNGs to be de rigeur today, particularly in contexts where we are talking about large amounts of parallelism (and the JAX API enforces a large amount of parallelism). By itself, that doesn't have to be too much of an issue, and you get the benefit of a cheaper PRNG by using this smaller variant. But it does place a limit on the number of jax.random calls that you can make with different PRNG keys. Unless if you have a tightly controlled mechanism for allocating those PRNG keys to ensure that they are distinct, you are subject to Birthday Collisions, which happen at around the square-root of the state space size, in this case, about 2**32. That's a lot! But you probably want to stay well under that, maybe even the square root of the Birthday Bound (2**16, which is not a lot). Big distributed reinforcement learning runs should definitely be wary.

The mechanism that JAX uses to split() reuses this implementation of threefry_random_bits(). For the moment, let's just talk about split(prng_key, num=2): it's convenient to label the two outputs as "left" and "right". In fact, let's define two operations left() and right() as follows:

def left(prng_key):
    return jax.random.split(prng_key)[0]

def right(prng_key):
    return jax.random.split(prng_key)[1]

None of the following really depend on restricting ourselves to that, but it's handy to talk about things.

So we use the current prng_key to encrypt the 2 64-bit counter blocks and get out 2 64-bit ciphertexts to use as the left and right keys (modulo the quirks discussed above). This follows the general treatment in Splittable pseudorandom number generators using cryptographic hashing ([Splittable2013]) which uses the full-strength cryptographic Threefish block cipher ([Random123]'s Threefry is a weakened variant of this cipher). That is, we use the old key as the cipher_key, encrypt a counter block that's different whether it's the left() or right() branch, and take the resulting cipher_text as the next prng_key for that branch. So we're feeding back the output of this algorithm as an input to the next iteration. This is a standard construction for cryptographic hashes. In fact, people do use this construction as a way to make sequential PRNGs, essentially keep appending 0 blocks to a cryptographic hash and yielding the internal state of that hash as the random bits (equivalently, keep calling left() on its own output).

It has nice properties for cryptographic purposes, but it has some significant drawbacks for PRNGs used for scientific and statistical purposes. left() defines a non-invertible mapping (right() defines a different one). As discussed above, a block cipher defines a keyed bijection: fixing cipher_key, there is a one-to-one mapping between the plain_text and cipher_text. But if instead you fix plain_text (i.e. we use the same counter block in left() every time), then the mapping from cipher_key to cipher_text is many-to-one, and thus non-invertible. This is important in cryptography; if we could invert the cipher_text to get the cipher_key, then the cipher would allow an attacker to recover the cipher_key in a variety of scenarios. But iterations of non-invertible mappings have some unfortunate effects for scientific PRNGs where those security properties don't matter and other properties become more desirable.

Damn Statistics

While there are few guarantees without knowing something about the structure of the mapping, you can calculate some useful statistics about the population of random non-invertible mappings. Particularly since the left() iteration is constructed from a (weak) block cipher, it is about as good of a random sample from that population as any. @imneme has a good pair of articles about the statistics of non-invertible mappings and also of invertible mappings. The state space of a random non-invertible mapping is not good for a scientific PRNG. You have a forest of trees that filter down to roots that are on cycles of various sizes. When a cycle completes, obviously, you have a Birthday Collision and start repeating values. The "rho length" statistic (the average number of iterations until a repeat is found) reflects that. Increasing the state space reduces the chance of these collisions predictably.

However, such state spaces are inherently biased and non-uniform, and this doesn't go away as the state space increases. The states that follow states which have in-degrees greater than 1 will be overrepresented if you repeat with different initial seeds. States on small cycles will also be overrepresented. Small cycles will have big trees attached to them. A large fraction of states (on the order of half) are not reachable from any other state; you have to start with them as the initial seed to ever observe them. This is particularly an issue when you reduce the initial seeding space to just 32 bits, as is the default configuration of JAX. I will never recommend a PRNG based on non-invertible mappings for scientific use.

So far, I've talked about iterating left(left(left(...))). Why do I care about this, even though we are going to be using the right() branch, too? Well [Splittable2013] also throws up its hands to say that talking about things like periods don't make sense for splittable PRNGs. After all, you are guaranteed to get a collision (by a pigeonhole argument) within only 64 steps down a full binary tree exploring all of the left() and right() branches (which sounds bad, but I'm not particularly concerned). But in practice, our trees are very left-heavy due to certain conventions (we tend to reserve the left-most key to pass on in the data-flow for further splitting and use the right key(s) to pass to be consumed in other functions). After all, to translate a serial PRNG code that calls multiple methods in sequence:

x = rng.uniform(0, 1)
y = rng.uniform(10, 20)
z = rng.uniform(-1, 1)

# -->

key, subkey = split(key)
x = uniform(subkey, 0, 1)
key, subkey = split(key)
y = uniform(subkey, 10, 20)
key, subkey = split(key)
z = uniform(subkey, -1, 1)

So the iteration of left() is kind of important to characterize, IMO. There are things we can say about it, and maybe improve it.

So I did construct a PRNG out of repeated iteration of left() and fed its output to PractRand. And we can observe that we do in fact find one of those cycles and fail right around the Birthday Bound (2**35 bytes corresponds to 2**32 iterations yielding 8-byte keys).

PractRand output for jax.random.split(key)[0]
❯ ./jax_keys.py
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
RNG_test using PractRand version 0.93
RNG = RNG_stdin64, seed = 0x6ae890b2
test set = expanded, folding = extra

rng=RNG_stdin64, seed=0x6ae890b2
length= 8 megabytes (2^23 bytes), time= 2.6 seconds
  no anomalies in 694 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 16 megabytes (2^24 bytes), time= 7.8 seconds
  no anomalies in 747 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 32 megabytes (2^25 bytes), time= 15.6 seconds
  no anomalies in 796 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 64 megabytes (2^26 bytes), time= 28.1 seconds
  no anomalies in 843 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 128 megabytes (2^27 bytes), time= 50.3 seconds
  Test Name                         Raw       Processed     Evaluation
  [Low4/32]Gap-16:B                 R=  -4.7  p =1-5.0e-4   unusual          
  ...and 890 test result(s) without anomalies

rng=RNG_stdin64, seed=0x6ae890b2
length= 256 megabytes (2^28 bytes), time= 92.0 seconds
  no anomalies in 938 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 512 megabytes (2^29 bytes), time= 171 seconds
  no anomalies in 985 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 1 gigabyte (2^30 bytes), time= 326 seconds
  no anomalies in 1039 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 2 gigabytes (2^31 bytes), time= 628 seconds
  no anomalies in 1092 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 4 gigabytes (2^32 bytes), time= 1237 seconds
  no anomalies in 1151 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 8 gigabytes (2^33 bytes), time= 2439 seconds
  no anomalies in 1224 test result(s)

rng=RNG_stdin64, seed=0x6ae890b2
length= 16 gigabytes (2^34 bytes), time= 4831 seconds
  Test Name                         Raw       Processed     Evaluation
  [Low4/64]BCFN_FF(2+2,13-1,T)      R=  +9.4  p =  1.6e-4   unusual          
  ...and 1300 test result(s) without anomalies

rng=RNG_stdin64, seed=0x6ae890b2
length= 32 gigabytes (2^35 bytes), time= 9622 seconds
  Test Name                         Raw       Processed     Evaluation
  BCFN_FF(2+0,13-0,T)               R= +41.3  p =  1.3e-21    FAIL !!        
  BCFN_FF(2+1,13-0,T)               R= +33.9  p =  1.1e-17    FAIL !         
  BCFN_FF(2+2,13-0,T)               R= +44.0  p =  4.4e-23    FAIL !!        
  BCFN_FF(2+3,13-0,T)               R= +48.1  p =  3.1e-25    FAIL !!        
  BCFN_FF(2+4,13-0,T)               R= +44.8  p =  1.6e-23    FAIL !!        
  BCFN_FF(2+5,13-0,T)               R= +48.1  p =  3.0e-25    FAIL !!        
  BCFN_FF(2+6,13-0,T)               R= +40.9  p =  2.2e-21    FAIL !!        
  BCFN_FF(2+7,13-1,T)               R= +29.0  p =  4.5e-15    FAIL           
  BCFN_FF(2+8,13-1,T)               R= +30.6  p =  6.1e-16    FAIL           
  DC6-9x1Bytes-1                    R=+230.8  p =  4.5e-155   FAIL !!!!!     

[RTK: snipping 630 more failures; this is definitely a cycle]

  ...and 717 test result(s) without anomalies

As mentioned in the article, increasing the state size does help with this. You can hold off those collisions more and more with larger state sizes. You can hide this flaw with larger state spaces, as done in the original [Splittable2013] paper, which uses a full-strength 256-bit block cipher.

Just to be sure that this wasn't trivially fixable by fixing the aforementioned quirk, I retried this a version that did the threefry2x32 manually with a fixed plaintext of 0. This also failed at 2**36 bytes (about the same, essentially), so confirming that the general scheme of using the old prng_key as the cipher_key is still a non-invertible mapping.

Alternatives

There are alternatives, which are to use invertible mappings to define left() and right() (and further for num > 2). Random invertible mappings have much nicer expected properties. Every state is on a cycle and has an in-degree of 1. Given a uniform distribution of root seeds, you have an equal chance of seeing any state (if you don't draw enough to hit a cycle). We have a good idea of the distribution of sizes of those cycles. Very, very roughly, in expectation, half of the states are on the biggest cycle, 1/4 are on the next one, 1/8 on the next one and so on. There still are small cycles, but the chance that your initial seed actually places you on one goes down as the size does. For 64-bit states, the chance of your root seed being on a cycle less than 2**32 (that Birthday Bound) is 1/2**(64-32) == 1/2**32, a pretty tiny probability. Whereas for the non-invertible mapping used in JAX, your chance of hitting a repeat in 2**32 draws is about half.

How do we construct useful invertible mappings? One way is to use SeedSequence, whose construction is conceptually very similar to the splitting mechanism here. Instead of a cryptographic hash, it folds data into the internal pool using invertible operations such that bijections are maintained where possible. So repeated iterations of left() or right() will at least have the invertible mapping statistics. There is also an MCG in the middle there driving it, much like SFC64 incorporates a 64-bit counter. That has the effect of creating a minimum cycle size of at least 2**28 (I think) when we repeatedly take the leftmost spawn child.

So for example, I have here a PractRand run where I use a tiny 32-bit pool size (though this is a little misleading as the MCG has to be tracked as well, so it's really a 64-bit state size). I run it like I did the JAX splitting test, where I just repeatedly fold in the 0 counter.

PractRand output for seed_sequence.spawn(1)[0]
❯ ./ss_split.py -p 1
RNG_test using PractRand version 0.93
RNG = RNG_stdin64, seed = 0x64fd6e66
test set = expanded, folding = extra

rng=RNG_stdin64, seed=0x64fd6e66
length= 1 megabyte (2^20 bytes), time= 2.1 seconds
  Test Name                         Raw       Processed     Evaluation
  [Low4/32]FPF-14+6/32:cross        R=  +6.5  p =  2.8e-5   mildly suspicious
  ...and 530 test result(s) without anomalies

rng=RNG_stdin64, seed=0x64fd6e66
length= 2 megabytes (2^21 bytes), time= 6.2 seconds
  no anomalies in 585 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 4 megabytes (2^22 bytes), time= 12.4 seconds
  Test Name                         Raw       Processed     Evaluation
  [Low4/64]FPF-14+6/32:cross        R=  +6.0  p =  7.7e-5   unusual          
  ...and 640 test result(s) without anomalies

rng=RNG_stdin64, seed=0x64fd6e66
length= 8 megabytes (2^23 bytes), time= 22.8 seconds
  no anomalies in 694 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 16 megabytes (2^24 bytes), time= 41.6 seconds
  no anomalies in 747 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 32 megabytes (2^25 bytes), time= 77.3 seconds
  no anomalies in 796 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 64 megabytes (2^26 bytes), time= 147 seconds
  no anomalies in 843 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 128 megabytes (2^27 bytes), time= 280 seconds
  no anomalies in 891 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 256 megabytes (2^28 bytes), time= 544 seconds
  no anomalies in 938 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 512 megabytes (2^29 bytes), time= 1066 seconds
  no anomalies in 985 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 1 gigabyte (2^30 bytes), time= 2106 seconds
  Test Name                         Raw       Processed     Evaluation
  [Low1/64]Gap-16:B                 R=  -5.6  p =1-7.8e-5   unusual          
  ...and 1036 test result(s) without anomalies
rng=RNG_stdin64, seed=0x64fd6e66
length= 2 gigabytes (2^31 bytes), time= 4204 seconds
  no anomalies in 1092 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 4 gigabytes (2^32 bytes), time= 8436 seconds
  no anomalies in 1155 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 8 gigabytes (2^33 bytes), time= 16916 seconds
  no anomalies in 1224 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 16 gigabytes (2^34 bytes), time= 33704 seconds
  no anomalies in 1302 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 32 gigabytes (2^35 bytes), time= 67929 seconds
  no anomalies in 1358 test result(s)

rng=RNG_stdin64, seed=0x64fd6e66
length= 64 gigabytes (2^36 bytes), time= 136044 seconds
  no anomalies in 1434 test result(s)

And it's still going strong past the point that the JAX one failed. Now, to be fair, this definitely created a duplicate entropy pool state somewhere along the way, which would have seeded an identical copy of the BitGenerator if it were used. Only the 32-bit pool plays a role in the seeding of the BitGenerator, not the MCG. If we actually drew a lot of numbers from each BitGenerator, PractRand would have found that out. It's still separately important that generated seeds that go into the number-generating API calls still have an acceptable size to avoid collisions.

Currently, SeedSequence is written statefully. It keeps track of the number of spawned children requested of it so that repeated calls return freshly independent SeedSequences. It is also somewhat wasteful in keeping an explicit representation of the path down the split tree from the root entropy. Neither of these implementation choices are necessary to the algorithm. Here is an implementation with an idempotent .split() method.

Now maybe you don't want to do all that. Maybe you don't like the expense of all of those multiplications; maybe you don't like that extra MCG hanging around taking up space. You have a keyed bijection already: threefry2x32(). But to get a bijection, you have to reverse the roles of the old prng_key and the counter block. Instead of using the old prng_key as the cipher_key and the counter block as the plain_text, flip it.

new_prng_key = threefry2x32(counter, old_prng_key)

Now there is a bijection between the two prng_keys because that's what a block cipher does. Running the repeated left(left(left(...))) iteration as a PRNG is still going strong long past where the original failed.

PractRand output for a ThreeFry2x32 bijection scheme
❯ ./jax_keys.py --variant bijective
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
RNG_test using PractRand version 0.93
RNG = RNG_stdin64, seed = 0x7a16bac8
test set = expanded, folding = extra

rng=RNG_stdin64, seed=0x7a16bac8
length= 8 megabytes (2^23 bytes), time= 2.5 seconds
  no anomalies in 694 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 16 megabytes (2^24 bytes), time= 7.6 seconds
  Test Name                         Raw       Processed     Evaluation
  BCFN_FF(2+1,13-5,T)               R= +12.1  p =  6.5e-5   unusual          
  ...and 746 test result(s) without anomalies

rng=RNG_stdin64, seed=0x7a16bac8
length= 32 megabytes (2^25 bytes), time= 15.2 seconds
  no anomalies in 796 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 64 megabytes (2^26 bytes), time= 27.5 seconds
  no anomalies in 843 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 128 megabytes (2^27 bytes), time= 49.1 seconds
  no anomalies in 891 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 256 megabytes (2^28 bytes), time= 89.2 seconds
  no anomalies in 938 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 512 megabytes (2^29 bytes), time= 166 seconds
  no anomalies in 985 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 1 gigabyte (2^30 bytes), time= 318 seconds
  Test Name                         Raw       Processed     Evaluation
  [Low1/16]Gap-16:B                 R=  -4.8  p =1-3.8e-4   unusual          
  ...and 1035 test result(s) without anomalies

rng=RNG_stdin64, seed=0x7a16bac8
length= 2 gigabytes (2^31 bytes), time= 616 seconds
  Test Name                         Raw       Processed     Evaluation
  [Low4/64]FPF-14+6/4:all           R=  -4.8  p =1-2.8e-4   unusual          
  ...and 1091 test result(s) without anomalies

rng=RNG_stdin64, seed=0x7a16bac8
length= 4 gigabytes (2^32 bytes), time= 1211 seconds
  no anomalies in 1155 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 8 gigabytes (2^33 bytes), time= 2405 seconds
  no anomalies in 1229 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 16 gigabytes (2^34 bytes), time= 4777 seconds
  no anomalies in 1303 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 32 gigabytes (2^35 bytes), time= 9522 seconds
  Test Name                         Raw       Processed     Evaluation
  [Low4/64]BCFN_FF(2+1,13-0,T)      R=  +9.5  p =  1.3e-4   unusual          
  ...and 1359 test result(s) without anomalies

rng=RNG_stdin64, seed=0x7a16bac8
length= 64 gigabytes (2^36 bytes), time= 18960 seconds
  Test Name                         Raw       Processed     Evaluation
  [Low1/32]BCFN_FF(2+0,13-0,T)      R=  +8.5  p =  4.4e-4   unusual          
  [Low4/32]FPF-14+6/32:all          R=  +5.5  p =  1.3e-4   unusual          
  ...and 1429 test result(s) without anomalies

rng=RNG_stdin64, seed=0x7a16bac8
length= 128 gigabytes (2^37 bytes), time= 37487 seconds
  no anomalies in 1506 test result(s)

rng=RNG_stdin64, seed=0x7a16bac8
length= 256 gigabytes (2^38 bytes), time= 74451 seconds
  no anomalies in 1567 test result(s)

Now, that's neat (I think). But we don't just do iterations of left(). We mix up left() and right() (and more). Each of those are different bijections. There is still a chance that a right() call will give you an earlier-seen state. While the structure provided by bijections lets you have a better safety factor on parts of the split tree (important parts, IMO), it doesn't help much for arbitrarily mixed operations. With the statistics of random bijections, I think you could probably derive more accurate bounds, but even with SeedSequence, we tell you to rely on the Birthday Bound in any case. But you still get the statistical benefits of using bijections.

Conclusion

I want to close with a quote from the conclusion of Evaluation of Splittable Pseudo-Random Generators:

Conventional wisdom in computational statistics and Monte Carlo simulation indicate that one PRNG cannot suit every purpose. Different applications have different priorities, and calling for different trade-offs between sequence length, distribution, running time, and theoretical guarantees.

The JAX style of PRNG is a very specific flavor of splittability and pure functionalism. There are splittable PRNGs that are not functional, and there are pure functional APIs that don't force explicit splitting (and of course, useful PRNG schemes that are neither). JAX has a specific profile of needs that make this specific flavor of PRNG very appealing for its use cases. However, JAX's profile is not universal and should not define the API standard, which has a broad scope. For that matter, the profile behind the current design of np.random is not universal either.

I am perfectly content with this functionality not being standardized by the Array API.

@leofang
Copy link
Contributor

leofang commented Jun 6, 2022

If I understand Robert's post above correctly, you're trying to convince us this is one of the rare cases where API design is tightly coupled to the underlying algorithm.

I'd like to add one more argument based on Robert's point (against random number standardization): In practice, the PRNG implementations available to different devices (CPUs, NVIDIA GPUs, AMD GPUs, Intel GPUs, Google TPUs, ...) are very likely different. (@emcastillo from CuPy/PFN has first-hand experience for this pain, as NumPy and CuPy have totally different PRNGs, and in this case CuPy isn't strictly speaking a drop-in replacement of NumPy.)

It then follows that the API standardization must be either PRNG-ignorant (= just give me some random numbers, I don't care how you generate it), or exclude PRNGs from the standard (= define an API to accept vendor-specific PRNGs, but leave out what PRNG we must cover in the standard). Either assumption must hold in order to proceed.

@jakirkham
Copy link
Member Author

It's probably also worthwhile to tie this back to how downstream users would leverage this API.

Going back to scikit-learn, many APIs take a RandomState object (for example). This of course is referencing the old NumPy API, but that's not really the point (as NumPy's new API is also object-based). Instead the point is scikit-learn depends on having an object to pass state around and generate random numbers from different distributions (reusing the example above). Though this isn't unique to scikit-learn, statsmodels also does this (for example).

Given the current usage of the object API by these libraries, it is relatively straightforward to see how they would adopt an object API if added to the spec.

What is less clear is how they would leverage a functional API. If we are interested in evaluating that option, it would be worth seeing how that API would be used in a downstream library. Should add we probably want to look at something a bit more general than JAX's current API given it is tied to a specific class of PRNGs as pointed out in a few places in this issue already.

@rkern
Copy link

rkern commented Jun 6, 2022

If I understand Robert's post above correctly, you're trying to convince us this is one of the rare cases where API design is tightly coupled to the underlying algorithm.

It's not coupled one-to-one, but that choice of API does eliminate a wide swathe of PRNG algorithm choices, and there are only some tricksy options left over.

What I am particularly arguing against is the claim that if we standardize on JAX's style of stateless API that we can just build a stateful API on top of it again for those people that like that API. While this is true, it is not usefully true. The reconstructed stateful API (built on top of the standardized JAX-style stateless API) does not restore the full range of PRNG options that we had before. And I think that there are good reasons to avoid the constraints of the JAX-style stateless API in the cases where we're not getting the benefit of JAX's other capabilities along with it.

The key issue is not so much the raw algorithm that is used inside of the distribution methods (uniform(), normal(), etc.) to compute a bunch of random numbers from a passed-in state. The API standard should not constrain that. I don't think anyone is suggesting that the Array API should try to guarantee that one could get the same numbers from each implementation.

The key issue (so to speak) is how to handle the data flow of the PRNG state. JAX's needs place strong requirements on this data flow in order to get corresponding benefits. Other environments don't have those benefits and thus don't have those requirements. Embedding those strong requirements in the standard imposes those costs on everyone. It's not just the syntax sugar of how the Generator API looks.

@rkern
Copy link

rkern commented Jun 8, 2022

Another way to think about it is that all PRNG schemes have a certain finite amount of safety margin, and it is a consumable resource. Just serially drawing arrays of numbers from it consumes a small amount of that safety margin. Splits, no matter how implemented, consume a lot of that safety margin.

My general recommendation is to only split when you have to, at the place where you need parallelism. In JAX programs, that's everywhere, because JAX is awesome at enabling incredible amounts of parallelism without having to explicitly code it. A splitting-central PRNG scheme fits very well in that environment. The parallelism pays back what you spent in terms of the lowered safety margin. But for most of the other Array API implementations, forcing that amount of splitting is forcing everyone to live with that impaired safety margin for no corresponding benefit.

The issue is not merely the ergonomics of the API. Simply wrapping a stateful API around a splitting-central stateless API does not restore the safety margin built into the more typical stateful PRNG implementations.

@rgommers
Copy link
Member

rgommers commented Jun 8, 2022

One other thing to point out: the JAX docs comparing to NumPy are wildly outdated/unfair, they use the non-recommended (global state) way of using the legacy API.

I'm happy to update these if you think it would be helpful. In my experience, the global default rng is how most users use numpy's random APIs, despite more recent changes to recommendations in numpy's docs, so it's a useful way to introduce how JAX differs. As for the JAX discussion, whether the seed is global or a mutated rng object, I think the bulk of the discussion (regarding side effects, implicit vs. explicit statefulness / pure functions) still holds true.

Thanks @jakevdp. I think it would be nice to either update it, or at least add a note saying that the section talks about legacy APIs and that numpy has had a default_rng/Generator API that has parallel capabilities and statistical properties & performance similar to (or even better than, xref Robert's analysis above) JAX since 2019. A sentence like "The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5Kb state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow." looks like it is solely designed to say "see how much better JAX is than NumPy here".

@jakevdp
Copy link

jakevdp commented Jun 8, 2022

OK, thanks @rgommers, I'll look into changing that discussion. For what it's worth, our intent there was not to say "JAX is better than numpy", but rather to explain why you need to think about PRNGs differently in JAX than numpy (and this is true regardless of whether you're using an explicit or implicit mutable state). But I see how one could read it and come away with a different impression.

@rkern
Copy link

rkern commented Jun 8, 2022

I do think it is worth talking about SeedSequence spawning there too (even if you do have to address the legacy np.random.* stuff as well). It is a place where modern numpy is quite similar to JAX. The recommended ways to achieve parallelism are the basically same in both; the major difference is when and how much parallelism is enforced by the API (and enabled by the rest of JAX).

The other thing I would remove is the discussion about the sequential-equivalent guarantee. We have none. If there is a place in our docs that says otherwise, please let me know so I can expunge it. Something like that might have crept into some of the examples, which were mostly written by third-party contributors who might have assumed that the behavior they observed was guaranteed for some reason. Most methods do happen to behave that way because it's generally the easiest thing to implement with a serial PRNG, but we do not guarantee anything about it. For example, if we had a distribution where there was an algorithm that worked well in bulk but had a costly one-time setup and another that worked better for small numbers at a time, we reserve the right to implement a switchover point.

@jakevdp
Copy link

jakevdp commented Jun 8, 2022

Hi @rkern - thanks for those comments. Just to be clear, the doc you linked to is not really "documentation" per se, but rather a years-old design doc meant to lay-out the motivation for JAX's initial PRNG design (which has evolved since then, and will continue to evolve). For that reason, I don't think I will be updating it, but rather will add a disclaimer at the top making its intent more clear. How does that sound?

@rkern
Copy link

rkern commented Jun 8, 2022

There might be a couple of other opportunities to standardize something, but it's not particularly clear to me what it would actually enable across implementations as different as numpy and JAX, say.

So for example, we could say, informationally, that there are 3 basic flavors of PRNG state flow that an implementation could have, but the standard doesn't specify any one or any of the details about each.

  1. Stateful Generator with distribution methods (x = rng.uniform(0, 1)) (implemented by numpy, dask, maybe others)
  2. JAX-like splitting-central pure functional (key, subkey = split(key); x = uniform(subkey, 0, 1), etc.) (implemented by JAX)
  3. Copy-and-return-iterated pure functional (rng, x = uniform(rng, 0, 1)) (not implemented by anyone that I'm aware of)

But it could standardize on a basic list of methods/functions that an implementation ought to have, along with the semantics of the other arguments not related to the PRNG state flow, and allowing for extra implementation-specific arguments (dask adds chunks=, for example). In particular, in the transition from RandomState.randint() to Generator.integers(), we fixed a number of corner cases that had made it hard to specify the full range of np.uint64, for example. It would be a really good idea for all of the implementations to also implement that logic and not suffer from the naive implementation mistakes I made in randint().

This is not likely actually enable one to write significant backend-agnostic code. So the Array API standard might not be the best place for that. Maybe a SPEC is a better instrument? The Array API does have the advantage of having convened the right people.

@rkern
Copy link

rkern commented Jun 8, 2022

@jakevdp Sure, but it's also in the tutorial.

@jakevdp
Copy link

jakevdp commented Jun 8, 2022

Thanks for pointing that out. To be honest, it's news to me that numpy does not provide a sequential-equivalent guarantee in it's pseudo-random values. I've spent the better part of the last couple decades assuming it did (I hope none of my code actually depends on that assumption...)

@jakevdp
Copy link

jakevdp commented Jun 8, 2022

Is there a place in the numpy docs that mentions the lack of a sequential equivalent guarantee in pseudo-random numbers? I'd like to link to it in the discussion. If not, I can link to this thread.

@edloper
Copy link

edloper commented Jun 8, 2022 via email

@rkern
Copy link

rkern commented Jun 8, 2022

The "Compatibility Guarantee" in the RandomState docstring has this language:

A fixed bit generator using a fixed seed and a fixed series of calls to RandomState methods using the same parameters, ...

This language was specifically about compatibility about different versions of numpy (though of course, it also holds true inside of a process). That's the only guarantee we've ever provided (and even that we disclaim for Generator, though it would make sense to add one about same-build and same-process reproducibility).

It seems reasonable to explicitly disclaim that not every combination that someone might think ought to be equivalent under some theory actually is actually implemented to be so. People do expect it from time to time and "confirm" it to themselves when they try something out, though we've never promised it.

@jakevdp
Copy link

jakevdp commented Jun 8, 2022

Thanks for the pointers - in that case I think I will link to this thread, because it seems more direct than linking to a long doc and pointing out the omission of the topic in question.

@seberg
Copy link
Contributor

seberg commented Jun 8, 2022

Are there actually large reasons why the NumPy API (with .spawn() also living on the rng object itself and a default_rng() function), is not a potential way forward?

It seems that there are the 3 "ways" of using a PRNG that Robert listed also above?:

  1. Sequential drawing.
  2. "Lazy" sequential drawing (rng, data = uniform(rng, size=1000) would be a "stateless" spelling of that. I.e. a way that allows data to be evaluated at a later time, but avoids splitting.)
    • For an end-user/library API (i.e. what we are discussing) it may not matter if the provider (JAX) actually has to do splitting to achieve this?
  3. Flexible splitting/spawning for parallelism, that does not come with the limits of 2.
    • A constraint here is maybe that NumPy guarantees that splitting does not depend on the RNG state as modified by sequential draws. But I expect that this would be easy to guarantee in all cases (if desired).

Everyone provides 3, some PRNGs (as NumPy) can probably not provide 2 but don't need it (NumPy never evaluates lazily – it might only be a curious addition for parallelization). JAX need 2 (but is currently "missing" it – i.e. is using the less optimal "always spawn" scheme).

But I don't think it has to be user-facing API? If you write:

data1 = rng.uniform()
data2 = rng.uniform()

in JAX or NumPy, whether that draws the numbers directly sequentially, or does so lazily (with some advancing or splitting scheme) hardly matters? What matters is that the user facing API is "sequential", because writing it differently would be bad for many PRNG schemes.

What I don't quite understand yet, is why a stateful API that provides:

  1. Sequential draws as the normal way to get random numbers
  2. A way to spawn off new streams for parallelism

is not a reasonable start. It would be bad for NumPy to provide a JAX-style stateless API, but why can't JAX provide a sequential stateful API (even if it is not necessarily a sequential RNG generation internally).
After all, just because the API is "sequential", does not mean it has to guarantee that:

concat([rng.uniform(size=10), rng.uniform(size= 20)])

gives the same as:

concate([rng.uniform(size= 20), rng.uniform(size= 10)])

NumPy will do this (often or always?) but it doesn't seem necessary for the end-user API?

The difficulty I see a bit is that some implementations may look like they provide guarantees like the above concatenation and others will not provide those same guarantees.

@froystig
Copy link

froystig commented Jun 8, 2022

It's a welcome surprise to have this thorough review of the JAX PRNG! Thanks @rkern for the really nice detailed look in particular.

I'm only arriving to this thread now and still catching up, but as a quick note for now: we happen to be actively working on some changes to the threefry_random_bits implementation, most immediately in how it lays out the counter values from iota, with the aim of making it more friendly to partitioning and to JAX's experimental support of dynamic array shapes. I don't think the immediate-term changes will affect many of @rkern's overall takeaways, but they may address this remark in part:

In particular, instead of incrementing the counter in 64-bit blocks, it creates a 32-bit counter array, splits it in half to use the first half as the upper 32-bit word and the second half as the lower 32-bit word. That's a bit wasteful (you could instead just use zeros(rem//2) and lax.iota(rem//2) for rem requested words instead), but they are valid inputs. The two arrays which are output (being separately the upper and lower words of the cipher_text blocks) are then concatenated together end-to-end instead of reassembling those words. I assume this is done because it's cheaper to concatenate than to interleave again. This is fine as far as it goes; it doesn't hugely affect the quality of numbers coming out of one jax.random.uniform() call, for example. But it does wreck the bijection properties, which will become important when we talk about implementation choices for split().

We indeed do not need to lay out the counter values exactly in the way that we do, so you'll see changes to that end. (I've observed in other contexts over time that the particular split/concat scheme we use isn't ideal, but I only picked up work on this again last week.)

So far, so good. Despite that quirk, I have only one qualm with the core PRNG scheme for drawing bits in the jax.random functions: 64 bits is quite small. I consider 128-bit PRNGs to be de rigeur today, particularly in contexts where we are talking about large amounts of parallelism (and the JAX API enforces a large amount of parallelism).

As of recently, I'd say that the key size—and in fact the entire hash function or base generator—is not an essential or timeless choice for us. It is what we implemented initially, and it remains the default for now. But as of google/jax#6899 we have a means of replacing the underlying bit generator. We use this internally to experiment with and offer other bit generators. I opened google/jax#7676 a while back precisely to track the introduction of a 128-bit generator. We also have plans to make it possible for users to plug in an arbitrary generator of their own, and to involve differently-backed PRNGs in the same process (all tracked at google/jax#9263). We might always choose to change our default away from the current threefry2x32 hash at some point.

So, where possible in the current discussion, it may be useful to assume that the base generator can change to meet the needs of the machine, application, etc.

@rkern
Copy link

rkern commented Jun 8, 2022

@froystig That's all good to hear! Like I said in my introduction, I didn't consider any of those details as immutable black marks against JAX or the overall design, but they did serve as a way to talk about the similarities with SeedSequence, the kinds of improvements that are left in the design space (bijections as far as the eye can see!), and what limits are still fundamental to a splitting-centric design (the Birthday Bound). And I'm a long-winded jerk.

@rkern
Copy link

rkern commented Jun 8, 2022

I don't know all of JAX's rewriting capabilities, but I don't think it can handle whatever it needs to do to make the Generator-like API work with the same reordering benefits.

data1 = rng.uniform()
data2 = rng.uniform()

What it would have to do here is recognize that rng is a Generator, so the special case applies, and magically under the covers transform the code to do a couple of splits up at the top before any of the uniform() calls.

And even if JAX did have technical capability of doing so, I'm not sure that they want their users to mix mental models like that. It's hard enough to teach people to use one consistently.

@froystig
Copy link

froystig commented Jun 8, 2022

they did serve as a way to talk about the similarities with SeedSequence, the kinds of improvements that are left in the design space (bijections as far as the eye can see!), and what limits are still fundamental to a splitting-centric design (the Birthday Bound).

@rkern Absolutely – I understood that and it makes sense. I only meant to share some related thoughts and work that's in progress, mostly as an aside. Your comments are super valuable, and this thread is a rare opportunity for us on JAX. Thank you!

@kgryte kgryte added the RFC Request for comments. Feature requests and proposed changes. label Jun 22, 2022
@wangpengmit
Copy link

wangpengmit commented Sep 3, 2022

I'm the author of tf.random.Generator (link) and several docs (e.g. this and this). TF RNG's current status is using stateless RNGs (similar to JAX's) as the underlying engine (code), and a mutable counter (implemented as a tf.Variable) to advance the RNG state (code). tf.random.Generator is basically stateless RNGs + a mutable counter. It works well in TF because TF is more friendly to mutability than JAX.

Unlike JAX, tf.random.Generator supports but doesn't do splitting unless asked to (code). It just does counter increment in common (sequential) cases (code). I feel splitting is more wasteful of the (key, counter) space than counter increment, but I'm not sure. Counter increment is of course computationally cheaper than splitting, the latter of which caused some speed problems in some TF models.

TF only supports counter-based RNG algorithms and has no plan to support non-counter-based ones.

@rkern
Copy link

rkern commented Sep 3, 2022

Yes, I can confirm that splitting is more profligate of the state space than the counter increment, especially when you constrain yourself to 64-bit keys, if I'm reading the code correctly. I think the same math from my analysis of JAX's PRNG splitting applies (the only real difference is that tf.random.Generator will generate the new keys at whatever counter it is currently on rather than JAX always starting at 0, but that doesn't affect the non-invertible mapping argument).

@ntessore
Copy link

ntessore commented Mar 3, 2023

From a user perspective, specifically someone writing libraries which want to support arbitrary array types, I don't see the problem with carrying the random state separately from the random namespace. It is not the slickest interface, but it can be implemented today:

# file: random_numpy.np

def standard_normal(rng, shape=()):
    return rng, rng.standard_normal(size=shape)

def poisson(rng, lam, shape=()):
    return rng, rng.poisson(lam, size=shape)


# file: random_jax.py

from jax import random

def standard_normal(key, shape=()):
    key, subkey = random.split(key)
    return key, random.normal(subkey, shape)

def poisson(key, lam, shape=()):
    key, subkey = random.split(key)
    return key, random.poisson(subkey, lam, shape)


# file: test.py

def random_namespace(rs):
    import sys
    if 'numpy' in sys.modules:
        import numpy as np
        if isinstance(rs, np.random.Generator):
            return __import__('random_numpy')
    if 'jax' in sys.modules:
        import jax
        if isinstance(rs, jax.Array):
            return __import__('random_jax')

def my_function(lam, rs):
    random = random_namespace(rs)
    rs, rv1 = random.standard_normal(rs, shape=(4,))
    rs, rv2 = random.poisson(rs, lam, shape=(4,))
    return rv1 + rv2

import numpy as np
rng = np.random.default_rng()
print(my_function(1.0, rng))
# [ 1.25035394 -0.11511349  1.87203598  2.55088409]

import jax
key = jax.random.PRNGKey(42)
print(my_function(1.0, key))
# [0.43244982 0.8856307  0.06793922 1.4641162 ]

For anything more complicated than sampling random variates, particularly anything that requires knowledge of the implementation details of the random state, we accept that we will always have to handle each framework separately.

@jakevdp
Copy link

jakevdp commented Mar 5, 2023

The main difference is that rng in numpy and key in JAX are fundamentally different kinds of objects. In the numpy example, the call to my_function has the side-effect of mutating rng in-place, so that subsequent calls will return different random values. In JAX, functions are pure, and so the call to my_function cannot mutate key, and subsequent calls with the same key will produce identical values to the first call:

import numpy as np
rng = np.random.default_rng()
print(my_function(1.0, rng))
# [ 1.25035394 -0.11511349  1.87203598  2.55088409]
print(my_function(1.0, rng))
# [3.15990919 1.07582056 1.09202392 0.33987543]

import jax
key = jax.random.PRNGKey(42)
print(my_function(1.0, key))
# [-0.5675502   1.8856307   0.06793922  3.464116  ]
print(my_function(1.0, key))
# [-0.5675502   1.8856307   0.06793922  3.464116  ]

@ntessore
Copy link

ntessore commented Mar 5, 2023

Yes, there is probably no universal API for situations where it matters whether or not we are dealing with a stateful RNG or a stateless key.

My point is that it does not preclude us from having a universal (functional) API that can generate random numbers. In my experience, that's also the only situation where you want such an API. For everything more complicated, you need special cases anyway.

Edit: FWIW, here is how I handle your specific example in my wrapper.

# file: random_numpy.np
def split(rng):
    return rng, rng

# file: random_jax.py
def split(key):
    return random.split(key)

# in implementation, rand is rng or key
rand, subrand = random.split(rand)
my_function(1.0, subrand)
rand, subrand = random.split(rand)
my_function(1.0, subrand)

@kgryte
Copy link
Contributor

kgryte commented Jun 29, 2023

Given lack of ecosystem agreement, we are not likely to forge a path forward for standardizing a PRNG API in the array API specification at this time. As such, I will go ahead and close this issue.

We can reopen/revisit if and when we have greater community consensus.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC Request for comments. Feature requests and proposed changes.
Projects
None yet
Development

No branches or pull requests