In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap
from jax import random
from jax.tree_util import register_pytree_node
import jax
import jax.tools

from jax import random

### Generating initial key

In [2]:
key = random.PRNGKey(0)



In [3]:
random.uniform(key)

DeviceArray(0.41845703, dtype=float32)

In [5]:
random.uniform(key) # same result because it's a pure random, and we used the same key

DeviceArray(0.41845703, dtype=float32)

### Splitting key to pass to subcomputations

In [10]:
key = random.PRNGKey(0)
key

DeviceArray([0, 0], dtype=uint32)

In [14]:
subkey1, subkey2 = random.split(key)
subkey1, subkey2

(array([4146024105,  967050713], dtype=uint32),
 array([2718843009, 1272950319], dtype=uint32))

In [15]:
x = random.uniform(subkey1)
y = random.uniform(subkey2)

x,y

(DeviceArray(0.5572065, dtype=float32), DeviceArray(0.10536897, dtype=float32))

In [19]:
### Using random with nested functions
class Point:
    
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __repr__(self):
        return "Point(x={}, y={})".format(self.x, self.y)
        

def random_point(r_key):
    k1, k2 =  random.split(r_key)
    return Point(random.uniform(k1), random.uniform(k1))


def random_point_array(size, r_key):
    result = []
    for i in range(size):
         # explicit handling of key, to avoid passing the same random (== same random result)
        r_key, split_key = random.split(r_key)
        result.append(random_point(split_key))
    return result
    
key = random.PRNGKey(0)
random_point_array(4, key)

[Point(x=0.9411180019378662, y=0.9411180019378662),
 Point(x=0.2546273469924927, y=0.2546273469924927),
 Point(x=0.6630877256393433, y=0.6630877256393433),
 Point(x=0.06263375282287598, y=0.06263375282287598)]