In [None]:
"""
Description:
    CHMC Implementation with AVF: FPI
    USE THE CORRECT ENVIRONMENT:  CHMC_FALL_2025
    YYYY-MM-DD

Author: John Gallagher
Created: 2025-09-28
Last Modified: 2025-10-16
Version: 1.0.0

"""

import jax
import jax.numpy as jnp
from jax import jit
import time

jax.config.update("jax_enable_x64", True)


def gauss_ndimf_jax(x, precision_matrix=None, cov=None, dim=2):
    """n-Dim Gaussian target distribution."""
    dim = len(x)

    # Error Classes
    class MultipleMatrices(Exception):
        pass

    # dealing with getting a precision matrix or cov matrix
    if precision_matrix is not None and cov is not None:
        raise MultipleMatrices(
            "Please supply either a Precision Matrix or a Covariance Matrix"
        )
    if precision_matrix is None and cov is not None:
        precision_matrix = jnp.linalg.inv(cov)
    if precision_matrix is None and cov is None:
        precision_matrix = jnp.eye(dim)
        # jnp.linalg.det(precision_matrix)**(-1/2)
        # (2*jnp.pi)**(-dim/2)*
    return jnp.exp(-0.5 * (x @ precision_matrix @ x))


def qex(qp):
    """
    qex: q extracted from qp state vector
    """
    dim = len(qp) // 2
    return qp[:dim]


def pex(qp):
    """
    pex: p extracted from qp state vector
    """
    dim = len(qp) // 2
    return qp[dim:]


def J_sym(vec):
    """
    J is the symplectic Jacobian matrix for Hamiltonians where J = ([[0, I]])
    """
    dim = len(vec) // 2
    return jnp.concatenate([vec[dim:], -vec[:dim]])


def qJ_sym(vec):
    """
    updates q side of vector with qdot = p
    Returns
    array([p], [0])
    """
    dim = len(vec) // 2
    return jnp.concatenate([vec[dim:], jnp.zeros(dim)])


def pJ_sym(vec):
    """
    qp with p = -qdot only
    Returns
    array([p], [0])
    """
    dim = len(vec) // 2
    return jnp.concatenate([jnp.zeros(dim), -vec[:dim]])


def draw_p(qp, key):
    q = qex(qp)
    p = jax.random.normal(key, shape=(dim,))
    return jnp.concatenate([q, p]), None


def gen_leapfrog(gradH, tau, N):
    def leapfrog(qp):
        """
        Requires gradH, tau, N
        Leapfrog integrator
        Takes state vector qp, and integrates it according to hamiltonian Ham

        """

        def lf_step(carry_in, _):
            qp0 = carry_in
            qhalf_p0 = qp0 + 0.5 * tau * qJ_sym(gradH(qp0))
            qhalf_pout = qhalf_p0 + tau * pJ_sym(gradH(qhalf_p0))
            qp_out = qhalf_pout + 0.5 * tau * qJ_sym(gradH(qhalf_pout))
            return qp_out, _

        qp_final, _ = jax.lax.scan(lf_step, qp, xs=None, length=N)
        return qp_final

    return leapfrog

def gen_midpointFPI(gradH, tau, N, tol,maxIter, solve = jnp.linalg.solve):
    """
    Generates midpointFPI function with appropriate statics: 
    tau, tol, maxIter, 
    """
    def midpointFPI(qp, _):
        """
        FPI_mid integrator
        Requries qp:statevector, and gradH defined before hand

        y(i+1) = y(i) + tau * J_sym GradH( 0.5*(y(i)+y(i+1)))

        """
        x0 = qp

        def G(y):
            """
            G(y) = x0 + tau * J_sym GradH( 0.5*(x+y))
            """
            midpoint = 0.5 * (x0 + y)
            return x0 + tau * J_sym(gradH(midpoint))

        def F(y):
            return y - G(y)

        def newton_step(qp):
            jacF = jax.jacobian(F)
            qpout = x0 - solve(jacF(qp), F(qp))
            return qpout

        def cond(carry):
            i, qp = carry
            Fqp = F(qp)
            err = jnp.linalg.norm(Fqp)
            return (err > tol) & (i < maxIter)

        def body_step(carry):
            i, qp = carry
            return [i + 1, newton_step(qp)]

        _, qp_out = jax.lax.while_loop(cond, body_step, [0, qp])
        return qp_out, qp_out
    def midpointFPI_T(qp):
        qp_out, _ = jax.lax.scan(midpointFPI, qp, xs = None, length = N)
        return qp_out
    return midpointFPI_T

def accept(delta, key):
    alpha = jnp.minimum(1.0, jnp.exp(delta))
    u = jax.random.uniform(key, shape=())
    return u <= alpha


def gen_hmc_kernel(H, tau, N):
    gradH = jax.grad(H)
    integrator = gen_leapfrog(gradH, tau, N)

    def hmc_kernel(carry_in, key):
        carry, _, _ = carry_in
        qp0, _ = draw_p(carry, key)
        qp_star = integrator(qp0)
        deltaH =  H(qp_star) - H(qp0)  # -(final - init) = init -final
        is_accepted = accept(-deltaH, key)
        qp_out = jnp.where(is_accepted, qp_star, qp0)
        carry_out = [qp_out, deltaH, is_accepted]
        return carry_out, carry_out
    return hmc_kernel
def gen_chmc_kernel(H, tau, N, tol, maxIter, solve=jnp.linalg.solve):
    gradH = jax.grad(H)
    integrator = gen_midpointFPI(gradH, tau, N, tol, maxIter, solve=jnp.linalg.solve )

    def chmc_kernel(carry_in, key):
        carry, _, _ = carry_in
        qp0, _ = draw_p(carry, key)
        qp_star = integrator(qp0)
        deltaH = H(qp0) - H(qp_star)  # -(final - init) = init -final
        is_accepted = accept(deltaH, key)
        qp_out = jnp.where(is_accepted, qp_star, qp0)
        carry_out = [qp_out, deltaH, is_accepted]
        return carry_out, carry_out
    return chmc_kernel

def hmc_sampler(initial_sample, keys, H, tau, N):
    """
    inputs: initial_sample, keys, H, tau, T
    """
    hmc_kernel = gen_hmc_kernel(H, tau, N)
    _, samples = jax.lax.scan(hmc_kernel, initial_sample, xs=keys)
    return samples
def chmc_sampler(initial_sample, keys, H, tau, N, tol, maxIter, solve=jnp.linalg.solve):
    """
    inputs: initial_sample, keys, H, tau, N, tol, maxIter, solve=jnp.linalg.solve
    """
    chmc_kernel = gen_chmc_kernel(H, tau, N, tol, maxIter, solve=jnp.linalg.solve)
    _, samples = jax.lax.scan(chmc_kernel, initial_sample, xs=keys)
    return samples    

def gen_hamiltonian(Mass_inv, target):
    
    def hamiltonian(qp):
        q, p = qex(qp), pex(qp)
        return 0.5 * jnp.sum(p @ Mass_inv @ p) - jnp.log(target(q))
    return hamiltonian


# def J_H(gH):
#     """Same operation as Symplectic Jacobian"""
#     return jnp.concatenate([gH[dim:], -gH[:dim]])


dim = 1000
Mass_inv = jnp.eye(dim)
target = gauss_ndimf_jax
hamiltonian = gen_hamiltonian(Mass_inv, target)
grad_target = jit(jax.grad(target))
jit_H = jit(hamiltonian)
gradH = jax.jit(jax.grad(hamiltonian))

# jit_integrator = jax.jit(leapfrog)
# jit_integrator = jit(midpointFPI)

# Set parameters
key = jax.random.PRNGKey(1)

initnum_samples = 1
mainnum_samples = 1000
keys_start = jax.random.split(key, initnum_samples)
keys_main = jax.random.split(key, mainnum_samples)
qp_init = jax.random.normal(key, shape=(2 * dim,))

# Structure of carry
# init_sample: [Array: sample, float: deltaH, bool: Accepted]
init_sample = [qp_init, 1, False]

tau = 0.5
T = 1
N = int(jnp.ceil(T/tau))
tol = 1e-4
max_iter = 100
# compile
start = time.time()
jhmc_sampler = jit(hmc_sampler, static_argnums=(2,3,4))
sample_FPI = jhmc_sampler(init_sample, keys_start, hamiltonian, tau, N)
end = time.time()
print("1st run:", end - start)
# main run
start = time.time()
sample_FPI = jhmc_sampler(init_sample, keys_main,  hamiltonian, tau, N)
end = time.time()
print(
    f"Main run:\n {mainnum_samples} runs: {end - start:.2f} \n 1 run:  {(end-start)/mainnum_samples}"
)

In [None]:
import numpy as np
ndims = 9
dims = np.logspace(2,10, ndims,base=2, dtype=int)
# Mass_inv = jnp.eye(dim)
# target = gauss_ndimf_jax
# hamiltonian = gen_hamiltonian(Mass_inv, target)
# grad_target = jit(jax.grad(target))
# jit_H = jit(hamiltonian)
# gradH = jax.jit(jax.grad(hamiltonian))

# jit_integrator = jax.jit(leapfrog)
# jit_integrator = jit(midpointFPI)

# Set parameters
key = jax.random.PRNGKey(1)

initnum_samples = 1
mainnum_samples = 1000
keys_start = jax.random.split(key, initnum_samples)
keys_main = jax.random.split(key, mainnum_samples)
# qp_init = jax.random.normal(key, shape=(2 * dim,))

# Structure of carry
# init_sample: [Array: sample, float: deltaH, bool: Accepted]
init_sample = [qp_init, 1, False]

tol = 1e-4
max_iter = 100
tau = 0.2
T = 1

# tol = 1e-4
# max_iter = 100
numtaus = 7
taufinal =0.5
tauinit = 0.1
tau_set = 0.2*1/jnp.logspace(0,6,numtaus,base=2)
# tau_set = jnp.linspace(tauinit, taufinal, numtaus)
# [Array: sample, float: deltaH, bool: Accepted]
# samples_taus = np.zeros((mainnum_samples, 2*dim, numtaus, ndims))
samples_deltaHs = np.zeros((mainnum_samples,numtaus, ndims))
samples_accepted = np.zeros((mainnum_samples,numtaus, ndims))
hmc_samples = []
for dim in dims:
    hmc_samples.append(np.zeros((mainnum_samples, dim)))

jhmc_sampler = jit(hmc_sampler, static_argnums=(2,3,4))
for i, taus in enumerate(tau_set):
    for j, dim in enumerate(dims):
        N = int(jnp.ceil(T/taus))
        # if dim == 4:
        print(f'taus: {taus}, dim: {dim}, N: {N}')
        Mass_inv = jnp.eye(dim)
        target = gauss_ndimf_jax
        hamiltonian = gen_hamiltonian(Mass_inv, target)
        grad_target = jit(jax.grad(target))
        jit_H = jit(hamiltonian)
        gradH = jax.jit(jax.grad(hamiltonian))
        qp_init = jax.random.normal(key, shape=(2 * dim,))
        init_sample = [qp_init, 1, False]
        hmc_samples[j], samples_deltaHs[:,i, j], samples_accepted[:,i, j] = hmc_sampler(init_sample, keys_main,  hamiltonian, taus, N)

In [None]:
0.4*1/np.logspace(0,6,7,base=2)

In [None]:
numtaus = 5
taufinal =0.5
tauinit = 0.1
dtau = (taufinal-tauinit)/(numtaus -1)
print(dtau)

In [None]:
print(f'(num samples, num tau, numdims)\n {samples_deltaHs.shape}')
for i in range(len(tau_set)):
    for j in range(len(dims)):
        print(f'tau: {tau_set[i]:.3f}, dim = {dims[j]}\n min: {samples_deltaHs[:,i,j].min():.6f} max: {samples_deltaHs[:,i,j].max():.6f} mean: {samples_deltaHs[:,i,j].mean():.7f} std: {samples_deltaHs[:,i,j].std():.6f} #Accepts: {samples_accepted[:,i,j].sum()}')

In [None]:
print(samples_deltaHs.shape)
print(tau_set.shape, dims.shape)

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["text.usetex"] = True
fig,axs = plt.subplots(len(tau_set))
fig.set_size_inches(8,20)

# index: (mainnum_samples,numtaus, ndims)
dtau = (taufinal-tauinit)/(numtaus -1)

for i in range(len(tau_set)):
    for j in range(len(dims)):
        axs[i].hist(samples_deltaHs[:,i,j],bins=10, density = True, label = f'{dims[j]}-dim')
        axs[i].set_xlim(-0.5, 0.5)
        axs[i].set_ylabel
        axs[i].set_title(f'$\\tau$: {tau_set[i]:0.1f}, T={T}  $\Longrightarrow$   N = {int(jnp.ceil(T/((i+1)*dtau)))}')
        axs[i].legend()
fig.suptitle('Histogram of LF: $\Delta H$', y=0.91)


In [None]:
def alphaex(deltaH):
    return jnp.minimum(1., jnp.exp(deltaH))
valphaex = jax.vmap(alphaex)
alphas = alphaex(samples_deltaHs)
meanalphas = alphas.mean(axis=0)
meanalphas.shape
tau_set.shape

In [None]:
import matplotlib.pyplot as plt

def alphaex(deltaH):
    return jnp.minimum(1., jnp.exp(deltaH))
valphaex = jax.vmap(alphaex)
alphas = alphaex(samples_deltaHs)
meanalphas = alphas.mean(axis=0)
# meanalphas.shape
# intersection = np.interp(dims,dims, meanalphas)
for onetau in range(len(tau_set)):
    plt.semilogx(dims, meanalphas[onetau,:], marker = '*', label = f'$\\tau  = {tau_set[onetau]:.5f}$', base = 2)
plt.semilogx(dims, np.ones(len(dims))*0.98, base=2,marker = '*', color ='red')
plt.ylabel(r'Mean$(\alpha)$')
plt.xlabel('Dimension')
plt.legend()
plt.title('LF Acceptance rate vs dimension')
plt.savefig('figures/LF_accept(tau)_v_hidim.png')

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["text.usetex"] = True
fig,axs = plt.subplots(5)
fig.set_size_inches(8,20)

# index: (mainnum_samples,numtaus, ndims)
dtau = (taufinal-tauinit)/(numtaus -1)

for i in range(numtaus):
    for j in range(ndims):
        axs[i].hist(samples_deltaHs[:,i,j],bins=10, density = True, label = f'{2**(j+2)}-dim')
        axs[i].set_xlim(-0.5, 0.5)
        axs[i].set_ylabel
        axs[i].set_title(f'$\\tau$: {(i+1)*dtau:0.1f}, T={T}  $\Longrightarrow$   N = {int(jnp.ceil(T/((i+1)*dtau)))}')
        axs[i].legend()
fig.suptitle('Histogram of LF: $\Delta H$', y=0.91)


In [None]:
import numpy as np
ndims = 5
dims = np.logspace(2,6, ndims,base=2, dtype=int)
# Mass_inv = jnp.eye(dim)
# target = gauss_ndimf_jax
# hamiltonian = gen_hamiltonian(Mass_inv, target)
# grad_target = jit(jax.grad(target))
# jit_H = jit(hamiltonian)
# gradH = jax.jit(jax.grad(hamiltonian))

# jit_integrator = jax.jit(leapfrog)
# jit_integrator = jit(midpointFPI)

# Set parameters
key = jax.random.PRNGKey(1)

initnum_samples = 1
mainnum_samples = 10000
keys_start = jax.random.split(key, initnum_samples)
keys_main = jax.random.split(key, mainnum_samples)
# qp_init = jax.random.normal(key, shape=(2 * dim,))

# Structure of carry
# init_sample: [Array: sample, float: deltaH, bool: Accepted]
init_sample = [qp_init, 1, False]

tol = 1e-4
maxIter = 1
tau = 0.2
T = 1

# tol = 1e-4
# maxIter = 100
numtaus = 5
taufinal =0.4
tauinit = 0.1
tau_set = jnp.linspace(tauinit, taufinal, numtaus)
# [Array: sample, float: deltaH, bool: Accepted]
# samples_taus = np.zeros((mainnum_samples, 2*dim, numtaus, ndims))
chmc_deltaHs = np.zeros((mainnum_samples,numtaus, ndims))
chmc_accepted = np.zeros((mainnum_samples,numtaus, ndims))
chmc_samples = []
for dim in dims:
    chmc_samples.append(np.zeros((mainnum_samples, dim)))

jchmc_sampler = jit(chmc_sampler, static_argnums=(2,3,4,5,6))
for i, taus in enumerate(tau_set):
    for j, dim in enumerate(dims):
        N = int(jnp.ceil(T/taus))
        if dim == 4:
            print(f'taus: {taus}, dim: {dim}, N: {N}')
        Mass_inv = jnp.eye(dim)
        target = gauss_ndimf_jax
        hamiltonian = gen_hamiltonian(Mass_inv, target)
        grad_target = jit(jax.grad(target))
        jit_H = jit(hamiltonian)
        gradH = jax.jit(jax.grad(hamiltonian))
        qp_init = jax.random.normal(key, shape=(2 * dim,))
        init_sample = [qp_init, 1, False]
        chmc_samples[j], chmc_deltaHs[:,i, j], chmc_accepted[:,i, j] = chmc_sampler(init_sample, keys_main,  hamiltonian, taus, N, tol, maxIter)
        # initial_sample, keys, H, tau, N, tol, maxIter, solve=jnp.linalg.solve

In [None]:
print(f'(num samples, num tau, numdims)\n {chmc_deltaHs.shape}')
for i in range(numtaus):
    for j in range(ndims):
        print(f'tau: {(i+1)*0.1:.3f}, dim = {2**(j+1)}\n min: {chmc_deltaHs[:,i,j].min():.6f} max: {chmc_deltaHs[:,i,j].max():.6f} mean: {chmc_deltaHs[:,i,j].mean():.7f} std: {chmc_deltaHs[:,i,j].std():.6f} #Accepts: {chmc_accepted[:,i,j].sum()}')

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["text.usetex"] = True
fig,axs = plt.subplots(5)
fig.set_size_inches(8,20)

# index: (mainnum_samples,numtaus, ndims)
dtau = (taufinal-tauinit)/(numtaus -1)

for i in range(numtaus):
    for j in range(ndims):
        axs[i].hist(chmc_deltaHs[:,i,j],bins=10, density = True, label = f'{2**(j+1)}-dim')
        axs[i].set_xlim(-0.5, 0.5)
        axs[i].set_ylabel
        axs[i].set_title(f'$\\tau$: {(i+1)*dtau:0.1f}, T={T}  $\Longrightarrow$   N = {int(jnp.ceil(T/((i+1)*dtau)))}')
        axs[i].legend()
fig.suptitle('Histogram of LF: $\Delta H$', y=0.91)


In [None]:
chmc_samples[0].shape


In [None]:
plt.hist(chmc_samples[0][:,2],bins=50,density=True)

In [None]:
plt.hist2d(chmc_samples[0][:,2], chmc_samples[0][:,1],bins=25)

In [None]:
plt.hist2d(chmc_samples[0][:,0], chmc_samples[0][:,1],bins=25)

In [None]:
def scatter_hist(x, y, ax, ax_histx, ax_histy):
    # no labels
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histy.tick_params(axis="y", labelleft=False)

    # the scatter plot:
    ax.scatter(x, y)

    # now determine nice limits by hand:
    binwidth = 0.25
    xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
    lim = (int(xymax/binwidth) + 1) * binwidth

    bins = np.arange(-lim, lim + binwidth, binwidth)
    ax_histx.hist(x, bins=bins)
    ax_histy.hist(y, bins=bins, orientation='horizontal')