In [None]:
"""
Description:
    CHMC Implementation with AVF: FPI
    USE THE CORRECT ENVIRONMENT:  CHMC_FALL_2025

Author: John Gallagher
Created: 2025-09-28
Last Modified: 2025-09-28
Version: 1.0.0

"""
import numpy as np
import matplotlib.pyplot as plt
import distributions as dist
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
import time


@jit
def draw_p(key1, carry_in):
    q, _ = carry_in
    p = jax.random.normal(key1,q.shape)
    carry_out = [q,p] 
    return carry_out, None

@jit
def leapfrog(q, p):
    def lf_step(carry_in, _):
        q, p = carry_in
        q_half = q + 0.5 * tau * jit_gradH_p(q, p)
        p_new = p - tau * jit_gradH_q(q_half, p)
        q_new = q_half + 0.5 * tau * jit_gradH_p(q_half, p_new)
        carry_out = [q_new, p_new]
        return carry_out, None
    [q_final, p_final], _ = jax.lax.scan(lf_step, [q, p], xs=None, length=T)
    return q_final, p_final

@jit
def midpointFPI_blocked(q, p):
    q0, p0 = q, p
    n = len(q)
    @jit
    def F(q, p):
        midq, midp = 0.5*(q+q0), 0.5*(p+p0)
        res_q = q - q0 + tau*gradH_p(midq, midp)
        res_p = p - p0 - tau*gradH_q(midq, midp)
        return [res_q, res_p]
    @jit
    def newton_step(q, p):
        jacF_q = jax.jacobian(F, argnums=0)(q,p)[0] #dF/dq
        jacF_p = jax.jacobian(F, argnums=1)(q,p)[1] #dF/dp
        F_q, F_p = F(q, p)
        qout = q0 - jnp.linalg.solve(jacF_q, F_q)
        pout = p0 - jnp.linalg.solve(jacF_p, F_p)
        return [qout, pout]
    @jit
    def cond(carry):
        i, [q, p]= carry
        res_q, res_p = F(q, p)
        err = jnp.linalg.norm(res_q) + jnp.linalg.norm(res_p)
        return((err> tol) & (i<max_iter))
    @jit
    def body_step(carry):
        i, [q, p] = carry
        return i+1, newton_step(q, p)
    _, [qout, pout] = jax.lax.while_loop(cond, body_step, (0, [q,p]))
    return [qout, pout]

@jit
def accept(key, delta):
    alpha = jnp.minimum(1., jnp.exp(-delta))
    u = jax.random.uniform(key, shape = ())
    return u <= alpha

def hmc_kernel(key, last_sample):
        key0, key1 = jax.random.split(key)
        key2, key3 = jax.random.split(key0)
        start = time.time()
        last_sample, _ = draw_p(key1, last_sample)
        end = time.time()
        print(f"Draw sample: {end- start}")
        q, p = last_sample
        start = time.time()
        q_star, p_star = jit_integrator(q,p)
        end = time.time()
        print(f"Integration: {end-start}")
        deltaH = jit_H(q_star, p_star) - jit_H(q, p)
        is_accepted = accept(key2, deltaH)
        q_out = jnp.where(is_accepted, q_star, q)
        return key3, [q_out, p_star]

def hmc_sampler(key, initial_sample, num_samples):
    def hmc_step(carry, _):
         key, last_sample = carry
         key, next_sample = hmc_kernel(key, last_sample)
         return [key, next_sample], next_sample
    [_, _], samples = jax.lax.scan(hmc_step, [key, initial_sample], xs = None, length = num_samples)
    return samples


# Function handles into mechanics of HMC Sampler:
def hamiltionian(q, p):
    return 0.5* (p@Mass_inv@p) - jnp.log(target(q))
# I don't know how jit syntax works yet so I just directly did it here. 
target = dist.gauss_ndimf_jax
grad_target = jax.grad(target)
gradH_p = jax.grad(hamiltionian, argnums=1)
gradH_q = jax.grad(hamiltionian, argnums=0)
jit_H = jax.jit(hamiltionian)
jit_target = jax.jit(target)
jit_grad_target = jax.jit(grad_target)
jit_gradH_p = jax.jit(gradH_p)
jit_gradH_q = jax.jit(gradH_q)
# jit_integrator = jax.jit(leapfrog)
jit_integrator = jit(midpointFPI)
key = jax.random.PRNGKey(0)
dim = 1000
keyq, keyp, key0 = jax.random.split(key, 3)
q0a = jax.random.normal(keyq, shape=(dim,))
p0b = jax.random.normal(keyp, shape=(dim,))
Mass_inv= jnp.eye(len(q0a))
tau = 0.2
T = 5
tol = 1e-1
max_iter = 10

# compile
start=time.time()
sample = hmc_sampler(key, [q0a, p0b], 1)
end = time.time()
print("1st run:", end-start)
# main run
start = time.time()
samples = hmc_sampler(key, [q0a, p0b], 100)
end = time.time()
print("100 runs: ", end - start, "\n 1 run", (end-start)/100)


In [None]:
def midpointFPI_fullJac(q, p):
    qp0 = jnp.concat(q, p)
    n = len(q)
    @jit
    def F(qp):
        q, p = qp[:dim], qp[dim:]
        midq, midp = 0.5*(q+q0), 0.5*(p+p0)
        res_q = q - qp0[:dim] + tau*gradH_p(midq, midp)
        res_p = p - qp0[dim:] - tau*gradH_q(midq, midp)
        return jnp.concat(res_q, res_p)
    @jit
    def newton_step(qp):
        jacF = jax.jacobian(F)(qp)
        qpout = qp0 - jnp.linalg.solve(jacF, F(qp))
        return [qpout[:dim], qpout[dim:]]
    @jit
    def cond(carry):
        i, [q, p]= carry
        res_q, res_p = F(q, p)
        err = jnp.linalg.norm(res_q) + jnp.linalg.norm(res_p)
        return((err> tol) & (i<max_iter))
    @jit
    def body_step(carry):
        i, [q, p] = carry
        return i+1, newton_step(q, p)
    _, [qout, pout] = jax.lax.while_loop(cond, body_step, (0, [q,p]))
    return [qout, pout]

In [None]:
import jax
import jax.random as jr

# @jax.jit
# def fast_random_generation_jitted(key, num_samples):
#   subkeys = jr.split(key, num_samples)
#   return jax.vmap(jr.normal)(subkeys, num_samples)



# The first call will be slower due to JIT compilation
key = jr.key(0)
num_samples = 10000
jit_split = jit(jr.split)
vmap_normal = jit(jax.vmap(jax.random.normal))
subkeys = jr.split(key,num_samples)
start_time = time.time()
_ = vmap_normal(subkeys)
end_time = time.time()
print(f"first-compiled and vectorized generation took: {end_time - start_time:.4f} seconds")
# Subsequent calls are extremely fast
import time
start_time = time.time()
vmap_normal(subkeys, shape = (10000,1000)).block_until_ready()
end_time = time.time()
print(f"JIT-compiled and vectorized generation took: {end_time - start_time:.4f} seconds")


In [None]:
a = np.arange(6)
a[:3]


In [None]:
target = dist.gauss_ndimf_jax
grad_target = jax.grad(target)
gradH_p = jax.grad(hamiltionian, argnums=1)
gradH_q = jax.grad(hamiltionian, argnums=0)
jit_H = jax.jit(hamiltionian)
jit_target = jax.jit(target)
jit_grad_target = jax.jit(grad_target)
jit_gradH_p = jax.jit(gradH_p)
jit_gradH_q = jax.jit(gradH_q)
jit_integrator = jax.jit(leapfrog)
# jit_integrator = jit(midpointFPI)
key = jax.random.PRNGKey(0)
keyq, keyp, key0 = jax.random.split(key, 3)
q0a = jax.random.normal(keyq, shape=(1000,))
p0b = jax.random.normal(keyp, shape=(1000,))
Mass_inv= jnp.eye(len(q0a))
tau = 0.2
T = 5
tol = 1e-1
max_iter = 10

# compile
start=time.time()
sample_lf = hmc_sampler(key, [q0a, p0b], 1)
end = time.time()
print("1st run:", end-start)
# main run
start = time.time()
samples_lf = hmc_sampler(key, [q0a, p0b], 1000)
end = time.time()
print("1000 runs: ", end - start, "\n 1 run", (end-start)/1000)

In [None]:
import numpy as np
def f(x):
    return np.cos(x)


In [None]:
f(1)

In [None]:
f(f(f(0.8575532158463933)))

In [None]:
f(0.5403023058681398)