How do I use conditionals in Jax?

In [None]:
# In pure Python, we can use usual conditionals
def simple_cond_python(p):
    if p > 0:
        return 2 * p
    else:
        return -p


print(simple_cond_python(0.3))
print(simple_cond_python(-0.4))

In [None]:
# In pure JAX, we write conditionals with jax.lax.cond as follows
import jax


def simple_cond_jax(p):
    pred = p > 0
    branch_1 = lambda p: 2 * p
    branch_2 = lambda p: -p
    arg_of_cond = p
    cond_res = jax.lax.cond(pred, branch_1, branch_2, arg_of_cond)
    return cond_res


print(simple_cond_jax(0.3))
print(simple_cond_jax(-0.4))

In [None]:
# Compiled JAX code is usually quite faster than Python code


def python_loop(x):
    for i in range(40000):
        if x < 100.0:
            x = 2 * x
        else:
            x = x - 97.0
    return x


@jax.jit
def jax_loop(x):
    return jax.lax.fori_loop(
        0,
        40000,
        lambda _, x: jax.lax.cond(x < 100.0, lambda x: 2 * x, lambda x: x - 97.0, x),
        x,
    )


%timeit python_loop(1.0)
# Get the JIT time out of the way
jax_loop(1.0)
%timeit jax_loop(1.0)

In [None]:
# One restriction is that both branches should have the same pytree structure
def failing_simple_cond_1(p):
    pred = p > 0
    branch_1 = lambda p: (p, p)
    branch_2 = lambda p: -p
    arg_of_cond = p
    cond_res = jax.lax.cond(pred, branch_1, branch_2, arg_of_cond)
    return cond_res


try:
    print(failing_simple_cond_1(0.3))
except TypeError as e:
    print(e)


# The other one is that the type of the output of the branches should be the same
def failing_simple_cond_2(p):
    pred = p > 0
    branch_1 = lambda p: 2 * p
    branch_2 = lambda p: 7
    arg_of_cond = p
    cond_res = jax.lax.cond(pred, branch_1, branch_2, arg_of_cond)
    return cond_res


try:
    print(failing_simple_cond_2(0.3))
except TypeError as e:
    print(e)

In [None]:
# In GenJAX, the syntax is a bit different still.
# Similarly to JAX having a custom primitive `jax.lax.cond` that creates a conditional by "composing" two functions seen as branches, GenJAX has a custom combinator that "composes" two generative functions, called `cond_combinator`.
import jax
import jax.numpy as jnp
from genjax import gen, cond_combinator, bernoulli


# We can first define the two branches as generative functions
@gen
def branch_1(p):
    v = bernoulli(p) @ "v1"
    return v


@gen
def branch_2(p):
    v = bernoulli(-p) @ "v2"
    return v


# Then we use the combinator to compose them
@gen
def cond_model(p):
    pred = jnp.int32(p > 0)
    arg_1 = (p,)
    arg_2 = (p,)
    v = cond_combinator(branch_1, branch_2)(pred, arg_1, arg_2) @ "cond"
    return v


key = jax.random.PRNGKey(314159)
jitted = jax.jit(cond_model.simulate)
tr = jitted(key, (0.0,))
print(tr.get_sample())
print(tr.get_sample()[("cond", "v1")])
# print(tr.get_sample()[("cond", "v2")]) # This will fail because the key is not in the trace, as the branch was not taken

In [None]:
# Note that it may be possible to write the following down, but this will not give you what you want in general!
# TODO: find a way to make it fail to better show the point.
from genjax import gen
from genjax import bernoulli
import jax


@gen
def simple_cond_genjax(p):
    branch_1 = lambda p: bernoulli(p) @ "v1"
    branch_2 = lambda p: bernoulli(-p) @ "v2"
    cond = jax.lax.cond(p > 0, branch_1, branch_2, p)
    return cond


key = jax.random.PRNGKey(314159)
tr1 = simple_cond_genjax.simulate(key, (0.3,))
tr2 = simple_cond_genjax.simulate(key, (-0.4,))
print(tr1.get_retval())
print(tr2.get_retval())

In [None]:
# Alternatively, if we have more than two branches, in JAX we can use the `jax.lax.switch` function
def simple_switch_jax(p):
    index = jnp.floor(jnp.abs(p)).astype(jnp.int32) % 3
    branches = [lambda p: 2 * p, lambda p: -p, lambda p: p]
    switch_res = jax.lax.switch(index, branches, p)
    return switch_res


print(simple_switch_jax(0.3))
print(simple_switch_jax(1.1))
print(simple_switch_jax(2.3))

In [None]:
# Likewise, in GenJAX we can use the `switch_combinator` if we have more than two branches
# We can first define three branches as generative functions
from genjax import switch_combinator, normal


@gen
def branch_1(p):
    v = normal(p, 1.0) @ "v1"
    return v


@gen
def branch_2(p):
    v = normal(-p, 1.0) @ "v2"
    return v


@gen
def branch_3(p):
    v = normal(p * p, 1.0) @ "v3"
    return v


# Then we use the combinator to compose them
@gen
def switch_model(p):
    index = jnp.floor(jnp.abs(p)).astype(jnp.int32) % 3
    v = switch_combinator(branch_1, branch_2, branch_3)(index, (p,), (p,), (p,)) @ "s"
    return v


key = jax.random.PRNGKey(0)
jitted = jax.jit(switch_model.simulate)
tr = jitted(key, (0.0,))
print(tr.get_sample()[("s", "v1")])
tr = jitted(key, (1.1,))
print(tr.get_sample()[("s", "v2")])
tr = jitted(key, (2.2,))
print(tr.get_sample()[("s", "v3")])


# We can rewrite the above a bit more elegantly using the *args syntax
@gen
def switch_model_v2(p):
    index = jnp.floor(jnp.abs(p)).astype(jnp.int32) % 3
    branches = [branch_1, branch_2, branch_3]
    args = [(p,), (p,), (p,)]
    v = switch_combinator(*branches)(index, *args) @ "switch"
    return v


jitted = switch_model_v2.simulate
tr = jitted(key, (0.0,))
print(tr.get_sample()[("switch", "v1")])