How do I use conditionals in Jax?

In [1]:
import genjax
import jax
from genjax import bernoulli

# In pure Python, we use usual conditionals
@interpreted_gen_fn
def simple_switch(p):
    if p > 0:
        v = bernoulli(p) @ "v"
        return v
    else:
        v = bernoulli(-p) @ "v"
        return v
    
key = jax.random.PRNGKey(314159)
tr1 = simple_switch.simulate(key, (0.3,))
tr2 = simple_switch.simulate(key, (-0.4,))
print(tr1)
print()
print(tr2)
# But this will be very slow
%timeit simple_switch.simulate(key, (0.3,))

NameError: name 'interpreted_gen_fn' is not defined

In [2]:
# In pure Jax, we write conditionals with jax.lax.cond as follows
def simple_cond(p):
    pred = p > 0
    branch_1 = lambda p: p
    branch_2 = lambda p: -p
    arg_of_cond = p
    cond_res = jax.lax.cond(pred, branch_1, branch_2, arg_of_cond)
    return cond_res

print(simple_cond(0.3))
print(simple_cond(-0.4))

0.3
0.4


In [4]:
# The restriction is that both branches should have the same return type
def failing_simple_cond(p):
    pred = p > 0
    branch_1 = lambda p: (p,p)
    branch_2 = lambda p: -p
    arg_of_cond = p
    cond_res = jax.lax.cond(p > 0, branch_1, branch_2, arg_of_cond)
    return cond_res

try: 
    print(failing_simple_cond(0.3))
except:
    print("TypeError: true_fun and false_fun output must have same type structure, got PyTreeDef((*, *)) and PyTreeDef(*).")

#TODO: add counter-example with same pytree structure but different type float vs int

TypeError: true_fun and false_fun output must have same type structure, got PyTreeDef((*, *)) and PyTreeDef(*).


In [10]:
# In GenJax, the syntax is a bit different still.
# Similarly to Jax having a custom primitive jax.lax.cond that "composes" two
# functions, GenJax has a custom combinator that "composes" two generative 
# functions.
import genjax
import jax
from genjax import static_gen_fn

# We first define the two branches as generative functions
@static_gen_fn
def branch_1(p):
    v = bernoulli(p) @ "v1"
    return v

@static_gen_fn
def branch_2(p):
    v = bernoulli(-p) @ "v2"
    return v

# Then we use the combinator to compose them
switch = genjax.switch_combinator(branch_1, branch_2)
key = jax.random.PRNGKey(314159)
jitted = jax.jit(switch.simulate)
tr = jitted(key, (0,))
v1 = tr.get_sample()["v1"]
v2 = tr.get_sample()["v2"]

FileNotFoundError: [Errno 2] No such file or directory

In [7]:
# Note that it may be possible to write the following down, but this will not give you what you want in general!
from genjax import static_gen_fn
from genjax import bernoulli
import jax

@static_gen_fn
def simple_switch_genjax(p):
    branch_1 = lambda p: bernoulli(p) @ "v1"
    branch_2 = lambda p: bernoulli(-p) @ "v2"
    cond = jax.lax.cond(p > 0, branch_1, branch_2, p)
    return cond

key = jax.random.PRNGKey(314159)
tr1 = simple_switch_genjax.simulate(key, (0.3,))
tr2 = simple_switch_genjax.simulate(key, (-0.4,))
print(tr1.get_retval())
print(tr2.get_retval())

1
1
