In [None]:
"""
Description:
    CHMC Implementation with AVF: FPI
    USE THE CORRECT ENVIRONMENT:  CHMC_FALL_2025
    YYYY-MM-DD

Author: John Gallagher
Created: 2025-09-28
Last Modified: 2025-10-09
Version: 1.0.0

"""

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
import time

jax.config.update("jax_enable_x64", True)


@jit
def gauss_ndimf_jax(x, precision_matrix=None, cov=None, dim=2):
    """n-Dim Gaussian target distribution."""
    dim = len(x)

    # Error Classes
    class MultipleMatrices(Exception):
        pass

    # dealing with getting a precision matrix or cov matrix
    if precision_matrix is not None and cov is not None:
        raise MultipleMatrices(
            "Please supply either a Precision Matrix or a Covariance Matrix"
        )
    if precision_matrix is None and cov is not None:
        precision_matrix = jnp.linalg.inv(cov)
    if precision_matrix is None and cov is None:
        precision_matrix = jnp.eye(dim)
        # jnp.linalg.det(precision_matrix)**(-1/2)
        # (2*jnp.pi)**(-dim/2)*
    return 0.5 * (x @ precision_matrix @ x)


def qex(qp):
    """
    qex: q extracted from qp state vector
    """
    dim = len(qp) // 2
    return qp[:dim]


def pex(qp):
    """
    pex: p extracted from qp state vector
    """
    dim = len(qp) // 2
    return qp[dim:]


def J_sym(vec):
    """
    J is the symplectic Jacobian matrix for Hamiltonians where J = ([[0, I]])
    """
    dim = len(vec) // 2
    return jnp.concatenate([vec[dim:], -vec[:dim]])


def qJ_sym(vec):
    """
    updates q side of vector with qdot = p
    Returns
    array([p], [0])
    """
    dim = len(vec) // 2
    return jnp.concatenate([vec[dim:], jnp.zeros(dim)])


def pJ_sym(vec):
    """
    qp with p = -qdot only
    Returns
    array([p], [0])
    """
    dim = len(vec) // 2
    return jnp.concatenate([jnp.zeros(dim), -vec[:dim]])


def draw_p(qp, key):
    q = qex(qp)
    p = jax.random.normal(key, shape=(dim,))
    return jnp.concatenate([q, p]), None


def gen_leapfrog(gradH, tau, N):
    def leapfrog(qp):
        """
        Requires gradH, tau, N
        Leapfrog integrator
        Takes state vector qp, and integrates it according to hamiltonian Ham

        """

        def lf_step(carry_in, _):
            qp0 = carry_in
            qhalf_p0 = qp0 + 0.5 * tau * qJ_sym(gradH(qp0))
            qhalf_pout = qhalf_p0 + tau * pJ_sym(gradH(qhalf_p0))
            qp_out = qhalf_pout + 0.5 * tau * qJ_sym(gradH(qhalf_pout))
            return qp_out, _

        qp_final, _ = jax.lax.scan(lf_step, qp, xs=None, length=N)
        return qp_final

    return leapfrog


def midpointFPI(qp, tol=1e-3, max_iter=10, solve=jnp.linalg.solve):
    """
    FPI_mid integrator
    Requries qp:statevector, and gradH defined before hand

    y(i+1) = y(i) + tau * J_sym GradH( 0.5*(y(i)+y(i+1)))

    """
    x0 = qp

    def G(y):
        """
        G(y) = x0 + tau * J_sym GradH( 0.5*(x+y))
        """
        midpoint = 0.5 * (x0 + y)
        return x0 + tau * J_sym(gradH(midpoint))

    def F(y):
        return y - G(y)

    def newton_step(qp):
        jacF = jax.jacobian(F)
        qpout = x0 - solve(jacF(qp), F(qp))
        return qpout

    def cond(carry):
        i, qp = carry
        Fqp = F(qp)
        err = jnp.linalg.norm(Fqp)
        return (err > tol) & (i < max_iter)

    def body_step(carry):
        i, qp = carry
        return [i + 1, newton_step(qp)]

    _, qp_out = jax.lax.while_loop(cond, body_step, [0, qp])
    return qp_out


def accept(delta, key):
    alpha = jnp.minimum(1.0, jnp.exp(delta))
    u = jax.random.uniform(key, shape=())
    return u <= alpha


def hmc_kernel(carry_in, key):
    carry, _, _ = carry_in
    qp0, _ = draw_p(carry, key)
    qp_star = jit_integrator(qp0)
    deltaH = jit_H(qp0) - jit_H(qp_star)  # -(final - init) = init -final
    is_accepted = accept(deltaH, key)
    qp_out = jnp.where(is_accepted, qp_star, qp0)
    carry_out = [qp_out, deltaH, is_accepted]
    return carry_out, carry_out


def hmc_sampler(initial_sample, keys, ):
    _, samples = jax.lax.scan(hmc_kernel, initial_sample, xs=keys)
    return samples


def hamiltionian(qp):
    q, p = qex(qp), pex(qp)
    return 0.5 * jnp.sum(p @ Mass_inv @ p) - jnp.log(target(q))


def J_H(gH):
    """Same operation as Symplectic Jacobian"""
    return jnp.concatenate([gH[dim:], -gH[:dim]])


target = jit(gauss_ndimf_jax)
grad_target = jit(jax.grad(target))
jit_H = jit(hamiltionian)
gradH = jax.jit(jax.grad(hamiltionian))

# jit_integrator = jax.jit(leapfrog)
jit_integrator = jit(midpointFPI)

# Set parameters
key = jax.random.PRNGKey(1)
dim = 2
initnum_samples = 1
mainnum_samples = 10000
keys_start = jax.random.split(key, initnum_samples)
keys_main = jax.random.split(key, mainnum_samples)
qp_init = jax.random.normal(key, shape=(2 * dim,))

# Structure of carry
# init_sample: [Array: sample, float: deltaH, bool: Accepted]
init_sample = [qp_init, 1, False]
Mass_inv = jnp.eye(dim)
tau = 0.2
T = 1
tol = 1e-4
max_iter = 100
# compile
start = time.time()
jhmc_sampler = jit(hmc_sampler)
sample_FPI = jhmc_sampler(init_sample, keys_start)
end = time.time()
print("1st run:", end - start)
# main run
start = time.time()
sample_FPI = jhmc_sampler(init_sample, keys_main)
end = time.time()
print(
    f"{mainnum_samples} runs: {end - start:.2f} \n 1 run:  {(end-start)/mainnum_samples}"
)

In [None]:
keys_start.shape

In [None]:
def gen_hmc_kernel(H, tau, N):
    gradH = jax.grad(H)
    integrator = gen_leapfrog(gradH, tau, N)

    def hmc_kernel(carry_in, key):
        carry, _, _ = carry_in
        qp0, _ = draw_p(carry, key)
        print(qp0)
        qp_star = integrator(qp0)
        deltaH = H(qp0) - H(qp_star)  # -(final - init) = init -final
        is_accepted = accept(deltaH, key)
        qp_out = jnp.where(is_accepted, qp_star, qp0)
        carry_out = [qp_out, deltaH, is_accepted]
        return carry_out, carry_out
    return hmc_kernel
def hmc_sampler(initial_sample, keys, H, tau, T):
    N = jnp.ceil(T/tau).astype(int)
    hmc_kernel = gen_hmc_kernel(H, tau, N)
    _, samples = jax.lax.scan(hmc_kernel, initial_sample, xs=keys)
    return samples

hmc_sampler([qp_init, 1, False], keys_start, hamiltionian, 0.2, 0.2)

In [None]:
qpstart = draw_p(qp_init, keys_start[1])

In [None]:
qpstart[0]

In [None]:
lftest = gen_leapfrog(gradH, 0.2, 1)
lftest(qpstart[0])

In [None]:
N = jnp.ceil(0.2/0.2).astype(int)
kerstep = gen_hmc_kernel(hamiltionian, 0.2, N)
kerstep([qpstart[0],_ ,_], key)

In [None]:
def gen_explicit(dim2n):
    nI = jnp.eye(n)
    n = dim2n // 2
    out = jnp.block([[nI, -nI][nI, nI]])
    return out

In [None]:
import numpy as np
dim = 1000
initnum_samples = 1
keys_start = jax.random.split(key, initnum_samples)


# Structure of carry
# init_sample: [Array: sample, float: deltaH, bool: Accepted]

Mass_inv = jnp.eye(dim)
# tau = 0.2
T = 2 # update to be N = T_final//tau
tol = 1e-4
max_iter = 10


steps = 11
tau_set = jnp.linspace(0.1,.2,2)
key2 = jax.random.PRNGKey(1)
qp_init = jax.random.normal(key2, shape=(2*dim,))
init_sample = [qp_init, 1, False]
mainnum_samples = 1000
tau_samples = np.zeros((mainnum_samples,dim,len(tau_set)))
key_chain = []
for i,t in enumerate(tau_set):
    print(t)
    tau = t
    N = jnp.ceil(T/tau).astype(int)
    jit_integrator = jax.jit(leapfrog, static_argnums=(1,2))
    key1, key2 = jax.random.split(key2)
    key_chain.append([key1, key2])
    keys_main = jax.random.split(key1, mainnum_samples)
    tau_samples[:,:,i] = hmc_sampler(init_sample, keys_main)
    # print(samples_tau[0].shape)
# collection_tau contains items 
#  tau: [Array: sample, float: deltaH, bool: Accepted]
# collection_tau[1][1]
# tau_acceptance_rate = np.zeros(steps)
# for i, sample in enumerate(collection_tau):
#     tau_acceptance_rate[i] = jnp.sum(sample[2])    


In [None]:
type(collection_tau[1][2])

In [None]:
plt.plot(tau_set, tau_acceptance_rate/mainnum_samples, marker = '*')
plt.title("tau vs acceptance rate for samples")

In [None]:
np.sum(jnp.array([True, False]))

In [None]:
jnp.sum(collection_tau[1][2])

In [None]:
plt.hist(collection_tau[10][1],bins=40, density=True)

In [None]:
plt.hist(collection_tau[2][1],bins=40, density=True)

In [None]:
plt.hist(collection_tau[3][1],bins=40, density=True)

In [None]:
plt.hist(collection_tau[10][1],bins=40, density=True)

In [None]:
plt.hist(collection_tau[0][1],bins=40, density=True)

In [None]:
plt.hist(np.array(sample_LF[1]), density=True)
plt.title('Histogram of deltaH sample_LF')

In [None]:
#set integrator
jit_integrator = jax.jit(leapfrog)

#change rngkey

key = jax.random.PRNGKey(4)
dim = 30
initnum_samples = 1
mainnum_samples = 1000
keys_start = jax.random.split(key, initnum_samples)
keys_main = jax.random.split(key, mainnum_samples)
qp_init = jax.random.normal(key, shape=(2*dim,))
Mass_inv = jnp.eye(dim)
tau = 0.2
T = 1
tol = 1e-4
max_iter = 1000
# compile
start=time.time()
jhmc_sampler = jit(hmc_sampler)
sample_LF = jhmc_sampler(init_sample, keys_start)
end = time.time()
print("1st run:", end-start)
# main run
start = time.time()
sample_LF = jhmc_sampler(init_sample, keys_main)
end = time.time()
print(f"{mainnum_samples} runs: {end - start:.2f} \n 1 run:  {(end-start)/mainnum_samples}")

In [None]:
# np.savetxt('sample_LF.csv', sample_LF, delimiter=',')
# np.savetxt('sample_FPI.csv', sample_FPI, delimiter=',')

In [None]:
plt.hist(np.array(sample_LF[1]), density=True)
plt.title('Histogram of deltaH sample_LF')

In [None]:
plt.

In [None]:
def plot2vecHist(sample, col1, col2, numbins=50):
    fig, ax1 = plt.subplots()
    ax1.hist2d(np.array(sample[0][:,col1]),np.array(sample[0][:,col2]), bins=numbins, density = True)
    ax1.set_xlabel(f'Dim_{col1} Values')
    ax1.set_ylabel(f'Dim_{col2} Values')
    ax1.set_title(f'Sample_LF x:Dim_{col1}, y:Dim_{col2}, histogram Accepted Distribution')
    plt.show()
plot2vecHist(sample_LF, 2,3)

In [None]:
sample_LF[:,2].shape

In [None]:
sample_LF[:,3].shape

In [None]:
import matplotlib.pyplot as plt


fig, ax1 = plt.subplots()

ax1.hist2d(np.array(sample_LF[0][:,2]),np.array(sample_LF[0][:,3]), bins=50, density = True)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('x-Value')
ax1.set_title('LF x:Dim_{col1}, y:Dim_{col2}, histogram Accepted Distribution')

# Plot target distribution
# # ax2 = ax1.twinx()
# x_vals = np.linspace(-3, 3, 10000)
# y_vals = [gauss_f(x)/(np.pi)**0.5 for x in x_vals]
# ax1.plot(x_vals, y_vals, 'r-', label='Target Function')
# ax1.set_ylabel('Target Function Value')
# ax1.set_ylim(0, 0.6)

plt.show()

In [None]:
sample_LF[:,:dim].shape

In [None]:
jit_integrator = jit(midpointFPI)
# compile
start=time.time()
sample_FPI = hmc_sampler(qp_init, keys_start)
end = time.time()
print("1st run:", end-start)
# main run
start = time.time()
sample_FPI = hmc_sampler(qp_init, keys_main)
end = time.time()
print(f"{mainnum_samples} runs: {end - start} \n 1 run:  {(end-start)/mainnum_samples}")


In [None]:
# Calculate the number of unique samples for Leapfrog
unique_samples_LF = jnp.unique(sample_LF[:,:dim], axis=0)
num_unique_LF = unique_samples_LF.shape[0]
print(f"Number of unique samples (Leapfrog): {num_unique_LF}")

# Calculate the number of unique samples for Midpoint FPI
unique_samples_FPI = jnp.unique(sample_FPI[:,:dim], axis=0)
num_unique_FPI = unique_samples_FPI.shape[0]
print(f"Number of unique samples (Midpoint FPI): {num_unique_FPI}")

# Compare the number of unique samples to the total number of samples
print(f"Total number of samples: {mainnum_samples}")

In [None]:
sample_FPI.shape

In [None]:
# jit_integrator = jit(leapfrog)
# # draw qp states
# inits, _ = jax.vmap(draw_p, in_axes=(None, 0))(qp_init, keys_main)
# vmap_integrator = jax.vmap(jit_integrator)
# sample_LF = vmap_integrator(inits)

# # Midpoint FPI
# jit_integrator = jit(midpointFPI)
# vmap_integrator = jax.vmap(jit_integrator)
# sample_FPI = vmap_integrator(inits)

sample_diff = sample_FPI - sample_LF

# norm of the difference along row
norm_diff = jnp.linalg.norm(sample_diff, axis=1)

# mean,max,min
print("Mean norm difference:", jnp.mean(norm_diff))
print("Max norm difference:", jnp.max(norm_diff))
print("Min norm difference:", jnp.min(norm_diff))

In [None]:
str(sample_FPI)

In [None]:
def plot2vecHist(sample, col1, col2, numbins=50):
    fig, ax1 = plt.subplots()
    ax1.hist2d(np.array(sample[0][:,col1]),np.array(sample[0][:,col2]), bins=numbins, density = True)
    ax1.set_xlabel(f'Dim_{col1} Values')
    ax1.set_ylabel(f'Dim_{col2} Values')
    ax1.set_title(f'Sample_FPI x:Dim_{col1}, y:Dim_{col2}, histogram Accepted Distribution')
    plt.show()
plot2vecHist(sample_FPI, 4,6)

In [None]:
def gauss_f(x):
    """Gaussian target distribution."""
    return np.exp(-x**2)

fig, ax1 = plt.subplots()

# Plot histogram of accepted samples
ax1.hist(np.array(sample_FPI[:,3]), bins=50, density = True)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('x-Value')
ax1.set_title('FPI Resulting Accepted Distribution vs Target Function')

# Plot target distribution
# ax2 = ax1.twinx()
x_vals = np.linspace(-3, 3, 10000)
y_vals = [gauss_f(x)/(np.pi)**0.5 for x in x_vals]
ax1.plot(x_vals, y_vals, 'r-', label='Target Function')
ax1.set_ylabel('Target Function Value')
ax1.set_ylim(0, 0.6)

plt.show()

In [None]:
fig, ax1 = plt.subplots()

# Plot histogram of accepted samples
ax1.hist2d(np.array(sample_LF[:,0]),np.array(sample_LF[:,1]), bins=50, density = True)
ax1.set_ylabel('Frequency')
ax1.set_xlabel('x-Value')
ax1.set_title('FPI Resulting Accepted Distribution vs Target Function')

# Plot target distribution
# # ax2 = ax1.twinx()
# x_vals = np.linspace(-3, 3, 10000)
# y_vals = [gauss_f(x)/(np.pi)**0.5 for x in x_vals]
# ax1.plot(x_vals, y_vals, 'r-', label='Target Function')
# ax1.set_ylabel('Target Function Value')
# ax1.set_ylim(0, 0.6)

plt.show()

In [None]:
sample_diff = sample_FPI[:,:3] - sample_LF[:,:3]

# norm of the difference along row
norm_diff = jnp.linalg.norm(sample_diff, axis=1)

# mean,max,min
print("Mean norm difference:", jnp.mean(norm_diff))
print("Max norm difference:", jnp.max(norm_diff))
print("Min norm difference:", jnp.min(norm_diff))

In [None]:
norm = jit(jax.vmap(jnp.linalg.norm))

In [None]:
err = norm(sample_FPI[:,:1000] - sample_LF[:,:1000])