In [None]:
import jax
import jax.numpy as jnp

import genjax

In [None]:
genjax.incremental

In [None]:
genjax.incremental

In [None]:
# One wart arises immediately:
# * one wart is that -- the signatures of update and regenerate
#   means that the concerns of "doing a Monte Carlo update"
#   tied with incremental computation in Gen.
#
# Feras really hates about Gen.

# class YourGenFn(GenerativeFunction):
#     # Non-incremental
#     def edit(..., args: Tuple):
#         pass

#     # Incremental tracediff (monoid) stuff
#     #
#     # Still doesn't solve "separation of concerns"
#     # we're computing the "incremental R-N derivative" of
#     # a generative function ... and maybe that concept
#     # we can't disentangle "incremental changes to values"
#     # from "incremental changes to traces"
#     #
#     # Distinguish between two types of "diffs":
#     # * Tracediff -- the monoid of tangents on traces
#     # * Valuediff -- a monoid of tangents on values
#     #   * UnknownChange, NoChange live in.
#     # Related to: static incremental lambda calculus

#     # score in trace is a w s.t. ..
#     # depends how/when it was computed. usually it's either
#     # the simulate or the assess guarantee, idk if it's consistent
#     # vs `edit` which has a very consistent guarantee propagation

#     # score <- simulate(...).get_score()
#     # score == log p(chm, x)
#     #    where chm ~ p(\cdot; x)
#     #
#     # when there is untraced randomness:
#     # log_score == log p(chm, r; x) - log q(r; x)
#     # guarantee: E_{p(chm | r; x)}[ 1 / score] = 1 / p(chm; x)
#     # score = p(chm, r; x) / q(r; x)
#     # E_{p(chm | r; x)}[1 / score]
#     #   = E_{...}[q(r; x) / p(chm, r; x)]
#     #   = \int q(r; x) (1 / p(chm; x)) dr
#     #   = (1 / p(chm; x)) \int q(r; x) dr = 1 / p(chm; x)
#     #
#     #
#     # tr, _ <- importance(..., chm) -- provided constraint to importance
#     # score = tr.get_score()
#     # score == log p(chm', x)
#     # where chm' ~ q(\cdot; chm, x)
#     #
#     # gen_fn had q(\cdot; chm, x) as the internal proposal
#     # SIRCombinator(gen_fn)
#     def edif(..., diffs: Diff)
#         -> tuple[td: Tracediff, retdiff: Diff, ...]:
#         pass

#     def update(..., args: Diff)
#         -> tuple[..., retval: Diff]:
#         pass

#     def regenerate(..., args: Diff)
#         -> tuple[..., retval: Diff]:
#         pass

In [None]:
genjax.core.interpreters.incremental  # program transformation

In [None]:
def fn(x, y):
    x = x + y
    q = x**2
    return q

In [None]:
jaxpr = jax.make_jaxpr(fn)(3.0, 3.0)
jaxpr

In [None]:
# genjax.core.interpreters.incremental(fn)(
#     None,
#     (SomePytree(3.0), ),
#     (SomePytree(genjax.incremental.NoChange), )
# )

In [None]:
#  jvp(vmap(fn))
# how many times do you make a jaxpr?

In [None]:
# it depends on how vmap and jvp are implemented as transformations
# vmap(fn) -> Jaxpr -> this Jaxpr will be different than `make_jaxpr(fn)`
# jvp(vmap(fn)) -> your making 2 Jaxprs if jvp also requires a Jaxpr as input

In [None]:
def fn(x, y):
    z = x + y
    q = jax.vmap(lambda z: z**2)(z)
    return q

In [None]:
jax.make_jaxpr(fn)(jnp.ones(5), 1.0)  # will get shape ()

In [None]:
jax.make_jaxpr(fn)(jnp.ones(10), 1.0)  # will get shape ()

In [None]:
jax.make_jaxpr(fn)(jnp.ones((5, 5)), 1.0)  # lambda z: z ** 2 will get shape (5,)

In [None]:
def g(x):
    v = x**2
    return v


def fn(x):
    return g(x**2)

In [None]:
genjax.trace

In [None]:
@genjax.gen
def g(x):
    x = genjax.trace(("x",), genjax.normal, (x, 1.0))  # GenJAX primitive
    return x**2


@genjax.gen
def f(x):
    v = genjax.trace(("v",), g.vmap(in_axes=(0,)), (x,))  # GenJAX primitive
    q = v * 3
    y = q**2
    return y

In [None]:
jax.make_jaxpr(f.source)(
    jnp.ones(5),
)

In [None]:
jax.make_jaxpr(f.simulate)(jax.random.key(1), (jnp.ones(5),))