In [1]:
# TODO: SMC over time & over space
# TODO: maybe SMCP3, though it may be better separate?

# present the HMM again, and add resampling step.

import genjax
import jax
from jax import numpy as jnp
from jax import jit
                                
# simple mv hmm
@genjax.scan_combinator(max_length=100)
@genjax.gen
def hmm(args, c):
    x, y = args
    new_x = genjax.normal(x, y) @ "new_x"
    return (new_x, y), None

initial_x = 0.0
initial_y = 0.0
key = jax.random.PRNGKey(0)
trace = hmm.simulate(key, ((initial_x, initial_y), None))
print(trace.get_sample()[...,"new_x"])

# v2 with mv normal
import genjax
import jax
from jax import numpy as jnp
from jax import jit

variance = jnp.eye(300)
initial_state = jax.random.normal(jax.random.PRNGKey(0), (300,))

@genjax.scan_combinator(max_length=100)
@genjax.gen
def hmm(x, c):
    new_x = genjax.mv_normal(x, variance) @ "new_x"
    return new_x, None

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
jitted = jit(hmm.repeat(num_repeats=100).simulate)
trace = jitted(key, (initial_state, None))
#print(trace.get_sample()[...,"new_x"])


# simple mv hmm
@genjax.scan_combinator(max_length=100)
@genjax.gen
def hmm(args, c):
    x, y = args
    new_x = genjax.normal(x, y) @ "new_x"
    return (new_x, y), None

initial_x = 0.0
initial_y = 0.0
key = jax.random.PRNGKey(0)
trace = hmm.simulate(key, ((initial_x, initial_y), None))
print(trace.get_sample()[...,"new_x"])


[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]


In [15]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['animation.ffmpeg_path'] = '/opt/homebrew/bin/ffmpeg'
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Rectangle
import matplotlib
matplotlib.rcParams['savefig.pad_inches'] = 0
import genjax
from genjax.inference.smc import *
from genjax import gen, JAXGenerativeFunction
from dataclasses import dataclass
from PIL import Image, ImageSequence
import copy

#@dataclass
class LabeledCategorical(JAXGenerativeFunction, ExactDensity):
    def sample(self, key, probs, labels, **kwargs):
        cat = tfd.Categorical(probs=probs)
        cat_index = cat.sample(seed=key)
        return labels[cat_index]

    def logpdf(self, v, probs, labels, **kwargs):
        w = jnp.log(jnp.sum(probs * (labels==v)))
        return w

class UniformCategorical(JAXGenerativeFunction, ExactDensity):
    def sample(self, key, labels, **kwargs):
        cat = tfd.Categorical(probs=jnp.ones(len(labels)) / len(labels))
        cat_index = cat.sample(seed=key)
        return labels[cat_index]

    def logpdf(self, v, labels, **kwargs):
        probs = jnp.ones(len(labels)) / len(labels)
        logpdf = jnp.log(probs)
        w = logpdf[0]
        return w
    
cat = TFPDistribution(lambda p: tfd.Categorical(probs=p))
labcat = LabeledCategorical()

# This works great. 
key = jax.random.PRNGKey(2)
key, subkey = jax.random.split(key, 2)
scene_dim = 20
σ_pos = 0.5
σ_pos_target = 0.1
σ_vel = 0.3
σ_collision = 1.0
maxvel = 2.0
positions = jnp.arange(0, scene_dim).astype(jnp.float32)
occ_positions = jnp.arange(16, scene_dim-1).astype(jnp.float32)
velocities = jnp.arange(-maxvel, maxvel+1).astype(jnp.float32)
pixelmap_support = jnp.array([[0.0, 1.0, 2.0]])
truncwin = 4
occluder_width = 3
occluder_area = occluder_width * scene_dim
edges = jnp.array([positions[0], positions[-1]])
first_n_probs = lambda n: normalize((jnp.arange(scene_dim*scene_dim) < n).astype(int))



""" Generative Functions for Model """

@gen
def init_latent_model():
    occ = labcat(unicat(occ_positions), occ_positions) @ "occ"
    x = labcat(unicat(positions), positions) @ "x"
    y = labcat(unicat(positions), positions) @ "y"
    vx = labcat(unicat(velocities), velocities) @ "vx"
    vy = labcat(unicat(velocities), velocities) @ "vy"
    return (occ, x, y, vx, vy)

def velocity_transform(pos, vel):
    likely_next_pos = pos + vel
    # collision returns -1. velocity is -1 definitely if position is 18 and velocity input is positive. 
    collision_detect = ((positions[0] < likely_next_pos) * (
        likely_next_pos < positions[-1])) * 1.0 + (
        (likely_next_pos >= positions[-1]) + (
            likely_next_pos <= positions[0])) * -1.0

    velprobs = discrete_truncnorm(collision_detect * vel,
                                  (collision_detect == 1) * σ_vel + (
                                      collision_detect == -1) * σ_collision,
                                  truncwin, velocities)
    return velprobs, collision_detect, edges[jnp.abs(pos - edges).argmin()]
    

@gen
def step_latent_model(occₚ, xₚ, yₚ, vxₚ, vyₚ):
    velprobs_x, collision_x, edge_x = velocity_transform(xₚ, vxₚ)
    velprobs_y, collision_y, edge_y = velocity_transform(yₚ, vyₚ)
    vx = labcat(velprobs_x, velocities) @ "vx"
    vy = labcat(velprobs_y, velocities) @ "vy"
    occ = labcat(discrete_norm(occₚ, σ_pos, occ_positions), occ_positions) @ "occ"
    x = labcat(discrete_norm((vx + edge_x) + collision_x * (xₚ - edge_x) , σ_pos,
                      positions), positions) @ "x" 
    y = labcat(discrete_norm((vy + edge_y) + collision_y * (yₚ - edge_y) , σ_pos, 
                      positions), positions) @ "y" 
    return (occ.astype(float), x.astype(float), y.astype(float), vx, vy)

""" Observation Model """

pixcolors = jnp.arange(0, 3).astype(jnp.float32)
flip_prob = .005
x_img, y_img = jnp.meshgrid(positions, positions)
image_pixels = jnp.array(list(zip(y_img.ravel(), x_img.ravel())))

def renderer_input(occ_input, x, y, vx, vy):
    occ = jnp.ones(scene_dim*scene_dim) * occ_input
    target_x = jnp.ones(scene_dim*scene_dim) * x
    target_y = jnp.ones(scene_dim*scene_dim) * y
    pix_x = image_pixels[:, 0]
    pix_y = image_pixels[:, 1]
    return (occ, target_x, target_y, pix_x, pix_y)

@gen
def render_pixel(occ, target_x, target_y, pix_x, pix_y):
    is_occluded = (occ <= pix_y) * (pix_y < occ + occluder_width)
    # if occluded, target not visible. if not, it can be. 
    is_target = (target_x == pix_x) * (target_y == pix_y) * (1 - is_occluded)
    # if not occluder or target, its empty
    is_empty = (1 - is_occluded) * (1 - is_target)
    # i think you don't need a switch here. you make 3 arrays and add them.
    occluded_probs = is_occluded * jnp.array([flip_prob / 2, 1-flip_prob, flip_prob / 2])
    target_probs = is_target * jnp.array([flip_prob / 2, flip_prob / 2, 1-flip_prob])
    empty_probs = is_empty * jnp.array([1-flip_prob, flip_prob / 2, flip_prob / 2])
    color_probs = occluded_probs + target_probs + empty_probs
    pixel_color = labcat(color_probs, pixcolors) @ "pixcolor"
    return pixel_color

render_map = genjax.vmap_combinator(render_pixel, in_axes=(0, 0, 0, 0, 0))

@gen
def render_image(occ_input, x, y, vx, vy):
    occ, target_x, target_y, pix_x, pix_y = renderer_input(occ_input, x, y, vx, vy)
    image = render_map(occ, target_x, target_y, pix_x, pix_y) @ "img" 
    return image

""" Renderer """

""" A wrapper for generating an observation from the model """ 
def run_mental_physics_from_prior(num_steps, noise_level, key, *init_state):
    global flip_prob
    orig_flip_prob = copy.deepcopy(flip_prob)
    flip_prob = noise_level
    gt_latents = []
    observation_traces = []
#    Init_state = (2., 4., 5., 1., 1.)
    if init_state == ():
        init_state = init_latent_model.simulate(key, ()).get_retval()
    else:
        init_state = init_state[0]
    obs_trace_init = render_image.simulate(key, init_state)
    observation_traces.append(obs_trace_init)
    gt_latents.append(init_state)
    for step in range(num_steps):
        key, subkey = jax.random.split(key, 2)
        if step == 0:
            tr = step_latent_model.simulate(subkey, init_state)
        else:
            tr = step_latent_model.simulate(subkey, tr.get_retval())
        mapped_render = render_image.simulate(subkey, tr.get_retval())
        gt_latents.append(tr.get_retval())
        observation_traces.append(mapped_render)
    flip_prob = orig_flip_prob
    gt_latents = np.array(gt_latents)
    obs_frames = [obs.get_retval() for obs in observation_traces]
    ani = animate_pf_results([], [], [],
                             [gt_latents[:, 0], gt_latents[:, 1], gt_latents[:, 2]] ,
                             obs_frames, False)
    return ani

def make_custom_physics(subkey):
    global flip_prob
    orig_flip_prob = copy.deepcopy(flip_prob)
    flip_prob = 0.000
    target_y = jnp.ones(int(len(positions) / 4)) * 10
    target_x = positions[0:5]
    occ = jnp.ones(5) * 17
    keys = jax.random.split(subkey, len(target_x))
    
    def render(x, y, oc, k):
        tr = render_image.simulate(k, (oc, x, y, 0.0, 0.0))
        return tr
    
    obs_traces = jax.vmap(render)(target_x, target_y, occ, keys)
    target_vx = jnp.hstack((jnp.array([-1]), jnp.diff(target_y)))    
    target_vy = jnp.hstack((jnp.array([1]), jnp.diff(target_x)))
    flip_prob = orig_flip_prob
    groundtruth = [occ, target_x, target_y, target_vx, target_vy]
    return obs_traces, groundtruth

def make_custom_physics2(subkey):
    global flip_prob
    orig_flip_prob = copy.deepcopy(flip_prob)
    flip_prob = 0.000
    target_y = jnp.ones(int(len(positions) / 4)) * 10
    target_x = jnp.arange(positions[4], positions[0]-1, -1)
    occ = jnp.ones(5) * 17
    keys = jax.random.split(subkey, len(target_x))
    
    def render(x, y, oc, k):
        tr = render_image.simulate(k, (oc, x, y, 0.0, 0.0))
        return tr
    
    obs_traces = jax.vmap(render)(target_x, target_y, occ, keys)
    target_vx = jnp.hstack((jnp.array([-1]), jnp.diff(target_y)))    
    target_vy = jnp.hstack((jnp.array([1]), jnp.diff(target_x)))
    flip_prob = orig_flip_prob
    groundtruth = [occ, target_x, target_y, target_vx, target_vy]
    return obs_traces, groundtruth

def get_pixeldata_from_obs_trace(obs_traces):
    pixeldata = []
    for obs_tr in obs_traces:
        pixeldata.append(obs_tr.get_choices()[:, "pixcolor"].flatten()[0][1])
    return pixeldata


""" INFERENCE """
    
def find_target(img):
    find_blue = jnp.where(img==2, 1, 0)
    return find_blue

# this is basically a leading edge detector 
def find_occluder(img_pixels, img):
    find_red = jnp.where(img==1, 1, 0)
    impix = find_red * image_pixels[:, 1]
    occ_pos = jnp.min(impix[jnp.nonzero(impix, size=occluder_area)])
    return occ_pos

def find_nth_occurrence(target_boolarray, n):
    return jnp.where(target_boolarray==1, size=scene_dim*scene_dim)[0][n]

@gen
def no_target_proposal_init(occ, target_arr, x, y, vx, vy):
    x = labcat(unicat(positions), positions) @ "x"
    y = labcat(upweight_zone(occ, positions, occluder_width, 1-flip_prob), positions) @ "y" 
    return x, y, 0.0, 0.0

def proposal_collision_detector(p, v):
    σ_pos_wide = 3 * σ_vel
    σ_pos_tight = σ_vel 
    return (((p + v) <= positions[0]) + ((p + v) >= positions[-1])) * σ_pos_wide + ((
        positions[0] < (p + v)) * ((p + v) < positions[-1])) * σ_pos_tight

@gen
def dynamics_proposal(occ, target_arr, xₚ, yₚ, vxₚ, vyₚ):
    σ_x = proposal_collision_detector(xₚ, vxₚ)
    σ_y = proposal_collision_detector(yₚ, vyₚ)
    x = labcat(discrete_norm(xₚ+vxₚ, σ_x, positions), positions) @ "x" 
    y = labcat(discrete_norm(yₚ+vyₚ, σ_y, positions), positions) @ "y" 
    return x, y, σ_x, σ_y


@gen
def target_proposal(occ, target_arr, xₚ, yₚ, vxₚ, vyₚ):
    n = labcat(first_n_probs(jnp.sum(target_arr)), jnp.arange(scene_dim*scene_dim)) @ "n"
    target_pix = image_pixels[find_nth_occurrence(target_arr, n)]
    x = labcat(discrete_norm(target_pix[0], σ_pos_target, positions), positions) @ "x"
    y = labcat(discrete_norm(target_pix[1], σ_pos_target, positions), positions) @ "y"
    σ_x = proposal_collision_detector(xₚ, vxₚ)
    σ_y = proposal_collision_detector(yₚ, vyₚ)
    return x, y, σ_x, σ_y

xy_proposal_init = genjax.switch_combinator(no_target_proposal_init, target_proposal)
xy_proposal = genjax.switch_combinator(dynamics_proposal, target_proposal)

@gen
def init_proposal(img):
    occ = labcat(discrete_norm(find_occluder(image_pixels, img), .1, occ_positions), occ_positions) @ "occ"
    target_candidates = find_target(img)
    x, y, σ_x, σ_y = xy_proposal_init((jnp.sum(target_candidates) > 0).astype(int), occ, target_candidates, 0, 0, 0, 0) @ "target"
    vx = labcat(unicat(velocities), velocities) @ "vx"
    vy = labcat(unicat(velocities), velocities) @ "vy"
    return (occ, x, y, vx, vy)

# its sampling N even though prop_use is 0. 

@gen
def step_proposal(occₚ, xₚ, yₚ, vxₚ, vyₚ, img):
    occ = labcat(discrete_norm(find_occluder(image_pixels, img), .1, occ_positions), occ_positions) @ "occ"    
    target_candidates = find_target(img)
    num_targets = jnp.sum(target_candidates)
    use_target_prob = (0.0 * (num_targets == 0)) + ((1 - flip_prob) / num_targets * (num_targets > 0))
    proposal_use = jnp.round(use_target_prob).astype(int)
    print(jnp.isfinite(proposal_use))
    x, y, σ_x, σ_y = xy_proposal(proposal_use, occ, target_candidates, xₚ, yₚ, vxₚ, vyₚ) @ "target"
    vx = labcat(discrete_norm(x-xₚ, σ_x, velocities), velocities) @ "vx"
    vy = labcat(discrete_norm(y-yₚ, σ_y, velocities), velocities) @ "vy"
    return (occ, x, y, vx, vy)

def translate_proposal_cm_to_model(cm):
    model_cm = genjax.choice_map({"x": cm[("target", "x")], "y": cm[("target", "y")],
                                  "vx": cm["vx"], "vy": cm["vy"], 
                                  "occ": cm["occ"]})
    return model_cm

def maybe_resample(key, log_weights, ess_threshold):
    key, subkey = jax.random.split(key, 2)
    resampled_inds = tfd.Categorical(
        probs=normalize(jnp.exp(log_weights))).sample(len(log_weights), subkey)
    log_total_weight = jax.nn.logsumexp(log_weights)
    log_normalized_weights = log_weights - log_total_weight
    log_ess = - jax.nn.logsumexp(2 * log_normalized_weights)
    ess = jnp.exp(log_ess)
    particle_inds = ((ess < ess_threshold) * resampled_inds) + ((ess >= ess_threshold) * jnp.arange(len(log_weights)))
    return particle_inds, log_total_weight, log_normalized_weights


def run_single_step_smc_jax(key, obs_trace, proposal, model, obs_model, cm_translator,
                            prevstate_and_score, n_particles):
    last_step_traces, resampled_traces, log_total_weight, log_norm_weights, pqobs = prevstate_and_score
    subkeys = jax.random.split(key, n_particles)
    # the only things in proposal_map are resampled_traces, obs_trace. its coming out nan.
    # obs_trace is just a physics_obs at slice 1.
    proposal_map = jax.vmap(
        lambda k, i: proposal.propose(
            k,
            resampled_traces.slice(i).get_retval() + (obs_trace.get_retval(),)))(subkeys, jnp.arange(n_particles))
    model_map = jax.vmap(
        lambda k, i: model.importance(
            k,
            cm_translator(proposal_map[0].slice(i).get_choices()),
            resampled_traces.slice(i).get_retval()))(subkeys, jnp.arange(n_particles))
    
    obs_map = jax.vmap(lambda k, i: obs_model.importance(
        k,
        obs_trace.get_choices(),
        (model_map[0].slice(i)["occ"],
         model_map[0].slice(i)["x"],
         model_map[0].slice(i)["y"],
         model_map[0].slice(i)["vx"],
         model_map[0].slice(i)["vy"])))(subkeys, jnp.arange(n_particles))

# maybe resample isn't properly implemented here for ess != inf.
# proposal map on first step returns all nans
    p_q_obs = (model_map[1], proposal_map[1], obs_map[1])
    scores = model_map[1] + obs_map[1] - proposal_map[1]
    key, subkey = jax.random.split(key, 2)
    resampled_particles, log_total_weight, log_norm_weights = maybe_resample(subkey, scores, n_particles / 4.)
#    resampled_particles, log_total_weight, log_norm_weights = maybe_resample(subkey, scores, jnp.inf)
    next_parent_state = jax.vmap(lambda i: model_map[0].slice(i))(resampled_particles)
    return (model_map[0], next_parent_state, log_total_weight, log_norm_weights, p_q_obs)

# the state at each step here in the resampled state. this is not necessarily what you want.
# you probably just want the resampled state to be the parents. but this is already
# done in snmc. but this is exactly why particle dist is so uniform. just try to implement this as
# "next parent state". should be fine! 

@gen
def initializing_func():
    return ()

def run_particle_filter(obs_traces, n_particles, num_steps, pfkey):
    subkeys = jax.random.split(pfkey, n_particles)
    _, init_key = jax.random.split(pfkey, 2)
    init_traces = jax.vmap(lambda k: initializing_func.simulate(k, ()))(subkeys)
    init_states_and_scores = run_single_step_smc_jax(init_key, obs_traces.slice(0),
                                                     init_proposal,
                                                     init_latent_model, 
                                                     render_image,
                                                     translate_proposal_cm_to_model,
                                                     (init_traces, init_traces, 0, 0, ()), n_particles)
    # seems like you have to run the first step outside the loop b/c its pytree is different than the init prop.
    _, pf_key_firststep = jax.random.split(init_key, 2)
    first_step_states_and_scores = run_single_step_smc_jax(pf_key_firststep, obs_traces.slice(1),
                                                           step_proposal,
                                                           step_latent_model, 
                                                           render_image,
                                                           translate_proposal_cm_to_model,
                                                           init_states_and_scores, n_particles)
    # this is very easy just incorporate scores into the end of states (i.e. make it a tuple).
    # make maybe resample return scores.

    def particle_step(states, obs_and_key):
        obs, pfkey = obs_and_key
        newstates = run_single_step_smc_jax(pfkey, obs, step_proposal, step_latent_model,
                                            render_image, translate_proposal_cm_to_model, states, n_particles)
        return newstates, newstates

    unrolled_keys = jax.random.split(pf_key_firststep, num_steps - 2)
    unrolled_pf = jax.lax.scan(particle_step, first_step_states_and_scores,
                               (obs_traces.slice(jnp.arange(2, num_steps)), unrolled_keys))
    return init_states_and_scores, first_step_states_and_scores, unrolled_pf


# xy is the lower left corner of the rectangle. so just add .5 to get particles to the middle of the rect.

def animate_state_inference_from_snmc(pf_results, groundtruth, physics_obs, particles_to_animate):
    x_inferences = []
    y_inferences = []
    occ_inferences = []
    p_scores = []
    particles_per_step = pf_results[1]
    resampler_per_step = pf_results[2]
    for step in range(len(physics_obs)):
        particles = particles_per_step[step]
        particle_choicemaps = [p.choicemap for i, p in enumerate(particles) if i in particles_to_animate]
        particle_scores = resampler_per_step[step].log_weights
        # note you would normally index the supports b/c
        # choicemap is an assembly index. but 
        x = np.array([cm["x"] for cm in particle_choicemaps])
        y = np.array([cm["y"] for cm in particle_choicemaps])
        occ = np.array([cm["occ"] for cm in particle_choicemaps])
        x_inferences.append(x)
        y_inferences.append(y)
        occ_inferences.append(occ)
        p_scores.append(particle_scores)
    anim = animate_pf_results(x_inferences, y_inferences, occ_inferences, groundtruth, physics_obs, True, p_scores)
    return anim


# get particle scores at each step here too. 
def animate_from_jax_pf(init_step, first_step, unrolled_state, groundtruth, physics_obs):
    init_state = init_step[0]
    first_step_state = first_step[0]
    num_steps = len(physics_obs)
    def make_state_array(index):
        first_2 = [init_state.get_retval()[index], first_step_state.get_retval()[index]]
        rest = [unrolled_state[1][0].slice(step).get_retval()[index] for step in range(0, num_steps-1)]
        return np.array(first_2 + rest)
    x_inferences = make_state_array(1)
    y_inferences = make_state_array(2)
    occluder_x_inferences = make_state_array(0)
    scores = [init_step[-1], first_step[-1]] + list(unrolled_pf[1][-1])
    anim = animate_pf_results(x_inferences, y_inferences, occluder_x_inferences, groundtruth, physics_obs, True, scores)
    return anim


def animate_pf_results(x_inferences, y_inferences, occluder_x_inferences,
                       groundtruth, physics_obs, color_by_particle, *scores):
    colors = ['w', 'lightgray', 'k']
    occ_gt, x_gt, y_gt = groundtruth    
    fig, ax = plt.subplots(1, 2, figsize=(12, 8))
    alphas = []
    if scores != ():
        scores = scores[0]
        for s in scores:
            if np.sum(np.isfinite(s)) != 0:
                alphas.append(np.cbrt(normalize(np.exp(s))))
            else:
                alphas.append(np.zeros(len(s)))
                print('alphas')
                print(alphas)
    elif jnp.shape(x_inferences) != (0,):
        alphas = [np.ones(len(x_inferences[0])) for x in x_inferences]
    ax[0].set_xlim([0, scene_dim])
    ax[0].set_ylim([0, scene_dim])
    ax[0].set_aspect('equal', 'box')
    ax[1].set_xlim([0, scene_dim])
    ax[1].set_ylim([0, scene_dim])
    ax[1].set_aspect('equal', 'box')
    # Animation function. Want to eventually animate the occluder as well.
    gt_target = Rectangle((x_gt[0], y_gt[0]), 1, 1, facecolor=colors[2], edgecolor='none')
    gt_occ = Rectangle((0, occ_gt[0]), scene_dim, occluder_width, facecolor=colors[1], edgecolor='none')
    obs_rectangles = []
    for pix_ind, pix in enumerate(physics_obs[0].astype(int)):
        color = colors[pix]
        rect_x, rect_y = image_pixels[pix_ind]
        rect = Rectangle((rect_x, rect_y), 1, 1, facecolor=color, edgecolor='none')
        obs_rectangles.append(rect)
    ax[0].add_patch(gt_occ)
    ax[0].add_patch(gt_target)
    for rect in obs_rectangles:
        ax[1].add_patch(rect)
# 9 , 13, 16 are the colors of the plot
    cpal = my_tab20(num_particles)
    print(len(cpal))
#    cpal = [sns.color_palette("tab20b")[i] for i in [9, 13, 16]]
    if jnp.shape(x_inferences) != (0,):
        print(x_inferences)
        if not color_by_particle:
            particle_color = [(0, 0, 0) for x in x_inferences[0]]
        else:
            particle_color = [cpal[i] for i, x in enumerate(x_inferences[0])]
        alpha = .75
        inferences_ax0 = ax[0].scatter(x_inferences[0] + .5, y_inferences[0] + .5, color=particle_color, alpha=alpha)
        inferences_ax1 = ax[1].scatter(x_inferences[0] + .5, y_inferences[0] + .5, color=particle_color, alpha=alpha)
 
    def init_func():
        pass

    ax[0].get_xaxis().set_visible(False)
    ax[1].get_xaxis().set_visible(False)
    ax[0].get_yaxis().set_visible(False)
    ax[1].get_yaxis().set_visible(False)


    def update(frame):
        print(frame)
        obs_frame = physics_obs[frame].astype(int)
        for pix, rect in zip(obs_frame, obs_rectangles):
            rect.set_facecolor(colors[pix])
        gt_occ.set_xy((0, occ_gt[frame]))
        gt_target.set_xy((x_gt[frame], y_gt[frame]))
        if jnp.shape(x_inferences) != (0,):
            inferences_ax0.set_offsets(np.column_stack((x_inferences[frame] + .5, y_inferences[frame] + .5)))
            inferences_ax1.set_offsets(np.column_stack((x_inferences[frame] + .5, y_inferences[frame] + .5)))
            inferences_ax0.set_alpha(alphas[frame])
            inferences_ax1.set_alpha(alphas[frame])

    animation = FuncAnimation(fig, update, frames=len(physics_obs), init_func=init_func, interval=250)
#    animation.save("/Users/nightcrawler/animations/animation.mp4", writer='ffmpeg')
    return animation

# simply make ordering of variables based on dependencies.
# think about ordering in the test suite. will be easier there. 


def particle_filter_snmc_physics(num_particles, physics_frames):

    # here you have to add a subfield for whether its sampled in a switch.
    # if it is, have to note whether its in the model or proposal.
    # or you could standardize the name of the variable to "sw" or something like that. "ss" for switch sample. the outcome will always just be a Q of the variable having some probs, and a P of the same variable having probs. the only issue is extracting those probs. test this in the test first, because its less complicated.

    model_variables = [{"variable" : 'occ',
                        "parents" : [],
                        "support" : occ_positions, 
                        "switch" : []},
                       {"variable" : 'x',
                        "parents" : ['vx'],
                        "support" : positions, 
                        "switch" : []}, 
                       {"variable" : 'y',
                        "parents" : ['vy'],
                        "support" : positions, 
                        "switch" : []}, 
                       {"variable" : 'vx',
                        "parents" : [],
                        "support": velocities, 
                        "switch" : []},
                       {"variable" : 'vy',
                        "parents" : [],
                        "support" : velocities, 
                        "switch" : []}]

    proposal_variables = [{"variable" : 'occ',
                           "parents" : [],
                           "support" : positions, 
                           "switch" : []},
                          {"variable" : 'x',
                           "parents" : [],
                           "support" : positions, 
                           "switch" : [("init_proposal", "target"),
                                       ("step_proposal", "target")]},
                          {"variable" : 'y',
                           "parents" : [],
                           "support" : positions, 
                           "switch" : [("init_proposal", "target"),
                                       ("step_proposal", "target")]},
                          {"variable" : 'vx',
                           "parents" : ['x'],
                           "support" : velocities, 
                           "switch" : []},
                          {"variable" : 'vy',
                           "parents" : ['y'],
                           "support" : velocities, 
                           "switch" : []}]

    obs_variables= [{"variable" : 'pixcolor',
                     "parents" : [],
                     "support" : pixelmap_support, 
                     "switch" : [],
                     "latent_transform_fn" : renderer_input}]

    variables = [model_variables, proposal_variables, obs_variables]

    # provide a list of tuples for observations. 
    observations = [(pf, ) for pf in physics_frames][0:num_snmc_steps]
#    pf_results = initialization_test(variables, init_latent_model, init_proposal, render_map,
 #                                    10, 10, observations)
    pf_results = run_snmc_particle_filter(variables,
                                          init_latent_model,
                                          step_latent_model, 
                                          init_proposal,
                                          step_proposal,
                                          render_map,
                                          10, num_particles, observations, "analog")
    return pf_results



def exact_filtering(key, observation_traces):
    
    model_variables = [{"variable" : 'occ',
                        "parents" : [],
                        "support" : occ_positions, 
                        "switch" : []},
                       {"variable" : 'x',
                        "parents" : ['vx'],
                        "support" : positions, 
                        "switch" : []}, 
                       {"variable" : 'y',
                        "parents" : ['vy'],
                        "support" : positions, 
                        "switch" : []}, 
                       {"variable" : 'vx',
                        "parents" : [],
                        "support": velocities, 
                        "switch" : []},
                       {"variable" : 'vy',
                        "parents" : [],
                        "support" : velocities, 
                        "switch" : []}]

    print("making choicemaps")

    obs_choicemaps = [observation_traces.slice(i).get_choices() for i in range(num_snmc_steps)]
    variables = [v['variable'] for v in model_variables]
    supports = [v['support'] for v in model_variables]

    print("making hmm")
    
    hmm = genjax_discrete_hmm.labeled_hmm_from_gen_fns_and_domains(
        init_latent_model,
        step_latent_model, render_image, 
        variables,
        supports)

    # Compute P(x_t, y_t | obs_0, ..., obs_t) for each `t`
    print("made hmm")
    posterior_sequence = hmm.filtering_posteriors(obs_choicemaps)
#    return posterior_sequence
# posterior_sequence is a list of genjax_discrete_hmm.DiscreteDistribution objects
# where posterior_sequence[t] is the posterior distribution at time over the space of choicemaps on
# ("x", "y") at time t.
    return posterior_sequence, obs_choicemaps, variables, supports


def get_exact_posterior_at_timepoint(tp, posterior_seq, variables, supports):
    
    def get_p_fullstate(posterior_sequence, t, *state_tuple):
        cm = genjax.choice_map({ v : s for (v, s) in zip(variables, state_tuple) })
        return posterior_sequence[t].get_prob(cm)

# We can use `vmap` to simultaneously access the posterior probabilities of multiple possible latent values.
# Here, `get_p_xs_ys` takes in a list of `x` values and a list of `y` values,
# and returns a matrix of probabilities, where `M[i, j]` is the posterior probability of
# `x[i]` and `y[j]` at time `t`.

    # takes a tuple tup and an integer n as input. It rotates the elements of the tuple to the left by n positions and returns the rotated tuple. If n is negative, it will rotate the elements to the right.
    def rotate_tuple(tup, n):
        n = n % len(tup)  # Ensure rotation within the length of the tuple
        return tup[n:] + tup[:n]
    
    def generate_mapfn(t, mapfn, in_ax):
        new_mapfn = jax.vmap(mapfn, in_axes=in_ax)
        if in_ax[2] == 0:
            return new_mapfn
        else:
            return generate_mapfn(t, new_mapfn, rotate_tuple(in_ax, 1))

    posterior_matrix_generator = generate_mapfn(tp, get_p_fullstate,
                                                (None, None, None, None, None, None, 0))

    posterior_matrix = posterior_matrix_generator(posterior_seq, tp, *supports)
    return posterior_matrix

def plot_exact_filtering(exact_filtering_results, v_to_plot):
    posterior_sequence, _, m_variables, supports = exact_filtering_results
    plot_index = m_variables.index(v_to_plot[0])
    probs = []    
    for t in range(num_snmc_steps):
        posterior_matrix = get_exact_posterior_at_timepoint(t, posterior_sequence, m_variables, supports)
        step_probs = []
        for i, val in enumerate(supports[plot_index]):
            indexing_args = [i if ind==plot_index else slice(None) for ind, mv in enumerate(m_variables)]
            step_probs.append(jnp.sum(posterior_matrix[*indexing_args]))
        probs.append(step_probs)
    y_values = np.arange(len(probs[0]))
    fig, ax = plt.subplots(1, 1)
#    p is the probability vector for each step t. 
    for y, p in zip(y_values, probs):
        ax.scatter(np.full_like(p, y), np.arange(len(p)),
                   s=np.array(p)*100, c='gray', edgecolors='none')
    ax.set_xlabel('step')
    ax.set_ylabel('variable value')
    plt.show()
    return probs

def test_particle_dist(init_step, first_step, unrolled_pf):
    x_init = list(init_step[0].get_choices()["x"])
    x_step1 = list(first_step[0].get_choices()["x"])
    x_rest = [list(unrolled_pf[1][0].get_choices()["x"][i]) for i in range(num_snmc_steps-2)]
    rws = 4
    cols = 5
    fig, ax = plt.subplots(rws, cols)
    all_x = [x_init, x_step1] + x_rest
    pl_count = 0
    cpal = my_tab20(100)
    for r in range(rws):
        for c in range(cols):
            if pl_count == num_snmc_steps:
                break
            ax[r, c].hist(all_x[pl_count], color=cpal[pl_count], bins=20)
            pl_count += 1
    plt.show()            
    return all_x

def test_particle_dist_snmc(pf_results, variable):
    all_v = [[] for i in range(num_particles)]
    for i in range(num_particles):
        full_validation = particle_validation(pf_results, i)
        for s in range(num_snmc_steps):
            all_v[s].append(full_validation["state"][s][variable])
    rws = 3
    cols = 2
    fig, ax = plt.subplots(rws, cols)
    pl_count = 0
    cpal = my_tab20(100)
    for r in range(rws):
        for c in range(cols):
            if pl_count == num_snmc_steps:
                break
            ax[r, c].hist(all_v[pl_count], color=cpal[pl_count], bins=20)
            ax[r, c].set_title("Step " + str(pl_count))
            pl_count += 1
    plt.tight_layout()
    plt.show()            
    return all_v
    
def filter_validation_set(pf_results, variable, step, query_value):
    validation_dicts = []
    for i in range(num_particles):
        full_validation = particle_validation(pf_results, i)
        if full_validation["state"][step][variable] == query_value:
            validation_dicts.append(full_validation)
    return validation_dicts

def scatter_particle_weights(init_step, first_step, unrolled_pf):
    weights = []
    for i in range(num_particles):
        full_validation = particle_validation(pf_results, i)
        weights.append(full_validation["total"][step])
    print(weights)
    plt.scatter(range(len(weights)), weights)
    plt.show()


def scatter_particle_weights_snmc(pf_results, step):
    weights = []
    for i in range(num_particles):
        full_validation = particle_validation(pf_results, i)
        weights.append(full_validation["total"][step])
    print(weights)
    plt.scatter(range(len(weights)), weights)
    plt.show()

# reset the alphas to represent probability, make scatterplots of the weights,
num_snmc_steps = 5
num_particles = 40
ds = []

for i in range(3):

    physics_obs, ground_truth_latents = make_custom_physics(subkey)
    gt = np.array(ground_truth_latents)
    gt_tuples = [tuple(gt[:, i]) for i in range(gt.shape[1])]
    m_variables = ["occ", "x", "y", "vx", "vy"]
    gt_choicemap_step1 = genjax.choice_map({ v : val for v, val in zip(m_variables, gt_tuples[1]) })
    physics_frames = [physics_obs.slice(i).get_retval() for i in range(num_snmc_steps)]

    pf_results = particle_filter_snmc_physics(num_particles, physics_frames)

    physics_obs, ground_truth_latents = make_custom_physics2(subkey)
    gt = np.array(ground_truth_latents)
    gt_tuples = [tuple(gt[:, i]) for i in range(gt.shape[1])]
    m_variables = ["occ", "x", "y", "vx", "vy"]
    gt_choicemap_step1 = genjax.choice_map({ v : val for v, val in zip(m_variables, gt_tuples[1]) })
    physics_frames = [physics_obs.slice(i).get_retval() for i in range(num_snmc_steps)]

    pf_results2 = particle_filter_snmc_physics(num_particles, physics_frames)

    labels, ds_indices = direction_selectivity(pf_results, pf_results2, 2, 'x')
    ds.append(ds_indices)

med_ds = np.median(ds, axis=0)
ds_med = [(l, d) for (l, d) in zip(labels, med_ds) if np.isfinite(d)] 
mean_ds = np.mean(ds, axis=0)
ds_mean = [(l, d) for (l, d) in zip(labels, mean_ds) if np.isfinite(d)]
plt.hist(np.abs([d[1] for d in ds_mean]))
plt.hist(np.abs([d[1] for d in ds_med]))

def plot_physics_series(pf_results, gt_latents, physics_frames, particle_ids, frames_to_plot):
    anim = animate_state_inference_from_snmc(pf_results, gt_latents[0:3], physics_frames[0:num_snmc_steps], particle_ids)
    anim.save('physics.gif')
    gif_anim = Image.open('physics.gif')
    frames = []
    phys_fig, ax = plt.subplots(1, len(frames_to_plot))
    for i, frame in enumerate(ImageSequence.Iterator(gif_anim)):
        if i in frames_to_plot:
            im = np.array(frame.convert('RGB'))[300:1250, 1300:2200, :]
            ax[len(frames)].imshow(im)
            ax[len(frames)].axis('off')
            ax[len(frames)].set(frame_on=False)
            frames.append(im)
    plt.subplots_adjust(wspace=0.001, hspace=0.001)
    phys_fig.savefig('physics_whole_run.png', dpi=2000)
    return frames

def choices(tr):
    try:
        cm = tr.get_choices().value
    except:
        cm = tr.get_choices()
    return { v: cm[v] for v in m_variables }


ImportError: cannot import name 'JAXGenerativeFunction' from 'genjax' (/Users/matuot/repos/genjax-docs/genjax/src/genjax/__init__.py)