What are some sharp edges of Jax that are good to know before really starting?

In [14]:
#1] Jax expects arrays/tuples everywhere
from genjax import bernoulli, gen
@gen
def f(p):
    v = bernoulli(p) @"v"
    return v

key = jax.random.PRNGKey(0)
# First way of failing
try:
    f.simulate(key, 0.5)
except Exception as e:
    print(e)
    
# Second way of failing
try:
    f.simulate(key, [0.5])
except Exception as e:
    print(e)
    
# Third way of failing
try:
    f.simulate(key, (0.5))
except Exception as e:
    print(e)    
    
# Correct way
f.simulate(key, (0.5,))

Method[1m[35m genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction.simulate()[0m parameter [1m[34margs[0m=[1m[31m0.5[0m violates type hint [1m[32m<class 'tuple'>[0m, as [1m[33mfloat [0m[1m[31m0.5[0m not instance of [1m[32mtuple[0m.
Method[1m[35m genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction.simulate()[0m parameter [1m[34margs[0m=[1m[31m[0.5][0m violates type hint [1m[32m<class 'tuple'>[0m, as [1m[33mlist [0m[1m[31m[0.5][0m not instance of [1m[32mtuple[0m.
Method[1m[35m genjax._src.generative_functions.static.static_gen_fn.StaticGenerativeFunction.simulate()[0m parameter [1m[34margs[0m=[1m[31m0.5[0m violates type hint [1m[32m<class 'tuple'>[0m, as [1m[33mfloat [0m[1m[31m0.5[0m not instance of [1m[32mtuple[0m.


StaticTrace(
  gen_fn=StaticGenerativeFunction(
    source=Closure(dyn_args=(), fn=<function f at 0x2c7bff420>),
  ),
  args=(0.5,),
  retval=<jax.Array(1, dtype=int32)>,
  addresses=AddressVisitor(visited=['v']),
  subtraces=[
    DistributionTrace(
      gen_fn=ExactDensity(
        sampler=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.sampler at 0x16dd92fc0>),
        logpdf_evaluator=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.logpdf at 0x16dd93060>),
      ),
      args=(0.5,),
      value=<jax.Array(1, dtype=int32)>,
      score=<jax.Array(-0.474077, dtype=float32)>,
    ),
  ],
  score=<jax.Array(-0.474077, dtype=float32)>,
)

In [83]:
# 2] We rely on Tensor Flow Probability and it sometimes does weird things.

# Bernoulli distribution uses logits instead of probabilities
from genjax import bernoulli, gen
@gen
def g(p):
    v = bernoulli(p) @ "v"
    return v

key = jax.random.PRNGKey(0)
arg = (3.,)  # 3 is not a valid probability but a valid logit
keys = jax.random.split(key, 20)
# simulate 20 times
for key in keys:
    print(g.simulate(key, arg).get_sample()["v"])
    
# Values which are stricter than 0 are considered to be the value True.
# This means that observing that the value of "v" is 4 will be considered possible while intuitively "v" should only have support on 0 and 1.


1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
