In [None]:
"""
Description:
    CHMC Implementation with AVF: FPI
    USE THE CORRECT ENVIRONMENT:  CHMC_FALL_2025

Author: John Gallagher
Created: 2025-09-28
Last Modified: 2025-10-4
Version: 1.0.0

This version uses q,p separated. 

"""

import numpy as np
import matplotlib.pyplot as plt
import distributions as dist
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
import time
jax.config.update("jax_enable_x64", True)

qex = lambda qp: qp[:dim]
pex = lambda qp: qp[dim:]


def draw_p(qp, key):
    q = qex(qp)
    p = jax.random.normal(key,shape = (dim,))
    return jnp.concatenate([q,p]), None

def leapfrog(qp):
    def lf_step(carry_in, _):
        q, p = qex(carry_in), pex(carry_in)
        q_half = q + 0.5 * tau * gradH_p(q, p)
        p_new = p - tau * gradH_q(q_half, p)
        q_new = q_half + 0.5 * tau * gradH_p(q_half, p_new)
        carry_out = jnp.concatenate([q_new, p_new])
        return carry_out, _
    qp_final,  _ = jax.lax.scan(lf_step, qp, xs=None, length=T)
    return qp_final

@jit
def midpointFPI(qp):
    x0 = qp
    def G(y):
        midpoint = 0.5*(x0+y)
        return x0 + tau*J_H(grad_xH(midpoint))
    
    def F(y):
        return y-G(y)
    
    def newton_step(qp):
        jacF = jax.jacobian(F)(qp)
        qpout = x0 - jnp.linalg.solve(jacF, F(qp))
        return qpout

    def cond(carry):
        i, qp = carry
        F_qp = F(qp)
        err = jnp.linalg.norm(F_qp)
        return (err > tol) & (i < max_iter)
    
    def body_step(carry):
        i, qp = carry
        return [i + 1, newton_step(qp)]
    
    _, qp_out = jax.lax.while_loop(cond, body_step, [0, qp])
    return qp_out

@jit
def accept(delta, key):
    alpha = jnp.minimum(1., jnp.exp(-delta))
    u = jax.random.uniform(key, shape=())
    return u <= alpha

def hmc_kernel(carry, key):
    qp0, _ = draw_p(carry, key)
    qp_star = jit_integrator(qp0)
    deltaH = jit_H(qp_star) - jit_H(qp0)
    is_accepted = accept(deltaH, key)
    qp_out = jnp.where(is_accepted, qp_star, qp0)
    return qp_out, qp_out

def hmc_sampler(initial_sample, keys):
    # def hmc_step(carry, _):
    #     [next_sample, key], _ = hmc_kernel(carry, _)
    #     return (next_sample, key), next_sample
    _, samples = jax.lax.scan(hmc_kernel, initial_sample, xs=keys)
    return samples


# Function handles into mechanics of HMC Sampler:
def hamiltionian(q,p):
    return 0.5 * (p@Mass_inv@p) - jnp.log(target(q))
def xhamiltionian(qp):
    q, p = qex(qp), pex(qp)
    return 0.5 * (p@Mass_inv@p) - jnp.log(target(q))

def J_H(gH):
    """Same operation as Symplectic Jacobian"""
    return jnp.concatenate([gH[dim:],-gH[:dim]])           

# def J_simplec(n):
#     zero_zero = jnp.zeros((n,n))
#     zero_one = jnp.eye((n))
#     one_zero = -jnp.eye((n))
#     one_one = zero_zero
#     return jnp.block([[zero_zero, zero_one],[one_zero, one_one]])
# I don't know how jit syntax works yet so I just directly did it here. 
target = dist.gauss_ndimf_jax
grad_target = jax.grad(target)
gradH_p = jax.grad(hamiltionian, argnums=1)
gradH_q = jax.grad(hamiltionian, argnums=0)
jit_H = jax.jit(xhamiltionian)
grad_xH = jax.jit(jax.grad(xhamiltionian))
jit_target = jax.jit(target)
jit_grad_target = jax.jit(grad_target)
jit_gradH_p = jax.jit(gradH_p)
jit_gradH_q = jax.jit(gradH_q)
jit_integrator = jax.jit(leapfrog)
# jit_integrator = jit(midpointFPI)

# Set parameters
key = jax.random.PRNGKey(1)
dim = 10000
num_samples = 1000
keys = jax.random.split(key, num_samples)
qp_init = jax.random.normal(key, shape=(2*dim,))
Mass_inv = jnp.eye(dim)
tau = 0.02
T = 1
tol = 2e-4
max_iter = 100

# compile
start=time.time()
sample_LF = hmc_sampler(qp_init, keys)
# sample_fpi = fpi_newton(jnp.concatenate([q0a,p0b]))
# sample_fpi = hmc_sampler(key, [q0a, p0b], 1)
end = time.time()
print("1st run:", end-start)
# main run
# start = time.time()
# samples = hmc_sampler(key, [q0a, p0b], 100)
# end = time.time()

# print("100 runs: ", end - start, "\n 1 run", (end-start)/100)

In [None]:
def midpointFPI(qp):
    x0 = qp
    def G(y):
        midpoint = 0.5*(x0+y)
        return x0 + tau*J_H(grad_xH(midpoint))
    
    def F(y):
        return y-G(y)
    
    def newton_step(qp):
        jacF = jax.jacobian(F)(qp)
        qpout = x0 - jnp.linalg.solve(jacF, F(qp))
        return qpout

    def cond(carry):
        i, qp = carry
        F_qp = F(qp)
        err = jnp.linalg.norm(F_qp)
        return (err > tol) & (i < max_iter)
    
    def body_step(carry):
        i, qp = carry
        return [i + 1, newton_step(qp)]
    
    _, qp_out = jax.lax.while_loop(cond, body_step, [0, qp])
    return qp_out

# x0 = qp_init
# def G(y):
#     midpoint = 0.5*(x0+y)
#     return x0 + tau*J_H(grad_xH(midpoint))

# def F(y):
#     return y-G(y)
# def newton_step(qp):
#     jacF = jax.jacobian(F)(qp)
#     qpout = x0 - jnp.linalg.solve(jacF, F(qp))
#     return qp
# def body_step(carry, _):
#     i, qp = carry
#     return [i + 1, newton_step(qp)], _

# [_ ,y1],_  = body_step([0,x0],_)
# [_, y2], _ = body_step([1,y1],_)
# [_, y3], _ = body_step([1,y2],_)

fpitest = midpointFPI(x0)
jnp.linalg.norm(fpitest-lftest)

In [None]:
target = dist.gauss_ndimf_jax
grad_target = jax.grad(target)
gradH_p = jax.grad(hamiltionian, argnums=1)
gradH_q = jax.grad(hamiltionian, argnums=0)
jit_H = jax.jit(hamiltionian)
jit_target = jax.jit(target)
jit_grad_target = jax.jit(grad_target)
jit_gradH_p = jax.jit(gradH_p)
jit_gradH_q = jax.jit(gradH_q)
jit_integrator = jax.jit(leapfrog)
# jit_integrator = jit(midpointFPI)
key = jax.random.PRNGKey(0)
keyq, keyp, key0 = jax.random.split(key, 3)
q0a = jax.random.normal(keyq, shape=(1000,))
p0b = jax.random.normal(keyp, shape=(1000,))
Mass_inv= jnp.eye(len(q0a))
tau = 0.2
T = 5
tol = 1e-1
max_iter = 10

# compile
start=time.time()
sample_lf = hmc_sampler(key, [q0a, p0b], 1)
end = time.time()
print("1st run:", end-start)
# main run
start = time.time()
samples_lf = hmc_sampler(key, [q0a, p0b], 1000)
end = time.time()
print("1000 runs: ", end - start, "\n 1 run", (end-start)/1000)