# Tutorial #2: Randomness in PCX

This is a small notebook showing a couple of details on how randomness is implemented in PCX. JAX provides its own stateless random utilities on which we build to provide a simple interface: `pcx.RandomKeyGenerator`. By default, pcx offer a globally instantiated `pcx.RandomKeyGenerator`, `px.RKG`, which is used by default if no alternative is provided.

In [None]:
import pcx as px
import pcx.nn as pxnn
import jax
import jax.numpy as jnp

# By default, px.RKG is initialised with the system time.
# We set both the global and a custom rkg seed to 0 and show their usage.
px.RKG.seed(0)
custom_RKG = px.RandomKeyGenerator(0)
layer_default = pxnn.Linear(8, 8, True) # by default uses px.RKG
layer_custom = pxnn.Linear(8, 8, True, rkg=custom_RKG)

assert jnp.all(layer_default.nn.weight == layer_custom.nn.weight), "This doesn't fail since both RKGs are initialised with the seed 0."

# Note that pcx functions accept a `pcx.RandomKeyGenerator`, while jax functions require a key,
# which can be obtained as following:
a_key = px.RKG()

Being `pcax.RKG` globally accessible, it can also be accessed with pcax transformations. This, however, requires its state to be accordingly transformed as well. This happens by automatically adding it to a transformation keyword arguments and applying relevant transformations:
- using `vmap` splits the state in `n` different states which are mapped over the vmapped dimension. At the end of the function, the vmapped states are discarded and only one is kept, becoming the new `pcax.RKG` state.

If other behaviours are necessary, it is always possible to pass your own `pcax.RandomKeyGenerator` via keyword arguments and apply the desired transformations.

In [None]:
import pcx.functional as pxf

@pxf.jit()
@pxf.vmap(in_axes=(0, None, None), out_axes=0)
def vsum(a, min_val, max_val):
    a = a + jax.random.uniform(px.RKG(), a.shape, minval=min_val, maxval=max_val)
    
    px.RKG.seed(0)
    
    return a

a = jnp.ones((10, 1))

a_1 = vsum(a, -1.0, 1.0)
a_2 = vsum(a, -1.0, 1.0)

assert jnp.any(a_1 != a_2), "The two arrays should be different since vsum changes the state of the RKG."

key = px.RKG.key.get()
assert jnp.all(key == 0), "The key should be 0, as set inside the vsum function"

print("All good!")


Note how the following uses the same key for all the values along the vmapped dimension, as we do not vmap the custom `pcax.RandomKeyGenerator`

In [None]:
@pxf.jit()
@pxf.vmap({'rkg': None}, in_axes=(0, None, None), out_axes=0)
def sum_custom(a, min_val, max_val, *, rkg):
    return a + jax.random.uniform(rkg(), a.shape, minval=min_val, maxval=max_val)

a = jnp.ones((10, 1))
a_ = sum_custom(a, -1.0, 1.0, rkg=custom_RKG)

print("All entries of a_ should be the same:")
print(a_)

# Since we use a custom rkg and we do not batch over it, the key state is shared
# over the vmap dimension and all the values produced are the same.
#
# NOTE: this is not something you would probably need normally, so think carefully
# about it if you find yourself using it. For standard use cases, one should simply
# rely on the provided default RKG.
assert jnp.all(a_ == a_[0]), "All the entries in a_ should be the same."

print("All good!")

If we want to vmap a custom RKG, we need to explicitly split and merge the key state other the vmap dimension.

In [None]:
@pxf.jit()
@pxf.vmap({'rkg': 0}, in_axes=(0, None, None), out_axes=0)
def vsum_custom(a, min_val, max_val, *, rkg):
    return a + jax.random.uniform(rkg(), a.shape, minval=min_val, maxval=max_val)

a = jnp.ones((10, 1))
custom_RKG.key.set(custom_RKG.key.split(len(a)))
a_ = vsum_custom(a, -1.0, 1.0, rkg=custom_RKG)
custom_RKG.key.set(custom_RKG.key[0])

print("All entries of a_ should now look random:")
print(a_)