In [None]:
"""
Module Name: integrators.py

Description:
    Making a list of distributions and functions for use in numerical sampling and analysis of approximate convergence. 

Author: John Gallagher
Created: 2025-06-03
Last Modified: 2025-09-16
Version: 1.0

To Do:
Incorporate grad information. 

"""
import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import jit

# Float Config Check
def check_config():
    """Checks the current JAX configuration."""
    print("In custom_module.py:")
    print(f"64-bit precision enabled: {jax.config.jax_enable_x64}")

"""
Hypothetical hamiltonian:
def hamiltonian(q, p, Mass_inv, target):
    return 0.5* p.dot(Mass_inv.dot(p)) - jnp.log(target(q))

q_dot = jax.grad(hamiltonian, argnums=1)
p_dot = jax.grad(hamiltonian, argnums=0)
    
"""


@jit
def qex(qp):
  return qp[:dim]
@jit
def pex(qp):
  return qp[dim:]

def hamiltonian(q,p):
    return p**2 + q**2
def xhamiltonian(qp):
    q, p = qex(q), pex(p)
    return p**2 + q**2

def J_H(gH):
    """Same operation as Symplectic Jacobian"""
    return jnp.concatenate([gH[dim:],-gH[:dim]])

# gradH_q = jit(jax.grad(xhamiltonian, argnums=0))
# gradH_p = jit(jax.grad(xhamiltonian, argnums=1))
gradH = jit(jax.grad(xhamiltonian))
jacH = jit(jax.jacobian(xhamiltonian))
# jacH_p = jit(jax.jacobian(gradH_p, argnums=1))
# jacH_q = jit(jax.jacobian(gradH_q, argnums=0))


# def leapfrog(state, system_params, Ham_funcs, tau, T,  tol, maxIter):
#     "Symplectic integrator: Leapfrog"
    
#     q, p = state
#     Mass_inv, target = system_params
#     q_dot, p_dot = Ham_funcs
#     q = q + 0.5*tau* p_dot(q, p, Mass_inv, target) # q_half
#     p = p - tau*q_dot(q, p, Mass_inv, target) # p_full
#     q = q + 0.5*tau* p_dot(q, p, Mass_inv, target) # q_full 
#     return [q, p]


def leapfrog_qp(qp):
    """
    requires 
    - qex, pex
    - tau
    - gradH_p
    - gradH_q
    - T
    """
    def lf_step(carry_in, _):
        q, p = qex(carry_in), pex(carry_in)
        q_half = q + 0.5 * tau * gradH_p(q, p)
        p_new = p - tau * gradH_q(q_half, p)
        q_new = q_half + 0.5 * tau * gradH_p(q_half, p_new)
        carry_out = jnp.concatenate([q_new, p_new])
        return carry_out, _
    qp_final,  _ = jax.lax.scan(lf_step, qp, xs=None, length=T)
    return qp_final

def midpointFPI_qp(qp):
    """
    requires 
    - tau
    - grad_xH to be defined
    - J_h
    - tol
    - max_iter
    """
    x0 = qp
    def G(y):
        midpoint = 0.5*(x0+y)
        return x0 + tau*J_H(grad_xH(midpoint))

    def F(y):
        return y-G(y)

    def newton_step(qp):
        jacF = jax.jacobian(F)(qp)
        qpout = x0 - jnp.linalg.solve(jacF, F(qp))
        return qpout

    def cond(carry):
        i, qp = carry
        F_qp = F(qp)
        err = jnp.linalg.norm(F_qp)
        return (err > tol) & (i < max_iter)

    def body_step(carry):
        i, qp = carry
        return [i + 1, newton_step(qp)]

    _, qp_out = jax.lax.while_loop(cond, body_step, [0, qp])
    return qp_out


qp0 = np.array([1., 2.])
tau = 0.2
T= 5
dim = 1


In [None]:
gradH_q(qp0)

In [None]:
leapfrog_qp(qp0)

In [None]:
xy = []
xy.append(qp0)
for i in range(1,1000):
    qp = leapfrog_qp(xy[i-1])
    xy.apped(qp)