What are some sharp edges of Jax that are good to know before really starting?

In [None]:
from genjax import bernoulli, gen
import jax

# 1] Jax expects arrays/tuples everywhere


@gen
def f(p):
    v = bernoulli(p) @ "v"
    return v


key = jax.random.PRNGKey(0)
# First way of failing
try:
    f.simulate(key, 0.5)
except Exception as e:
    print(e)

# Second way of failing
try:
    f.simulate(key, [0.5])
except Exception as e:
    print(e)

# Third way of failing
try:
    f.simulate(key, (0.5))
except Exception as e:
    print(e)

# Correct way
f.simulate(key, (0.5,))

In [None]:
# 2] We rely on Tensor Flow Probability and it sometimes does weird things.

# Bernoulli distribution uses logits instead of probabilities
from genjax import bernoulli, gen
from genjax import ChoiceMapBuilder as C
import jax.numpy as jnp


@gen
def g(p):
    v = bernoulli(p) @ "v"
    return v


key = jax.random.PRNGKey(0)
arg = (3.0,)  # 3 is not a valid probability but a valid logit
keys = jax.random.split(key, 30)
# simulate 30 times
print(jnp.array([g.simulate(key, arg).get_sample()["v"] for key in keys]))

# Values which are stricter than 0 are considered to be the value True.
# This means that observing that the value of "v" is 4 will be considered possible while intuitively "v" should only have support on 0 and 1.
chm = C["v"].set(3)
print()
print(g.assess(chm, (0.5,))[0])  # This should be -inf.
print()

# Alternatively, we can use the flip function which uses probabilities instead of logits.
from genjax import flip


@gen
def h(p):
    v = flip(p) @ "v"
    return v


key = jax.random.PRNGKey(0)
arg = (0.3,)  # 0.3 is a valid probability
keys = jax.random.split(key, 30)
# simulate 30 times
print(jnp.array([h.simulate(key, arg).get_sample()["v"] for key in keys]))
print()

# Categorical distribution also use logits instead of probabilities
from genjax import categorical


@gen
def i(p):
    v = categorical(p) @ "v"
    return v


key = jax.random.PRNGKey(0)
arg = ([3.0, 1.0, 2.0],)  # lists of 3 logits
keys = jax.random.split(key, 30)
# simulate 30 times
print(jnp.array([i.simulate(key, arg).get_sample()["v"] for key in keys]))

In [None]:
# 3] Jax code can be compiled for better performance.
from jax import jit


# jit is the way to force Jax to compile the code.
# It can be used as a decorator
@jit
def f_v1(p):
    jax.lax.cond(p.sum(), lambda p: p * p, lambda p: p * p, p)


# Or as a function
f_v2 = jit(f_v2)


# Baseline
def f_v3(p):
    jax.lax.cond(p.sum(), lambda p: p * p, lambda p: p * p, p)


# Notice that the first and second have the same performance while the third is 10k times slower.
arg = jax.numpy.eye(1000)
%timeit f_v1(arg)
%timeit f_v2(arg)
%timeit f_v3(arg)

In [None]:
# 4] Going from Python to Jax
# For loops
def python_loop(x):
    for i in range(100):
        x = 2 * x
    return x


def jax_loop(x):
    jax.lax.fori_loop(0, 100, lambda i, x: 2 * x, x)


# Conditional statements
def python_cond(x):
    if x.sum() > 0:
        return x * x
    else:
        return x


def jax_cond(x):
    jax.lax.cond(x.sum(), lambda x: x * x, lambda x: x, x)


# While loops
def python_while(x):
    while x.sum() > 0:
        x = x * x
    return x


def jax_while(x):
    jax.lax.while_loop(lambda x: x.sum() > 0, lambda x: x * x, x)

In [None]:
# 5] In Jax (and GenJax), jit at the outer most level, and in particular not inside a for loop.


def innocent_jittable(x):
    return jax.lax.fori_loop(0, 100, lambda i, x: 1.1 * x, x)


def turtle_compilation_speed(x):
    for i in range(100):
        x = jit(innocent_jittable)(x)
    return x


%timeit turtle_compilation_speed(1)


# the same but non-jitted
def human_speed_no_compile(x):
    for i in range(100):
        x = innocent_jittable(x)
    return x


%timeit human_speed_no_compile(1.0)

# jit at the outer most level
jitted = jit(innocent_jittable)


def hare_compilation_speed(x):
    for i in range(100):
        x = jitted(x)
    return x


# get rid compile time with a first run
jitted(1.0)
%timeit hare_compilation_speed(1.0)

# We can see that the fastest way is to jit the outer most function, which makes it faster than the non-jitted by a factor of !7, and that jitting inside a for loop is an extra 500x slower than the non-jitted version.

In [None]:
# 6] is my thing compiling or is it blocked at traced time?

import multiprocessing
import time
import subprocess

# In Jax, the first time you run a function, it is traced, which produces a Jaxpr, a representation of the computation that Jax can optimize.

# So in order to debug whether a function is running or not, if it passes the first check that Python let's you write it, you can check if it is running by checking if it is traced, before actually running it on data.


# This is done by calling make_jaxpr on the function. If it returns a Jaxpr, then the function is traced and ready to be run on data.
def im_fine(x):
    return x * x


print(jax.make_jaxpr(im_fine)(1.0))
print()


def i_wont_be_so_fine(x):
    return jax.lax.while_loop(lambda x: x > 0, lambda x: x * x, x)


print(jax.make_jaxpr(i_wont_be_so_fine)(1.0))
print()


# Try running the function for 8 seconds
def run_process():
    ctx = multiprocessing.get_context("spawn")
    p = ctx.Process(target=i_wont_be_so_fine, args=(1.0,))
    p.start()
    time.sleep(8)
    if p.is_alive():
        print("I'm still running")
        p.terminate()
        p.join()


result = subprocess.run(
    ["python", "genjax/docs/sharp-edges-notebooks/basics/script.py"],
    capture_output=True,
    text=True,
)

# Print the output
print(result.stdout)

In [None]:
# 7] Using random keys for generative functions

# In GenJax, we use explicit random keys to generate random numbers. This is done by splitting a key into multiple keys, and using them to generate random numbers.
from genjax import bernoulli, gen, beta


@gen
def beta_bernoulli_process(u):
    p = beta(0.0, u) @ "p"
    v = bernoulli(p) @ "v"  # sweet
    return v


key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 20)
jitted = jit(beta_bernoulli_process.simulate)
print(jnp.array([jitted(key, (0.5,)).get_sample()["v"] for key in keys]))

In [None]:
# 8] Jax uses 32-bit floats by default

import numpy as np
import jax.numpy as jnp
from jax import random

x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
print("surprise surprise: ", x.dtype)
print()

# A common TypeError occurs when one tries using np instead of jnp, which is the Jax version of numpy, the former uses 64-bit floats by default, while the Jax version uses 32-bit floats by default.

# this on its own gives a UserWarning
jnp.array([1, 2, 3], dtype=np.float64)

# this will truncate the array to 32-bit floats and also give a UserWarning
innocent_looking_array = np.array([1.0, 2.0, 3.0], dtype=np.float64)


@jax.jit
def innocent_looking_function(x):
    return jax.lax.cond(x.sum(), lambda x: x * x, lambda x: innocent_looking_array, x)


input = jnp.array([1.0, 2.0, 3.0])
innocent_looking_function(input)

try:
    # This actually raises a TypeError
    innocent_looking_array = np.array([1, 2, 3], dtype=np.float64)

    @jax.jit
    def innocent_looking_function(x):
        return jax.lax.cond(
            x.sum(), lambda x: x * x, lambda x: innocent_looking_array, x
        )

    input = jnp.array([1, 2, 3])
    innocent_looking_function(input)
except Exception as e:
    print(e)

In [None]:
# 9] Beware to OOM on the GPU which happens faster than you might think

# Here's a simple HMM model that can be run on the GPU.
# By simply changing N from 300 to 1000, the code will typically run out of memory on the GPU as it will take ~300GB of memory.

import genjax
import jax
from jax import numpy as jnp
from jax import jit

N = 300
n_repeats = 100
variance = jnp.eye(N)
initial_state = jax.random.normal(jax.random.PRNGKey(0), (N,))


@genjax.scan_combinator(max_length=100)
@genjax.gen
def hmm(x, c):
    new_x = genjax.mv_normal(x, variance) @ "new_x"
    return new_x, None


key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
jitted = jit(hmm.repeat(num_repeats=n_repeats).simulate)
trace = jitted(key, (initial_state, None))
%timeit jitted(subkey, (initial_state, None))

# If you are running out of memory, you can try de-batching one of the computations, or using a smaller batch size.
# For instance, in this example, we can-debatch the repeat combinator, which will reduce the memory usage by a factor of 100, at the cost of some performance.
jitted = jit(hmm.simulate)


def hmm_debatched(key, initial_state):
    keys = jax.random.split(key, n_repeats)
    traces = {}
    for i in range(n_repeats):
        trace = jitted(keys[i], (initial_state, None))
        traces[i] = trace
    return traces


key = jax.random.PRNGKey(0)
# About 4x slower on arm64 cpu and 40x on a google colab gpu
%timeit hmm_debatched(key, initial_state)

In [None]:
# 9] Sometimes tracing can be slow, and you may accelerate it.

# TODO: why is nested vmap so slow to trace?
# TODO: printing returning traced values or nothing when function is jitted.