In [None]:
import genjax
import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.animation as animation
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np

key = jax.random.PRNGKey(0)

plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False
matplotlib.rcParams["animation.embed_limit"] = 25

In [None]:
# Load the image
im = mpimg.imread("../../../docs/assets/img/logo.png")

# Convert to mask
im = np.amax(im[:, :, :2], 2) < 0.9
# invert the image
im = np.logical_not(im)

# Convert back to float
im = im.astype(float)
height, width = im.shape
# plot the image
plt.imshow(im, cmap="gray")
plt.show()

In [None]:
# Define the prior model
@genjax.gen
def prior_model():
    x = genjax.uniform(0.0, float(height)) @ "x"
    y = genjax.uniform(0.0, float(width)) @ "y"
    obs_x = genjax.normal(x, 1.0) @ "obs_x"
    obs_y = genjax.normal(y, 1.0) @ "obs_y"
    return x, y


jax_im = jnp.array(im.astype(float))
n_samples = 5000
batched_prior_model = prior_model.repeat(n=n_samples)

key, subkey = jax.random.split(key)
tr = batched_prior_model.simulate(subkey, ())
xs_init, ys_init = tr.get_choices()[..., "x"], tr.get_choices()[..., "y"]
zs_init = np.stack([xs_init, ys_init], axis=1)

In [None]:
# by changing this, we should be able to change the visual effect for convergence of the particles to the true distribution
@genjax.gen
def proposal(z):
    x, y = z
    x = genjax.normal(x, 0.5) @ "x"
    y = genjax.normal(y, 0.5) @ "y"
    # v2: rotate the proposal distribution
    # theta = 0.3
    # rotated x and y
    # x, y = x*cos(theta) - y*sin(theta), x*sin(theta) + y*cos(theta)
    # x = genjax.normal(x, 0.5) @ "x"
    # y = genjax.normal(y, 0.5) @ "y"
    # v3: spiral inwards
    # theta = 0.3
    # inward_coeff = (x+y) / 2
    # rotated inward and scaled by inward_coeff
    # x, y = (x*cos(theta) - y*sin(theta), x*sin(theta) + y*cos(theta))*inward_coeff
    # x = genjax.normal(x+x*cos(theta), 0.5) @ "x"
    # y = genjax.normal(y+y*sin(theta), 0.5) @ "y"
    return x, y

In [None]:
plt.style.use("dark_background")
animation.embed_limit = 25
fig, ax = plt.subplots()
fig.tight_layout()

ax.set_axis_off()
ax.set_xlim(0, width)
ax.set_ylim(0, height)

scat = ax.scatter(ys_init, height - xs_init, s=1000 * 1 / n_samples)

In [None]:
n_frames = 2
samples = jnp.array([zs_init, zs_init + 1])
weights = np.ones(n_samples) / n_samples


# create an animation for the particles
def animate(i):
    scat.set_offsets(np.c_[samples[i, :, 1], height - samples[i, :, 0]])
    scat.set_sizes(1000 * weights[i])
    return (scat,)


# create an animation for the particles
ani = animation.FuncAnimation(
    fig, animate, repeat=True, frames=n_frames, blit=True, interval=100
)

In [None]:
import blackjax
import blackjax.smc.resampling as resampling


# Set the prior and likelihood
def prior_logpdf(z):
    return 0.0


# The pdf is uniform, so the logpdf is constant on the domain and negative infinite outside
def log_likelihood(z):
    x, y = z
    # The pixel is black if x, y falls within the image, which means that their integer part is a valid index
    floor_x, floor_y = jnp.floor(x), jnp.floor(y)
    floor_x, floor_y = jnp.astype(floor_x, jnp.int32), jnp.astype(floor_y, jnp.int32)
    out_of_bounds = (floor_x < 0) | (floor_x >= 80) | (floor_y < 0) | (floor_y >= 250)
    value = jax.lax.cond(
        out_of_bounds,
        lambda *_: -INF,
        lambda arg: -INF * (jax_im[arg[0], arg[1]] == 0),
        operand=(floor_x, floor_y),
    )
    return value


# Temperature schedule
n_temperatures = 150
lambda_schedule = np.logspace(-3, 0, n_temperatures)

# The proposal distribution is a random walk with a fixed scale
scale = 0.5  # The scale of the proposal distribution
normal = blackjax.mcmc.random_walk.normal(scale * jnp.ones((2,)))

rw_kernel = blackjax.additive_step_random_walk.build_kernel()
rw_init = blackjax.additive_step_random_walk.init
rw_params = {"random_step": normal}

tempered = blackjax.tempered_smc(
    prior_logpdf,
    log_likelihood,
    rw_kernel,
    rw_init,
    rw_params,
    resampling.systematic,
    num_mcmc_steps=5,
)

initial_smc_state = tempered.init(zs_init)

In [None]:
# Define the loop
def smc_inference_loop(loop_key, smc_kernel, init_state, schedule):
    """Run the tempered SMC algorithm."""

    def body_fn(carry, lmbda):
        i, state = carry
        subkey = jax.random.fold_in(loop_key, i)
        new_state, info = smc_kernel(subkey, state, lmbda)
        return (i + 1, new_state), (new_state, info)

    _, (all_samples, _) = jax.lax.scan(body_fn, (0, init_state), schedule)

    return all_samples


# Run the SMC sampler
blackjax_samples = smc_inference_loop(
    key, tempered.step, initial_smc_state, lambda_schedule
)

In [None]:
weights = np.array(blackjax_samples.weights)
samples = np.array(blackjax_samples.particles)


# temp = ax.text(0.9, 0.9, r'$\lambda$: 0', horizontalalignment='center', verticalalignment='center', transform=ax.transAxes, fontsize=15)


def animate(i):
    scat.set_offsets(np.c_[samples[i, :, 1], 80 - samples[i, :, 0]])
    scat.set_sizes(1000 * weights[i])
    # temp.set_text(r'$\lambda$: {:.1e}'.format(lambda_schedule[i]))
    return (scat,)


ani = animation.FuncAnimation(
    fig, animate, repeat=True, frames=n_temperatures, blit=True, interval=100
)

# writer = animation.PillowWriter(fps=20,
#                                 metadata=dict(artist='Me'),
#                                 bitrate=1800)
# ani.save('scatter.gif', writer=writer)