In [29]:
import jax
import jax.numpy as jnp
import jax.random as jran
import numpy as np
from typing import Any, Callable, Sequence
import flax
import flax.linen as nn

In [30]:
a = jnp.zeros((2, 5))
print(a.dtype)

float32


In [3]:
a.devices()

{CpuDevice(id=0)}

## Random numbers

JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a ``key``.

In [4]:
key = jran.PRNGKey(42)
key

Array([ 0, 42], dtype=uint32)

In [5]:
for _ in range(3):
    rv = jran.normal(key)
    print(rv)

-0.18471177
-0.18471177
-0.18471177


In [6]:
for _ in range(2):
    rvs = jran.normal(key, shape=(3,))
    print(rvs)

[ 0.18693547 -1.2806505  -1.5593132 ]
[ 0.18693547 -1.2806505  -1.5593132 ]


Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated.

**Note:** Feeding the same key to different random functions can result in correlated outputs, which is generally undesirable.

In order to generate different and independent samples, you must ``split()`` the key yourself whenever you want to call a random function:

In [7]:
key = jran.PRNGKey(42)

for _ in range(3):
    key, subkey = jran.split(key)
    rvs = jran.normal(subkey, shape=(3,))
    print(rvs)
    # key carries over to the next iteration

[-0.5675502   0.28439185 -0.9320608 ]
[ 0.67903334 -1.220606    0.94670606]
[-0.09680057  0.7366595   0.86116916]


If fixed number of iterations, do all splits once:

In [8]:
N = 3
key = jran.PRNGKey(42)
for key in jran.split(key, N):
    rvs = jran.normal(key, shape=(3,))
    print(rvs)

[-0.04324572  0.00212434 -0.40485173]
[-1.0068504  -0.87616897 -0.6528091 ]
[-0.70704466  1.2879405  -0.4776387 ]


Best option if iterations don't have sequential dependence: use ``vmap`` to vectorize operation:

In [9]:
def f(key):
    return jran.normal(key)

key = jran.PRNGKey(42)
N = 10
rvs = jax.vmap(f)(jran.split(key, N))
print(rvs)

[ 1.7917308   0.6962527  -0.3863588   0.6568204   1.5387199   0.08471087
 -0.05403972 -0.6987761  -1.7351557   1.9373399 ]


Note that in this case, we can vectorize ``jran.normal`` directly:

In [10]:
rvs = jax.vmap(jran.normal)(jran.split(key, N))
print(rvs)

[ 1.7917308   0.6962527  -0.3863588   0.6568204   1.5387199   0.08471087
 -0.05403972 -0.6987761  -1.7351557   1.9373399 ]


### Small benchmark

Use `block_until_ready()` in benchmarks to account for JAX’s asynchronous dispatch.

**Case 1a**

In [50]:
%%timeit
N = int(1e3)
key = jran.PRNGKey(42)
rvs = jnp.stack([jran.normal(key) for key in jran.split(key, N)]).block_until_ready()

37.8 ms ± 498 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


**Case 1b**

In [51]:
%%timeit

N = int(1e3)
key = jran.PRNGKey(42)
rvs = jax.vmap(jran.normal)(jran.split(key, N)).block_until_ready()

346 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


**Case 2a**

In [52]:
%%timeit

N = int(1e6)
key = jran.PRNGKey(42)
rvs = jax.vmap(jran.normal)(jran.split(key, N)).block_until_ready()

12.8 ms ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


**Case 2b**

In [53]:
%%timeit

N = int(1e6)
key = jran.PRNGKey(42)
rvs = jran.normal(key, shape=(N,)).block_until_ready()

2.82 ms ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


**Case 2c**

In [54]:
%%timeit

N = int(1e6)
rvs = np.random.default_rng().normal(size=N)

4.91 ms ± 74.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Gradients and autodiff

In [20]:
@jax.jit
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

@jax.jit
def grad_sigmoid_exact(x):
    s = sigmoid(x)
    return s * (1 - s)

grad_sigmoid_jax = jax.jit(jax.grad(sigmoid))

In [21]:
x = 0.0

print(f"{sigmoid(x)=}")
print(f"{grad_sigmoid_exact(x)=}")
print(f"{grad_sigmoid_jax(x)=}")

sigmoid(x)=Array(0.5, dtype=float32, weak_type=True)
grad_sigmoid_exact(x)=Array(0.25, dtype=float32, weak_type=True)
grad_sigmoid_jax(x)=Array(0.25, dtype=float32, weak_type=True)


### Benchmark jax vs exact and jitted vs not

In [39]:
N = int(1e3)
key = jran.PRNGKey(42)
x = jran.normal(key, (N,))

**No jit**

In [40]:
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def grad_sigmoid_exact(x):
    s = sigmoid(x)
    return s * (1 - s)

grad_sigmoid_jax = jax.grad(sigmoid)

In [41]:
%timeit grad_sigmoid_exact(x).block_until_ready()

17.2 µs ± 157 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [42]:
%timeit jax.vmap(grad_sigmoid_jax)(x).block_until_ready()

1.64 ms ± 15.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


**Jitted**

JAX uses eager computations by default; if you want lazy evaluation—what's sometimes called graph mode in other packages—you can specify this by wrapping your function in `jax.jit`.

Within a jit-compiled function, JAX replaces arrays with abstract tracers in order to determine the full sequence of operations in the function, and to send them all to XLA for compilation, where the operations may be rearranged or transformed by the compiler to make the overall execution more efficient.

In [43]:
@jax.jit
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

@jax.jit
def grad_sigmoid_exact(x):
    s = sigmoid(x)
    return s * (1 - s)

grad_sigmoid_jax = jax.jit(jax.grad(sigmoid))

In [44]:
%timeit grad_sigmoid_exact(x).block_until_ready()

2.87 µs ± 46.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [45]:
%timeit jax.vmap(grad_sigmoid_jax)(x).block_until_ready()

185 µs ± 4.05 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [57]:
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

grad_sigmoid_jax = jax.jit(jax.grad(sigmoid))

In [58]:
%timeit jax.vmap(grad_sigmoid_jax)(x).block_until_ready()

175 µs ± 1.35 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [48]:
@jax.jit
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

grad_sigmoid_jax = jax.grad(sigmoid)

In [49]:
%timeit jax.vmap(grad_sigmoid_jax)(x).block_until_ready()

641 µs ± 1.79 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Convention

A useful pattern is to use `numpy` for operations that should be static (i.e. done at compile-time), and use `jax.numpy` for operations that should be traced (i.e. compiled and executed at run-time).

For this reason, a standard convention in JAX programs is to `import numpy as np` and `import jax.numpy as jnp` so that both interfaces are available for finer control over whether operations are performed in a static matter (with `numpy`, once at compile-time) or a traced manner (with `jax.numpy`, optimized at run-time).

## Pytrees