In [None]:
import numpy as np
import matplotlib.pyplot as plt
import distributions as dist
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
import time
jax.config.update("jax_enable_x64", True)

In [None]:
# Setting up the Hamiltonian
target = dist.gauss_ndimf_jax

def hamiltionian(q, p):
    return 0.5* (p@Mass_inv@p) - jnp.log(target(q))
def xhamiltionian(x):
    p, q = x[:dim], x[dim:]
    return 0.5* (p@Mass_inv@p) - jnp.log(target(q))
# Setting up the gradients and Jacobian

grad_target = jax.grad(target)
gradH_p = jax.grad(hamiltionian, argnums=1)
gradH_q = jax.grad(hamiltionian, argnums=0)
grad_xH = jax.grad(xhamiltionian)
jit_H = jax.jit(hamiltionian)
jit_target = jax.jit(target)
jit_grad_target = jax.jit(grad_target)
jit_gradH_p = jax.jit(gradH_p)
jit_gradH_q = jax.jit(gradH_q)

zero_zero = jnp.zeros((dim,dim))
zero_one = jnp.eye((dim))
one_zero = -jnp.eye((dim))
one_one = zero_zero
J_simplec = jnp.block([[zero_zero, zero_one],
                       [one_zero, one_one]])

def fpi_dense(qp):
    x0 = qp
    def G(y):
        midpoint = 0.5*(x0+y)
        return x0 + tau*J_simplec@grad_xH(midpoint)
    def F(y):
        return y-G(y)
    jacF = jax.jacobian(F)

    def fpi_step(y): 
        return y - jnp.linalg.solve(jacF(y), F(y))
    
    def cond(carry):
        y, i = carry
        return (jnp.linalg.norm(y)>tol and i<max_iter)
    @jit
    def body_step(carry):
        y, i = carry
        yout = fpi_step(y)
        return yout, i+1
    
    qp_out, _ = jax.lax.scan(cond, body_step, [qp, 0])
    return qp_out

# Setting up the integrators
@jit
def leapfrog(q, p):
    def lf_step(carry_in, _):
        q, p = carry_in
        q_half = q + 0.5 * tau * jit_gradH_p(q, p)
        p_new = p - tau * jit_gradH_q(q_half, p)
        q_new = q_half + 0.5 * tau * jit_gradH_p(q_half, p_new)
        carry_out = [q_new, p_new]
        return carry_out, None
    [q_final, p_final], _ = jax.lax.scan(lf_step, [q, p], xs=None, length=T)
    return q_final, p_final

qpcat = jit(jnp.concat)
extract_q = lambda qp: qp[:dim]
extract_p = lambda qp: qp[dim:]
jit_q = jit(extract_q)
jit_p = jit(extract_p)

@jit
def midpointFPI_dense_jac(q, p):
    q0, p0 = q,p
    qp0 = qpcat([q,p])
    n = len(q)
    @jit
    def F(qp):
        q, p = jit_q(qp), jit_p(qp)
        midq, midp = 0.5*(q+q0), 0.5*(p+p0)
        res_q = q - qp0[:dim] + tau*gradH_p(midq, midp)
        res_p = p - qp0[dim:] - tau*gradH_q(midq, midp)
        return jnp.concat([res_q, res_p])
    @jit
    def newton_step(qp):
        jacF = jax.jacobian(F)(qp)
        qpout = qp0 - jnp.linalg.solve(jacF, F(qp))
        return [jit_q(qpout), jit_p(qpout)]
    @jit
    def cond(carry):
        i, [q, p]= carry
        qp = qpcat([q,p])
        res_qp = F(qp)
        err = jnp.linalg.norm(res_qp)
        return((err> tol) & (i<max_iter))
    @jit
    def body_step(carry):
        i, [q, p] = carry
        qp = jnp.concat([q,p])
        return i+1, newton_step(qp)
    _, [qout, pout] = jax.lax.while_loop(cond, body_step, (0, [q,p]))
    return [qout, pout]


jit_integrator = jax.jit(leapfrog)
# jit_integrator = jit(midpointFPI)

dim = 100
key = jax.random.PRNGKey(0)
keyq, keyp, key0 = jax.random.split(key, 3)
q0a = jax.random.normal(keyq, shape=(dim,))
p0b = jax.random.normal(keyp, shape=(dim,))
Mass_inv= jnp.eye(len(q0a))
tau = 0.002
T = 1
tol = 1e-7
max_iter = 1000
def midpointFPI_blocked(q, p):
    q0, p0 = q, p

    @jit
    def F(q, p):
        midq, midp = 0.5*(q+q0), 0.5*(p+p0)
        res_q = q - q0 - tau*gradH_p(midq, midp)
        res_p = p - p0 - tau*gradH_q(midq, midp)
        return [res_q, res_p]
    @jit
    def newton_step(q, p):
        jacF_q = jax.jacobian(F, argnums=0)(q,p)[0] #dF/dq
        jacF_p = jax.jacobian(F, argnums=1)(q,p)[1] #dF/dp
        F_q, F_p = F(q, p)
        qout = q0 - jnp.linalg.solve(jacF_q, F_q)
        pout = p0 - jnp.linalg.solve(jacF_p, F_p)
        return [qout, pout]
    @jit
    def cond(carry):
        i, [q, p]= carry
        res_q, res_p = F(q, p)
        err = jnp.linalg.norm(res_q) + jnp.linalg.norm(res_p)
        return((err> tol) & (i<max_iter))
    @jit
    def body_step(carry):
        i, [q, p] = carry
        return i+1, newton_step(q, p)
    _, [qout, pout] = jax.lax.while_loop(cond, body_step, (0, [q,p]))
    return [qout, pout]
print(midpointFPI_blocked(q0a, p0b))
print(midpointFPI_dense_jac(q0a, p0b))
print(leapfrog(q0a, p0b))

# # compile
# start=time.time()
# sample_lf = hmc_sampler(key, [q0a, p0b], 1)
# end = time.time()
# print("1st run:", end-start)
# # main run
# start = time.time()
# samples_lf = hmc_sampler(key, [q0a, p0b], 1000)
# end = time.time()
# print("1000 runs: ", end - start, "\n 1 run", (end-start)/1000)

In [None]:
x0 = jnp.concat([q0a, p0b])
def G(y):
    midpoint = 0.5*(x0+y)
    # 
    return x0 + tau*J_simplec@grad_xH(midpoint)
def F(y):
    return y-G(y)
jacF = jax.jacobian(F)

def fpi_step(y): 
    return y - jnp.linalg.solve(jacF(y), F(y))

def cond(carry):
    y, i = carry
    return (jnp.linalg.norm(y)>tol and i<max_iter)
@jit
def body_step(carry):
    y, i = carry
    yout = fpi_step(y)
    return yout, i+1

y1 = fpi_step(x0)
y2 = fpi_step(y1)
y3 = fpi_step(y2)
y4 = fpi_step(y3)
y5 = fpi_step(y4)
y6 = fpi_step(y5)
y7 = fpi_step(y6)

In [None]:
qplf = jnp.concat(leapfrog(q0a,p0b))
qp_blocked = jnp.concat(midpointFPI_blocked(q0a, p0b))
qp_dense = jnp.concat(midpointFPI_dense_jac(q0a,p0b))


In [None]:
print(jnp.linalg.norm(y7-qplf))
print(jnp.linalg.norm(qp_blocked-qplf))
print(jnp.linalg.norm(qp_blocked-y7))
print(jnp.linalg.norm(qp_blocked - qp_dense))
print(jnp.linalg.norm(qp_dense-qplf))

In [None]:
len(qpf)

In [None]:
def lf_step(carry_in, _):
        # print(carry_in)
        q, p = carry_in
        q_half = q + 0.5 * tau * jit_gradH_p(q, p)
        # print(q_half)
        p_new = p - tau * jit_gradH_q(q_half, p)
        print(p_new)
        q_new = q_half + 0.5 * tau * jit_gradH_p(q_half, p_new)
        return [q_new, p_new], _

In [None]:
q0a

In [None]:
gradH_p(q0a,p0b).shape

In [None]:
cat = jit(jnp.concat)
cat([q0a,p0b])

In [None]:
midpointFPI_dense_jac(q0a, p0b)

In [None]:
lf_step([q0a, p0b],_)