In [None]:
"""
Description:
    CHMC Implementation with AVF: FPI
    USE THE CORRECT ENVIRONMENT:  CHMC_FALL_2025

Author: John Gallagher
Created: 2025-09-28
Last Modified: 2025-09-28
Version: 1.0.0

"""
import numpy as np
import matplotlib.pyplot as plt
import distributions as dist
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
import time


@jit
def draw_p(carry_in, key1):
    q, _ = carry_in
    p = jax.random.normal(key1,q.shape)
    carry_out = [q,p] 
    return carry_out, None

@jit
def leapfrog(q, p):
    def lf_step(carry_in, _):
        q, p = carry_in
        q_half = q + 0.5 * tau * jit_gradH_p(q, p)
        p_new = p - tau * jit_gradH_q(q_half, p)
        q_new = q_half + 0.5 * tau * jit_gradH_p(q_half, p_new)
        carry_out = [q_new, p_new]
        return carry_out, None
    [q_final, p_final], ys = jax.lax.scan(lf_step, [q, p], xs=None, length=T)
    return q_final, p_final

# def fixed_point(f, x0, tol = 1e-3, max_iters = 20)
#     def cond_fun(carry):
#         x, i, converged = carry
#         return jnp.logical_and(i < max_iters, jnp.logical_not(converged))

#     def body_fun(val):
#         x, i, converged = val
#         x_new = g(x)
#         converged = jnp.abs(x_new - x) < tol
#         return (x_new, i+1, converged)

#     x_final, iters, converged = lax.while_loop(cond_fun, body_fun, (x0, 0, False))
#     return x_final, iters, converged
@jit
def accept(key, delta):
    alpha = jnp.minimum(1., jnp.exp(-delta))
    u = jax.random.uniform(key, shape = ())
    return u <= alpha

def hmc_kernel(key, last_sample):
        key0, key1 = jax.random.split(key)
        key2, key3 = jax.random.split(key0)
        last_sample, _ = draw_p(last_sample, key1)
        q, p = last_sample
        q_star, p_star = jit_integrator(q,p)
        deltaH = jit_H(q_star, p_star) - jit_H(q, p)
        is_accepted = accept(key2, deltaH)
        q_out = jnp.where(is_accepted, q_star, q)
        next_sample = [q_out, p_star]
        return key3, next_sample

def hmc_sampler(key, initial_sample, leapfrog, jit_H, accept, num_samples):
    def hmc_step(carry, x):
         key, last_sample = carry
         key, next_sample = hmc_kernel(key, last_sample)
         return (key, next_sample), next_sample
    (_, _), samples = jax.lax.scan(hmc_step, (key, initial_sample), xs = None, length = num_samples)
    return samples





# Function handles into mechanics of HMC Sampler:
def hamiltionian(q, p):
    return 0.5* (p@Mass_inv@p) - jnp.log(target(q))

target = dist.gauss_ndimf_jax
grad_target = jax.grad(target)
gradH_p = jax.grad(hamiltionian, argnums=1)
gradH_q = jax.grad(hamiltionian, argnums=0)
jit_H = jax.jit(hamiltionian)
jit_target = jax.jit(target)
jit_grad_target = jax.jit(grad_target)
jit_gradH_p = jax.jit(gradH_p)
jit_gradH_q = jax.jit(gradH_q)
jit_integrator = jax.jit(leapfrog)

key = jax.random.PRNGKey(0)
keyq, keyp, key0 = jax.random.split(key, 3)
q0a = jax.random.normal(keyq, shape=(1000,))
p0b = jax.random.normal(keyp, shape=(1000,))
Mass_inv= jnp.eye(len(q0a))
tau = 0.2
T = 5

# compile
start=time.time()
sample = hmc_sampler(key, [q0a, p0b], leapfrog, jit_H, accept, 1)
end = time.time()
print("1st run:", end-start)
# main run
start = time.time()
samples = hmc_sampler(key, [q0a, p0b], leapfrog, jit_H, accept, 1000)
end = time.time()
print("1000 runs: ", end - start, "\n 1 run", (end-start)/1000)


In [None]:
import numpy as np
def f(x):
    return np.cos(x)


In [None]:
f(1)

In [None]:
f(f(f(0.8575532158463933)))

In [None]:
f(0.5403023058681398)