In [1]:
import genjax
import jax

key = jax.random.PRNGKey(314159)

console = genjax.pretty(show_locals=True)

In [2]:
@genjax.gen
def model(v):
    x = genjax.Normal(v, 1.0) @ "x"
    v = genjax.Normal(x, 3.0) @ "v"
    return x**2 + v

In [3]:
model

In [4]:
prog = genjax.adev.lang(model)
prog

In [5]:
key, v = prog.simulate(key, (1.0,))
v

In [6]:
key, out_tangents = prog.grad_estimate(key, (1.0,), (1.0,))
out_tangents

In [7]:
mod = genjax.module(lambda v: v + genjax.param(5.0, name="x"))(3.0)

In [8]:
mod(3.0)

In [9]:
mod.params["x"].value

In [10]:
key, v = prog(key, 3.0)
v

In [11]:
jax.jvp(prog, (key, 3.0), (key, 1.0))

In [None]:
f = lambda v: prog(key, v)
(key, v), tangents = jax.jvp(f, (2.0,), (1.0,))
tangents