In [1]:
# Feb 16, 2026
# import numpy as np
import jax
from jax import jit

import jax.random as jr
import jax.numpy as jnp
from datatypes import QP
from target import gen_perturb_precision
from hamiltonian import gaussian_hamiltonian
from sampler import hmc_sampler, chmc_sampler, extract_positions

import time

import numpy as np

In [2]:
# initial key
key = jax.random.PRNGKey(1)

# dimensions
dim = 4
Mass_inv = jnp.eye(dim)

target_mat = gen_perturb_precision(dim)
H = gaussian_hamiltonian(target_mat, mass_inv=Mass_inv)
H_flat = lambda qp_flat: H(QP.from_array(qp_flat))

# Set parameters
tau = 0.05
T = 1.0
N = int(jnp.ceil(T / tau))
tol = 1e-3
max_iter = 2


initnum_samples = 1
mainnum_samples = 1000
keys_init = jr.split(key, initnum_samples)
keys_main = jr.split(key, mainnum_samples)
qp_init = jr.normal(key, shape=(2 * dim,))

init_sample = [qp_init, 1, False]

# ========================================
# HMC
# ========================================

# compile
start = time.time()
jhmc_sampler = jit(hmc_sampler, static_argnums=(2, 3, 4))
sample_hmc = jhmc_sampler(init_sample, keys_init, H_flat, tau, N)
end = time.time()
print("1st run:", end - start)
# main run
start = time.time()
sample_hmc = jhmc_sampler(init_sample, keys_main,  H_flat, tau, N)
end = time.time()
print(
    f"Main run:\n {mainnum_samples} runs: {end - start:.2f} \n 1 run:  {(end-start)/mainnum_samples}"
)

1st run: 0.16547870635986328
Main run:
 1000 runs: 0.16 
 1 run:  0.00015618419647216796


In [3]:
# ========================================
# CHMC
# ========================================
start = time.time()
jchmc_sampler = jit(chmc_sampler, static_argnums=(2,3,4,5,6))
sample_chmc = jchmc_sampler(init_sample, keys_init, H_flat, tau, N, tol, max_iter)
end = time.time()
print("CHMC 1st run:", end - start)
start = time.time()
sample_chmc = jchmc_sampler(init_sample, keys_main, H_flat, tau, N, tol, max_iter)
end = time.time()
print(
    f"CHMC Main run:\n {mainnum_samples} runs: {end - start:.2f} \n 1 run:  {(end-start)/mainnum_samples: .3e}"
)

CHMC 1st run: 0.5075011253356934
CHMC Main run:
 1000 runs: 0.25 
 1 run:   2.463e-04


In [4]:
dims = [4, 8, 16]
lens = [100, 1_000, 5_000, 10_000]

for dim in dims:

    Mass_inv = jnp.eye(dim)
    target_mat = gen_perturb_precision(dim)
    H = gaussian_hamiltonian(target_mat, mass_inv=Mass_inv)
    H_flat = lambda qp_flat: H(QP.from_array(qp_flat))
    qp_init = jr.normal(key, shape=(2 * dim,))
    init_sample = [qp_init, 1, False]
    jchmc_sampler = jit(chmc_sampler, static_argnums=(2,3,4,5,6))
    for mainnum_samples in lens:
        keys_main = jr.split(key, mainnum_samples)
        start = time.time()
        sample_chmc = jchmc_sampler(init_sample, keys_main, H_flat, tau, N, tol, max_iter)
        end = time.time()
        print(f"CHMC Main run:\n {mainnum_samples} runs: {end - start:.2f} \n 1 run:  {(end-start)/mainnum_samples: .3e}")

CHMC Main run:
 100 runs: 0.25 
 1 run:   2.513e-03
CHMC Main run:
 1000 runs: 0.25 
 1 run:   2.509e-04
CHMC Main run:
 5000 runs: 0.25 
 1 run:   4.974e-05
CHMC Main run:
 10000 runs: 0.26 
 1 run:   2.555e-05
CHMC Main run:
 100 runs: 0.67 
 1 run:   6.697e-03
CHMC Main run:
 1000 runs: 0.26 
 1 run:   2.619e-04
CHMC Main run:
 5000 runs: 0.29 
 1 run:   5.762e-05
CHMC Main run:
 10000 runs: 0.30 
 1 run:   2.961e-05
CHMC Main run:
 100 runs: 4.11 
 1 run:   4.115e-02
CHMC Main run:
 1000 runs: 0.30 
 1 run:   2.996e-04
CHMC Main run:
 5000 runs: 0.33 
 1 run:   6.658e-05
CHMC Main run:
 10000 runs: 0.36 
 1 run:   3.606e-05
