In [22]:
import genjax
from genjax import EditRequest, Pytree, Address, Argdiffs, Trace, Weight, Retdiff, Update, ChoiceMap
from genjax.typing import PRNGKey, FloatArray

from jax import vmap
import jax.random as jrand
import jax.numpy as jnp
import jax.tree_util as jtu
from jax.scipy.special import logsumexp

import matplotlib.pyplot as plt

In [None]:
def run_mh(mcmc_edit_request, tr):
    new_tr, w, *_ = tr.edit(mcmc_edit_request)
    # do accept reject
    # ...
    return final_trace

In [None]:
new_trace = old_trace.mh(mcmc_edit_request)

## What are K and L here?

In [160]:
# x: X, joint space is X x Cat(n) x X
# K(i, x' | x)
# L(i, x | x') -- we'll only use L to evaluate the density of (i, x). 

def K(x : continuous):
    grid = grid_around(x)

def L(...):
    pass

NameError: name 'continuous' is not defined

In [None]:
# x : X
# pair of things we care about for proper weighting: P, Q
# proper weighting criteron: E_Q[w(x) * f(x)] = \int P * f
# w some function
#
# (x, w) ~ χ(dx, dw), s.t. Eχ[w * f(x)] = \int P * f

## Code

In [180]:
@Pytree.dataclass(match_args=True)
class GridRejuvenation(EditRequest):
    addr: Address
    grid_radius: FloatArray
    N: int = Pytree.static(default=100)

    def grid_around(self, v):
        grid_containing_center = jnp.hstack([jnp.linspace(v - self.grid_radius, v + self.grid_radius, N), v]).sort()
        idx = jnp.floor(self.N / 2).astype(int)
        return grid_containing_center, idx

    def edit(self,
             key: PRNGKey,
             tr: Trace,
             argdiffs: Argdiffs
            ) -> tuple[Trace, Weight, Retdiff, EditRequest]:
        choices = tr.get_choices()
        v = choices[self.addr]
        grid, old_latent_idx = self.grid_around(v)

        def _updates(key, v):
            chm = ChoiceMap.entry(v, self.addr)
            request = Update(chm)
            return tr.edit(key, request, argdiffs)

        # Forward move.
        key, sub_key = jrand.split(key)
        sub_keys = jrand.split(sub_key, len(grid))
        new_trs, ws, retdiffs, bwd_requests = vmap(_updates)(sub_keys, grid)
        
        ## TODO: confirm below...
        ## Turns out we can use the weights for this... 
        ## if we used the scores, the parts that aren't changed would cancel, 
        ## but the weight is exactly the change in density for the parts that are changed.
        
        idx = genjax.categorical.sample(key, ws)
        new_tr, w, retdiff, bwd_request = jtu.tree_map(lambda v: v[idx], 
                    (new_trs, ws, retdiffs, bwd_requests)
        ) 
        # TODO: this could be right? ... check with Jay!
        #
        # the weight `w` here is the P' / P ratio in SMCP3... why?
        # because that's exactly what Update computes
        # when Q doesn't have to propose anything.
        # P' / P = new_tr.get_score() - old_trace.get_score()
        # E_[w] = P' / 
        # w = (P' / P) * (1 / Q)
        # => E[w] = P' / P
        # (x, w_) p.w. for P under some Q_, then x', w_ * w
        # (x', w_ * w) p.w. for P' under K from the move.
        # for that to be true, we need to have E[w] = P' / P
        
        # What is the density K(i | x)?
        K_density = genjax.categorical.logpdf(idx, ws) # TODO: hope this is numerically fine TFP probability.

        L_density = 1 / len(grid)

        # TODO: smart L
        #
        # Now, we need to construct a backwards proposal:
        # we'll use the same gridding function, centered at the new latent.
        # discard = bwd_request.constraint
        # discarded_value = discard[self.addr]
        # backwards_grid, new_latent_idx = self.grid_around(discarded_value)
        # _, bws, *_ = vmap(_updates)(sub_keys, backwards_grid)
        # normalized_bws = bws - logsumexp(bws)
        # Now, we need to find the index of the discarded latent.
        # The discarded latent was at old floor(N / 2) in the old grid, and then
        # we sampled a new grid point *.
        #
        #            old   *
        # [0.0, 0.2, 0.4, 0.8, 1.0]   -> * at idx 3.
        #                                We shift by (idx *) - (idx old).
        #
        #       old   *
        # [0.2, 0.4, 0.8, 1.0, 1.2]   -> * at idx 3.
        #                                To get old, we take (idx *) and subtract (

        # Final SMCP3 weight computation.
        final_weight = w + L_density - K_density
        
        return new_tr, final_weight, retdiff, bwd_request

# tester rolls

In [181]:
original_array = jnp.array([-2, -1, 0, 1, 2])
original_array

Array([-2, -1,  0,  1,  2], dtype=int32)

In [182]:
center = jnp.floor(len(original_array) / 2).astype(int)
original_array[center]

Array(0, dtype=int32)

In [183]:
select_idx = 1
roll_shift = center - select_idx
rolled = jnp.roll(jnp.array([-2, -1, 0, 1, 2]), roll_shift)
rolled

Array([ 2, -2, -1,  0,  1], dtype=int32)

In [184]:
jnp.roll(rolled, -roll_shift)

Array([-2, -1,  0,  1,  2], dtype=int32)

## tester grid rolls

In [185]:
original = jnp.arange(0.3 - 0.1, 0.3 + 0.1, step=0.01)
original[jnp.floor(len(original) / 2).astype(int)]

Array(0.2999999, dtype=float32)

In [186]:
def grid_around(v, radius, step_size):
    grid_containing_center = jnp.arange(v - radius, v + radius, step_size)
    center = jnp.floor(len(grid_containing_center) / 2).astype(int)
    return grid_containing_center, center

In [187]:
original, center = grid_around(0.3, 0.5, 0.01)

In [188]:
original[center]

Array(0.30000025, dtype=float32)

In [189]:
select_idx = 34
roll_shift = center - select_idx

In [190]:
new = 0.3 + original[select_idx]
new

Array(0.44000018, dtype=float32)

In [191]:
backwards, center = grid_around(new, 0.5, 200)
jnp.roll(backwards, -roll_shift)[center]

Array(-0.05999982, dtype=float32)

## tester boy

In [192]:
@genjax.gen
def model():
    v = genjax.normal(0.0, 1.0) @ "v"
    y = genjax.normal(v, 1.0) @ "y"
    return y

In [193]:
key = jrand.key(1)
tr, _ = model.importance(key, ChoiceMap.kw(y=3.0), ())

In [194]:
tr.get_choices()["y"]

3.0

In [195]:
tr.get_choices()["v"]

Array(0.7158846, dtype=float32)

In [196]:
# Gibbs says:
# say you have P(v, x, y) -- we could observe y. 
# Note: P(v, x, y) \propto P(v, x | y), \propto P(v | x, y), \propto P(y | x, v)
#
# Start with (v_0, x_0, y) ... where does this come from? (v_0, x_0) comes from some Q(\cdot; y)
# 1. Sample v_1 ~ P(v | x_0, y)
# 2. Sample x_1 ~ P(x | v_1, y)
#     ...

In [197]:
request = GridRejuvenation(
    "v", # addr I want to rejuvenate at,
    0.3, # auxiliary data related to the grid
)

In [199]:
# TODO: API
# Rejuvenation(addr, RejuvenationObject)
#  Example: Rejuvenation(addr, Grid)
#  ...

In [200]:
new_tr, w, *_ = request.edit(key, tr, ())
new_tr.get_choices()

Static({
  'v': Choice(v=<jax.Array(0.850386, dtype=float32)>),
  'y': Choice(v=<jax.Array(3., dtype=float32, weak_type=True)>),
})

In [201]:
w

Array(7.02363, dtype=float32)

In [None]:
# The desired workflow:
#
# tr, w = gen_fn.p_importance(...)
# tr.visualize(DimensionPlot(axis=0, {"x" => Histogram(...), "pose" => 3DRender(...)}))
# request = RejuvenationGibbs("x", ...)
# new_tr, _ = request.edit(tr, ...)
# new_tr.visualize(DimensionPlot(axis=0, {"x" => Histogram(...), "pose" => 3DRender(...)}))

# Trace bucket list (things todo to make trace really nice to use and a true zero-cost abstraction):
# * Tracediff "incremental trace monoid"
# * Correct batch semantics -> 
#       (a) POPL: we have a plan for how some of the GFI works on traces in batch.
#          (i) We need to extend POPL treatment to edit ^ also tied to incremental trace monoid stuff.
#       (b) Say we get a batch semantics that we're happy with -- I think we should know: where does resampling go?
# * Trace visualization -- we probably want a interface to Huebert's stuff that looks like `trace.visualize({"x" => Histogram(...), ("y", "z") => SomeOtherPlot})`
#   (a) We want to make traces more readable as plaintext.
#   (b) As easy as it is to Python `print` a trace as plaintext, it should be easy to generate a "tree-like" GenStudio object to visually explore the trace.
#
# Limitations of current implementation with respect to the above bucket list:
# (b) Memory: updating a trace with a big footprint can be costly:
#     (i) Say we do GridRejuvenation on pose in Gen3D -- this makes a grid, and then renders a change
#         we get back a batched trace (batched across the grid) containing copies of the ground truth image.
#    # solution: implement Mathieu's design for out_axis option in vmap combinator      
    # idea for solution: have a batch-size option in vmap combinator (maybe call map instead of vmap behindn the scene if option provided)

In [None]:
# George's proposal -- how do I couple exact inference logic into my models?
# he was concerned about conjugate updates.
#
# request = ConjugateUpdateGibbs("x", ...)
# soln:
#   conjugate updates are not something that we should _change the GFI_ for
#   conjugate updates should be offered by a GenFn language which is "Markov blanket aware"
#   GenFn knows about its P distribution.
#   it uses its knowledge to expose ConjugateUpdate("x", ...)
#   this is offered as an EditRequest, and can then be used compositional with `edit`.
#
# This special language could offer analysis of its code, its GenFns could provide slices of themselves via
# an interface.
#
# How does this thing interact with the GFI? We don't change the GFI, this language exposes the use of these capabilities via
# EditRequest.
#
# Construct a generative function via blocks...
#   "Everything is a combinator"
#   gen_fn = PushforwardByKernel(Input(), Lambda(Input(), ...)
#   gen_fn.get_markov_blanket("x") -> George's thing...
#
# (very similar Colin's block demo language) 
# Pedagogy: teach GenLM to generate generative functions from this AST grammar
# write a WCFG for GenLM, and be able to give to student's that know nothing about GenJAX
# and they can use it to make generative functions.
# 
# For these generative functions -- analysis lets us have good conjugate updates, we also have Gibbs and enumerative Gibbs
# A (new student) asks for a generative function in this "blocks language" (in natural language to GenLM), 
# and we give them back a generative function, and some inference logic.
#
# Better starting place: can we teach GenLM to make "blocks" generative functions from natural language.

# Pedagogy: getting GenLM to write programs in a GenJAX "sub-language" with good inference baked in. 
# ... what about high school students?
# ... what about stupid VCs?
# 
# GenLM / LLM reads the GenJAX documentation and cookbook and then we plug it in...
#   * how do we imagine the in-context-on-genjax LLM to help?
#     * fancy autocomplete -- this totally makes sense to me today, and it makes sense how it would work
#     * help you iterate on inference -- this is more murky
#     * I'd love if we had a gold standard set of inference "APIs": GridRejuvenation, Gibbs, ...
#     * Ask the LLM: what should I try? And then it says GridRejuvenation...