In [28]:
import genjax
import jax
import jax.numpy as jnp
from genjax import vi
import jax.tree_util as jtu
from genjax import ChoiceMap as Chm

global_args = (jnp.ones(10, dtype=float), )

key = jax.random.PRNGKey(314159)

In [29]:
@genjax.gen
def scratch_model(v):
    global_theta = genjax.normal(0.0, 1.0) @ "theta"

    @genjax.gen
    def submodel(global_theta, v):
        z = genjax.normal(global_theta, v) @ "z"
        x = genjax.normal(z, 3.0) @ "x"

    return submodel.vmap(in_axes=(None, 0))(global_theta, v) @ "local"

In [30]:
tr = scratch_model.simulate(key, global_args)
tr.get_choices()

Static({
  'local': Indexed(
    c=Static({
      'x': Choice(
        v=# jax.Array float32(10,) ≈-0.43 ±2.5 [≥-4.0, ≤3.6] nonzero:10
          Array([-2.7239337,  3.5539687,  1.236969 , -2.260299 ,  3.1946638,
                  1.424921 , -2.282737 , -1.604079 , -0.8603655, -4.0168695],      dtype=float32)
        ,
      ),
      'z': Choice(
        v=# jax.Array float32(10,) ≈-0.17 ±1.0 [≥-1.9, ≤1.2] nonzero:10
          Array([-1.056706  ,  0.7125043 ,  0.92436844, -1.9272702 ,  0.45415905,
                 -0.6104184 ,  0.18161982, -1.4966843 , -0.1501036 ,  1.2360518 ],      dtype=float32)
        ,
      ),
    }),
    addr=<jax.Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)>,
  ),
  'theta': Choice(v=<jax.Array(-0.10823099, dtype=float32)>),
})

In [31]:
from genjax import ChoiceMapBuilder as C
chm = C["local", jnp.arange(10), "x"].set(jnp.ones(10))
chm

Static({
  'local': Indexed(
    c=Static({
      'x': Choice(
        v=<jax.Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)>,
      ),
    }),
    addr=<jax.Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)>,
  ),
})

In [32]:
scratch_model.importance(key, chm, global_args)

(StaticTrace(
   gen_fn=scratch_model,
   args=(
     <jax.Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)>,
   ),
   retval=None,
   addresses=AddressVisitor(visited=[('theta',), ('local',)]),
   subtraces=[
     DistributionTrace(
       gen_fn=ExactDensityFromCallables(sampler=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.sampler at 0x143420ae0>), logpdf_evaluator=Closure(dyn_args=(), fn=<function tfp_distribution.<locals>.logpdf at 0x143420b80>)),
       args=(0.0, 1.0),
       value=<jax.Array(-0.10823099, dtype=float32)>,
       score=<jax.Array(-0.9247955, dtype=float32)>,
     ),
     VmapTrace(gen_fn=VmapCombinator(gen_fn=StaticGenerativeFunction(source=Closure(dyn_args=(), fn=<function scratch_model.<locals>.submodel at 0x17e277740>)), in_axes=(None, 0)), inner=StaticTrace(gen_fn=StaticGenerativeFunction(source=Closure(dyn_args=(), fn=<function scratch_model.<locals>.submodel at 0x17e277740>)), args=(<jax.Array float32(10,) ≈-0.11 ±7.5e-09 [≥-0.11

## Various little test snippets

In [33]:
from genjax import vi

# We add params here because our VI code in the library works with GenSP concepts,
# and what's going to happen is that the "proposal" (here, that's our scratch family)
# gets to access the target of inference (meaning: the original generative function, 
# the observation constraints, and the arguments to the gen fn).
@genjax.gen
def scratch_model(params, v):
    z = genjax.normal(v, 1.0) @ "z"
    x = genjax.normal(z, 0.05) @ "x"
    
# this will be our variational family --
# the variational family is a parametrized distribution, which
# we use to search over in variational inference.
@genjax.gen
def scratch_family(tgt: genjax.Target):
    (params, v) = tgt.args
    (x_shifted, ) = params
    chm = tgt.constraint
    x = chm["x"]
    
    # TODO: one can totally change this logic.
    z = vi.normal_reparam(x_shifted * x, v) @ "z"

chm = C["x"].set(1.0)
init_params = (1.0, )

grad_estimator = vi.ELBO(
    scratch_family.marginal(), # this is a "SampleDistribution"
    lambda *args: genjax.Target(scratch_model, (args, 1.0), chm),
)
grad_estimator(key, init_params)

(Array(-43.400658, dtype=float32, weak_type=True),)

In [34]:
from genjax import vi

# We add params here because our VI code in the library works with GenSP concepts,
# and what's going to happen is that the "proposal" (here, that's our scratch family)
# gets to access the target of inference (meaning: the original generative function, 
# the observation constraints, and the arguments to the gen fn).
@genjax.gen
def scratch_model(params, v):
    @genjax.gen
    def submodel(v):
        z = genjax.normal(v, 1.0) @ "z"
        x = genjax.normal(z, 0.05) @ "x"
        
    return submodel.vmap(in_axes=(0, ))(v) @ "local"
    
# this will be our variational family --
# the variational family is a parametrized distribution, which
# we use to search over in variational inference.
@genjax.gen
def scratch_family(tgt: genjax.Target):
    (params, v) = tgt.args
    (x_shifted, ) = params
    chm = tgt.constraint
    x = chm["local", ..., "x"]

    @genjax.gen
    def submodel(x, x_shifted, v):
        z = vi.normal_reparam(x_shifted * x, v) @ "z"
        
    return submodel.vmap(in_axes=(0, None, 0))(x, x_shifted, v) @ "local" 

chm = C["local", jnp.array([0]), "x"].set(jnp.array([1.0], dtype=float))
init_params = (1.0, )

grad_estimator = vi.ELBO(
    scratch_family.marginal(), # this is a "SampleDistribution"
    lambda *args: genjax.Target(scratch_model, (args, jnp.array([1.0])), chm),
)
grad_estimator(key, init_params)

(Array(-43.400658, dtype=float32, weak_type=True),)

## VI for real

In [64]:
# Example WITHOUT a scan
#
# We add params here because our VI code in the library works with GenSP concepts,
# and what's going to happen is that the "proposal" (here, that's our scratch family)
# gets to access the target of inference (meaning: the original generative function, 
# the observation constraints, and the arguments to the gen fn).
@genjax.gen
def scratch_model(params, v):
    global_theta = genjax.normal(0.0, 10.0) @ "theta"
    
    @genjax.gen
    def submodel(global_theta, v):
        z = genjax.normal(global_theta, v) @ "z"
        x = genjax.normal(z, 0.05) @ "x"

    return submodel.vmap(in_axes=(None, 0))(global_theta, v) @ "local"
    
# this will be our variational family --
# the variational family is a parametrized distribution, which
# we use to search over in variational inference.
@genjax.gen
def scratch_family(tgt: genjax.Target):
    (params, v) = tgt.args
    data = tgt.constraint
    (θ_p, x_shifted) = params
    global_theta = vi.normal_reparam(θ_p, 0.2) @ "theta"
    
    # This is our observation.
    # this will be like your "ys" in your model
    # with the lgssm likelihood.
    x = data["local", ..., "x"]
    
    @genjax.gen
    def submodel(global_theta, x, x_shifted, v):
        # TODO: one can totally change this logic.
        z = vi.normal_reparam(global_theta, v) @ "z"

    return submodel.vmap(in_axes=(None, 0, None, 0))(
        global_theta, x, x_shifted, v
    ) @ "local"

chm = C["local", jnp.arange(10), "x"].set(jnp.ones(10, dtype=float))
init_params = (1.0, 0.2)

grad_estimator = vi.ELBO(
    scratch_family.marginal(), # this is a "SampleDistribution"
    lambda *args: genjax.Target(scratch_model, (args, 0.3 * jnp.ones(10, dtype=float)), chm),
)
jitted = jax.jit(grad_estimator)

# Sample training loop:
print(init_params)
for i in range(100):
    params_grads = jitted(key, init_params)
    init_params = jtu.tree_map(lambda v, g: v - 1e-4 * g, init_params, params_grads)
    print(init_params)

(1.0, 0.2)
(Array(1.0216452, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.0346323, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.0424246, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.0471, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.0499051, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.0515882, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.052598, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.0532039, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.0535675, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.0537857, dtype=float32, weak_type=True), Array(0.2, dtype=float32, weak_type=True))
(Array(1.0539166, dtype=float32, weak_type=True), Array(0.2, dt

## VI (closer to the real model)

In [83]:
# Example with a scan
#
# We add params here because our VI code in the library works with GenSP concepts,
# and what's going to happen is that the "proposal" (here, that's our scratch family)
# gets to access the target of inference (meaning: the original generative function, 
# the observation constraints, and the arguments to the gen fn).
@genjax.gen
def scratch_model(params, v):
    global_theta = genjax.normal(0.0, 10.0) @ "theta"
    
    @genjax.gen
    def submodel(global_theta, v):
        z = genjax.normal(global_theta, v) @ "z"

        def _inner(carry, scanned_in):
            z = carry
            v = scanned_in
            return z, z + v
            
        _, z = jax.lax.scan(_inner, z, jnp.zeros(10))
        
        x = genjax.normal(jnp.sum(z) / 10.0, 0.1) @ "x"

    return submodel.vmap(in_axes=(None, 0))(global_theta, v) @ "local"
    
# this will be our variational family --
# the variational family is a parametrized distribution, which
# we use to search over in variational inference.
@genjax.gen
def scratch_family(tgt: genjax.Target):
    (params, v) = tgt.args
    data = tgt.constraint
    (θ_p, x_shifted) = params
    global_theta = vi.normal_reparam(θ_p, 0.2) @ "theta"
    
    # This is our observation.
    # this will be like your "ys" in your model
    # with the lgssm likelihood.
    x = data["local", ..., "x"]
    
    @genjax.gen
    def submodel(global_theta, x, x_shifted, v):
        # TODO: one can totally change this logic.
        z = vi.normal_reparam(global_theta, v) @ "z"

    return submodel.vmap(in_axes=(None, 0, None, 0))(
        global_theta, x, x_shifted, v
    ) @ "local"

chm = C["local", jnp.arange(10), "x"].set(jnp.ones(10, dtype=float))
init_params = (1.0, 0.2)

grad_estimator = vi.ELBO(
    scratch_family.marginal(), # this is a "SampleDistribution"
    lambda *args: genjax.Target(scratch_model, (args, 0.3 * jnp.ones(10, dtype=float)), chm),
)
jitted = jax.jit(grad_estimator)
jitted_vmap = jax.jit(jax.vmap(grad_estimator, in_axes=(0, None)))

# Sample training loop:
print(init_params)
key = jax.random.PRNGKey(314159)
for i in range(100):
    key, sub_key = jax.random.split(key)
    sub_keys = jax.random.split(sub_key, 64)
    params_grads = jitted_vmap(sub_keys, init_params)
    params_grads = jtu.tree_map(lambda v: jnp.mean(v, axis=0), params_grads)
    init_params = jtu.tree_map(lambda v, g: v - 1e-5 * g, init_params, params_grads)
    print(init_params)

(1.0, 0.2)
(Array(1.0003084, dtype=float32), Array(0.2, dtype=float32))
(Array(1.0004424, dtype=float32), Array(0.2, dtype=float32))
(Array(1.0011479, dtype=float32), Array(0.2, dtype=float32))
(Array(1.0001657, dtype=float32), Array(0.2, dtype=float32))
(Array(1.000024, dtype=float32), Array(0.2, dtype=float32))
(Array(0.99993193, dtype=float32), Array(0.2, dtype=float32))
(Array(0.99958915, dtype=float32), Array(0.2, dtype=float32))
(Array(0.99945563, dtype=float32), Array(0.2, dtype=float32))
(Array(0.99961483, dtype=float32), Array(0.2, dtype=float32))
(Array(0.99998873, dtype=float32), Array(0.2, dtype=float32))
(Array(0.99898916, dtype=float32), Array(0.2, dtype=float32))
(Array(0.9995703, dtype=float32), Array(0.2, dtype=float32))
(Array(0.99862176, dtype=float32), Array(0.2, dtype=float32))
(Array(0.9984746, dtype=float32), Array(0.2, dtype=float32))
(Array(0.9981054, dtype=float32), Array(0.2, dtype=float32))
(Array(0.99866647, dtype=float32), Array(0.2, dtype=float32))
(Array