In [None]:
"""
Description:
    CHMC Implementation with AVF: FPI
    USE THE CORRECT ENVIRONMENT:  CHMC_FALL_2025
    YYYY-MM-DD
    
Author: John Gallagher
Created: 2025-09-28
Last Modified: 2025-10-09
Version: 1.0.0

"""
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
import time

jax.config.update("jax_enable_x64", True)

@jit
def gauss_ndimf_jax(x, precision_matrix = None, cov=None, dim = 2):
    """n-Dim Gaussian target distribution."""
    dim = len(x)
    # Error Classes
    class MultipleMatrices(Exception):
        pass
    # dealing with getting a precision matrix or cov matrix
    if precision_matrix is not None and cov is not None:
        raise MultipleMatrices("Please supply either a Precision Matrix or a Covariance Matrix")
    if precision_matrix is None and cov is not None:
        precision_matrix = jnp.linalg.inv(cov)
    if precision_matrix is None and cov is None:
        precision_matrix = jnp.eye(dim)
        # jnp.linalg.det(precision_matrix)**(-1/2)
        # (2*jnp.pi)**(-dim/2)*
    return jnp.exp(-x@precision_matrix@x)

@jit
def qex(qp):
  return qp[:dim]
@jit
def pex(qp):
  return qp[dim:]

@jit
def draw_p(qp, key):
    q = qex(qp)
    p = jax.random.normal(key,shape = (dim,))
    return jnp.concatenate([q,p]), None
@jit
def leapfrog(qp):
    def lf_step(carry_in, _):
        q, p = qex(carry_in), pex(carry_in)
        q_half = q + 0.5 * tau * gradH_p(q, p)
        p_new = p - tau * gradH_q(q_half, p)
        q_new = q_half + 0.5 * tau * gradH_p(q_half, p_new)
        carry_out = jnp.concatenate([q_new, p_new])
        return carry_out, _
    qp_final,  _ = jax.lax.scan(lf_step, qp, xs=None, length=T)
    return qp_final

@jit
def midpointFPI(qp):
    x0 = qp
    def G(y):
        midpoint = 0.5*(x0+y)
        return x0 + tau*J_H(grad_xH(midpoint))

    def F(y):
        return y-G(y)

    def newton_step(qp):
        jacF = jax.jacobian(F)(qp)
        qpout = x0 - jnp.linalg.solve(jacF, F(qp))
        return qpout

    def cond(carry):
        i, qp = carry
        F_qp = F(qp)
        err = jnp.linalg.norm(F_qp)
        return (err > tol) & (i < max_iter)

    def body_step(carry):
        i, qp = carry
        return [i + 1, newton_step(qp)]

    _, qp_out = jax.lax.while_loop(cond, body_step, [0, qp])
    return qp_out

@jit
def accept(delta, key):
    alpha = jnp.minimum(1., jnp.exp(delta))
    u = jax.random.uniform(key, shape=())
    return u <= alpha
@jit
def hmc_kernel(carry, key):
    qp0, _ = draw_p(carry, key)
    qp_star = jit_integrator(qp0)
    deltaH = jit_H(qp_star) - jit_H(qp0)
    is_accepted = accept(deltaH, key)
    qp_out = jnp.where(is_accepted, qp_star, qp0)
    return qp_out, qp_out
@jit
def hmc_sampler(initial_sample, keys):
    _, samples = jax.lax.scan(hmc_kernel, initial_sample, xs=keys)
    return samples


# Function handles into mechanics of HMC Sampler:
def hamiltionian(q,p):
    return 0.5 * (p@Mass_inv@p) - jnp.log(target(q))
def xhamiltionian(qp):
    q, p = qex(qp), pex(qp)
    return 0.5 * (p@Mass_inv@p) - jnp.log(target(q))

def J_H(gH):
    """Same operation as Symplectic Jacobian"""
    return jnp.concatenate([gH[dim:],-gH[:dim]])

# def J_simplec(n):
#     zero_zero = jnp.zeros((n,n))
#     zero_one = jnp.eye((n))
#     one_zero = -jnp.eye((n))
#     one_one = zero_zero
#     return jnp.block([[zero_zero, zero_one],[one_zero, one_one]])
# I don't know how jit syntax works yet so I just directly did it here.
target = jit(gauss_ndimf_jax)
grad_target = jit(jax.grad(target))
gradH_p = jit(jax.grad(hamiltionian, argnums=1))
gradH_q = jit(jax.grad(hamiltionian, argnums=0))
jit_H = jit(xhamiltionian)
grad_xH = jax.jit(jax.grad(xhamiltionian))
# jit_target = jax.jit(target)
# jit_grad_target = jax.jit(grad_target)
# jit_gradH_p = jax.jit(gradH_p)
# jit_gradH_q = jax.jit(gradH_q)
# jit_integrator = jax.jit(leapfrog)
# jit_integrator = jit(midpointFPI)

# Set parameters
key = jax.random.PRNGKey(1)
dim = 3
num_samples = 1
mainnum_samples = 10000
keys_start = jax.random.split(key, num_samples)
keys_main = jax.random.split(key, mainnum_samples)
qp_init = jax.random.normal(key, shape=(2*dim,))
Mass_inv = jnp.eye(dim)
tau = 0.2
T = 1
tol = 1e-4
max_iter = 1000
# compile
start=time.time()
sample_LF = hmc_sampler(qp_init, keys_start)
end = time.time()
print("1st run:", end-start)
# main run
start = time.time()
sample_LF = hmc_sampler(qp_init, keys_main).block_until_ready()
end = time.time()
print(f"{mainnum_samples} runs: {end - start:.2f} \n 1 run:  {(end-start)/mainnum_samples}")


In [None]:
# Let H = 0.5*9(p.T@p) + 0.5*(q.T@Sigma_inv@q)
# F(y) = y- x - tau*J@gradH(0.5*(y+x))
# J_F(y) = I - 0.5*tau*J@JacH(0.5*(y+x))

J_F = jnp.eye(2*dim) + jnp.block([[jnp.zeros(dim),0.5*tau*jnp.eye(dim)]])

In [None]:
sample_LF[:,:dim].shape

In [None]:
jit_integrator = jit(midpointFPI)
# compile
start=time.time()
sample_FPI = hmc_sampler(qp_init, keys_start)
end = time.time()
print("1st run:", end-start)
# main run
start = time.time()
sample_FPI = hmc_sampler(qp_init, keys_main)
end = time.time()
print(f"{mainnum_samples} runs: {end - start} \n 1 run:  {(end-start)/mainnum_samples}")


In [None]:
# Calculate the number of unique samples for Leapfrog
unique_samples_LF = jnp.unique(sample_LF[:,:dim], axis=0)
num_unique_LF = unique_samples_LF.shape[0]
print(f"Number of unique samples (Leapfrog): {num_unique_LF}")

# Calculate the number of unique samples for Midpoint FPI
unique_samples_FPI = jnp.unique(sample_FPI[:,:dim], axis=0)
num_unique_FPI = unique_samples_FPI.shape[0]
print(f"Number of unique samples (Midpoint FPI): {num_unique_FPI}")

# Compare the number of unique samples to the total number of samples
print(f"Total number of samples: {mainnum_samples}")

In [None]:
sample_FPI.shape

In [None]:
# jit_integrator = jit(leapfrog)
# # draw qp states
# inits, _ = jax.vmap(draw_p, in_axes=(None, 0))(qp_init, keys_main)
# vmap_integrator = jax.vmap(jit_integrator)
# sample_LF = vmap_integrator(inits)

# # Midpoint FPI
# jit_integrator = jit(midpointFPI)
# vmap_integrator = jax.vmap(jit_integrator)
# sample_FPI = vmap_integrator(inits)

sample_diff = sample_FPI - sample_LF

# norm of the difference along row
norm_diff = jnp.linalg.norm(sample_diff, axis=1)

# mean,max,min
print("Mean norm difference:", jnp.mean(norm_diff))
print("Max norm difference:", jnp.max(norm_diff))
print("Min norm difference:", jnp.min(norm_diff))

In [None]:
fig, ax1 = plt.subplots()
ax1.hist2d(np.array(sample_FPI[:,0]),np.array(sample_FPI[:,1]), bins=50, density = True)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('x-Value')
ax1.set_title('FPI Resulting Accepted Distribution vs Target Function')
plt.show()

In [None]:
def gauss_f(x):
    """Gaussian target distribution."""
    return np.exp(-x**2)

fig, ax1 = plt.subplots()

# Plot histogram of accepted samples
ax1.hist(np.array(sample_FPI[:,1]), bins=50, density = True)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('x-Value')
ax1.set_title('FPI Resulting Accepted Distribution vs Target Function')

# Plot target distribution
# ax2 = ax1.twinx()
x_vals = np.linspace(-3, 3, 10000)
y_vals = [gauss_f(x)/(np.pi)**0.5 for x in x_vals]
ax1.plot(x_vals, y_vals, 'r-', label='Target Function')
ax1.set_ylabel('Target Function Value')
ax1.set_ylim(0, 0.6)

plt.show()

In [None]:
fig, ax1 = plt.subplots()

# Plot histogram of accepted samples
ax1.hist2d(np.array(sample_LF[:,0]),np.array(sample_LF[:,1]), bins=50, density = True)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('x-Value')
ax1.set_title('FPI Resulting Accepted Distribution vs Target Function')

# Plot target distribution
# # ax2 = ax1.twinx()
# x_vals = np.linspace(-3, 3, 10000)
# y_vals = [gauss_f(x)/(np.pi)**0.5 for x in x_vals]
# ax1.plot(x_vals, y_vals, 'r-', label='Target Function')
# ax1.set_ylabel('Target Function Value')
# ax1.set_ylim(0, 0.6)

plt.show()

In [None]:
sample_diff = sample_FPI[:,:3] - sample_LF[:,:3]

# norm of the difference along row
norm_diff = jnp.linalg.norm(sample_diff, axis=1)

# mean,max,min
print("Mean norm difference:", jnp.mean(norm_diff))
print("Max norm difference:", jnp.max(norm_diff))
print("Min norm difference:", jnp.min(norm_diff))

In [None]:
norm = jit(jax.vmap(jnp.linalg.norm))

In [None]:
err = norm(sample_FPI[:,:1000] - sample_LF[:,:1000])

In [None]:
"""
Description:
    Sandbox2 for working through poster materials.
    USE THE CORRECT ENVIRONMENT:  CHMC_FALL_2025

Author: John Gallagher
Created: 2025-09-21
Last Modified: 2025-09-21
Version: 1.0.0

"""
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import jit
import time
def gauss_ndimf_jax(x, precision_matrix = None, cov=None, dim = 2):
    """n-Dim Gaussian target distribution."""
    dim = len(x)
    # dealing with getting a precision matrix or cov matrix
    if precision_matrix is not None and cov is not None:
        raise MultipleMatrices("Please supply either a Precision Matrix or a Covariance Matrix")
    if precision_matrix is None and cov is not None:
        precision_matrix = jnp.linalg.inv(cov)
    if precision_matrix is None and cov is None:
        precision_matrix = jnp.eye(dim)
    return jnp.exp(-x.dot(precision_matrix.dot(x)))
def hamiltionian(q, p, Mass_inv, target):
    return 0.5* (p@Mass_inv@p) - jnp.log(target)

target = jit(gauss_ndimf_jax)
grad_target = jax.grad(target)
grad_H_p = jax.grad(hamiltionian, argnums=1)
grad_H_q = jax.grad(hamiltionian, argnums=0)


def leapfrog(state, tau, T, Mass_inv, target_func):
    "Symplectic integrator: Leapfrog"
    N = int(T/tau)
    q, p = state

    @jit
    def step(carry,_):
        q, p = carry
        q = q + 0.5*tau* grad_H_p(q, p, Mass_inv, target_func(q)) # q_half
        p = p - tau*grad_H_q(q, p, Mass_inv, target_func(q)) # p_full
        q = q + 0.5*tau* grad_H_p(q, p, Mass_inv, target_func(q)) # q_full
        return [q,p], None

    carry, _ = jax.lax.scan(step, [q,p], xs=None, length=N)
    q,p=carry
    return [q, p]

@jit
def hmc_kernel_jax(carry, key, tau, T, Mass_inv, target_func):
    q_current, samples_history = carry
    dim = q_current.shape[0]
    key_draw_p, key_accept = jax.random.split(key)

    # Draw momentum
    p_current = jax.random.normal(key_draw_p, shape=(dim,))

    # Integrate
    q_star, p_star = leapfrog([q_current, p_current], tau, T, Mass_inv, target_func)

    # Calculate change in Hamiltonian
    delta_H = hamiltionian(q_star, p_star, Mass_inv, target_func(q_star)) - hamiltionian(q_current, p_current, Mass_inv, target_func(q_current))

    # Acceptance step
    alpha = jnp.minimum(1., jnp.exp(-delta_H))
    accept = jax.random.uniform(key_accept, shape=())
    is_accepted = accept <= alpha

    q_next = jnp.where(is_accepted, q_star, q_current)

    # Append accepted sample to history
    samples_history = jnp.vstack([samples_history, q_next.reshape(1, -1)])

    return (q_next, samples_history), None


def sample_jax(initial_sample, tau, T, num_samps, key, Mass_inv, target_func):
    """
    Perform Metropolis-Hastings sampling using JAX.
    Returns:
    jnp.ndarray: Array of sampled values.
    """
    dim = initial_sample.shape[0]
    keys = jax.random.split(key, num_samps)

    # Initial carry: (current state, initial samples history)
    initial_samples_history = initial_sample.reshape(1, -1)
    initial_carry = (initial_sample, initial_samples_history)

    # Use lax.scan for efficient sampling
    (final_q, samples), _ = jax.lax.scan(
        lambda carry, k: hmc_kernel_jax(carry, k, tau, T, Mass_inv, target_func),
        initial_carry,
        xs=keys
    )

    return samples