In [4]:
!pip install numpyro

Collecting numpyro
  Downloading numpyro-0.19.0-py3-none-any.whl.metadata (37 kB)
Downloading numpyro-0.19.0-py3-none-any.whl (370 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/370.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m370.9/370.9 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpyro
Successfully installed numpyro-0.19.0


In [1]:
import jax

print("JAX version:", jax.__version__)
print("Backend:", jax.default_backend())
print("Devices:", jax.devices())

JAX version: 0.7.2
Backend: gpu
Devices: [CudaDevice(id=0)]


In [2]:
!nvidia-smi

Wed Oct 29 18:34:31 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   46C    P0             26W /   70W |     110MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
import jax.numpy as jnp
import numpyro
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler

tfd = numpyro.distributions

dataset = load_breast_cancer()
scaler = StandardScaler()
X = scaler.fit_transform(dataset.data).astype("float32")
y = dataset.target.astype("float32")

X = jnp.asarray(X)
y = jnp.asarray(y)
n_features = X.shape[1]
print("Dataset shape:", X.shape, y.shape)

Dataset shape: (569, 30) (569,)


In [6]:
def joint_log_prob(x, y, tau, lamb, beta):
    lp = tfd.Gamma(0.5, 0.5).log_prob(tau)
    lp += tfd.Gamma(0.5, 0.5).log_prob(lamb).sum()
    lp += tfd.Normal(0.0, 1.0).log_prob(beta).sum()
    logits = x @ (tau * lamb * beta)
    lp += tfd.Bernoulli(logits=logits).log_prob(y).sum()
    return lp

In [7]:
from jax import random

key = random.key(0)           # initialize PRNG key
beta = random.uniform(key, (30,), minval=0.0, maxval=1.0)

print(beta)

[0.947667   0.9785799  0.33229148 0.46866846 0.5698887  0.16550303
 0.3101946  0.68948054 0.74676657 0.17101455 0.9853538  0.02528262
 0.6400418  0.56269085 0.8992138  0.93453753 0.8341402  0.7256162
 0.5098531  0.02765214 0.03148878 0.9580188  0.5188192  0.79221416
 0.5522419  0.6113529  0.8931755  0.75499094 0.21164179 0.22934973]


In [8]:
joint_log_prob(X, y, 1.0, 1.0, beta)

Array(-4869.8623, dtype=float32)

In [9]:
def unconstrained_joint_log_prob(x, y, z):
    ndims = x.shape[-1]
    unc_tau, unc_lamb, beta = jnp.split(z, [1, 1 + ndims])
    unc_tau = unc_tau.reshape([])
    tau = jnp.exp(unc_tau)
    lamb = jnp.exp(unc_lamb)
    ldj = unc_tau + unc_lamb.sum()
    return joint_log_prob(x, y, tau, lamb, beta) + ldj

target_log_prob = lambda z: unconstrained_joint_log_prob(X, y, z)

In [10]:
target_log_prob_and_grad = jax.value_and_grad(target_log_prob)

dim = 1 + n_features + n_features  # tau + lamb + beta
z_init = jnp.zeros((dim,))

logp, grad = target_log_prob_and_grad(z_init)
print("Initial log-density:", float(logp))
print("Gradient L2 norm:", float(jnp.linalg.norm(grad)))

Initial log-density: -465.95599365234375
Gradient L2 norm: 803.6372680664062


In [12]:
from functools import partial

In [14]:
target_log_prob = partial(unconstrained_joint_log_prob, X, y)

In [15]:
target_log_prob_and_grad = jax.value_and_grad(target_log_prob)

In [18]:
tlp, tlp_grad = target_log_prob_and_grad(z_init)
print(tlp,"\n" ,tlp_grad)

-465.956 
 [   0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.           0.           0.           0.           0.
    0.        -200.83615   -114.22049   -204.30441   -195.0466
  -98.642456  -164.11075   -191.57361   -213.6521     -90.92255
    3.5317173 -156.02263      2.2843075 -152.99834   -150.82368
   18.436588   -80.60622    -69.8029    -112.2554       1.7941825
  -21.450775  -213.60806   -125.69727   -215.38536   -201.88058
 -115.948044  -162.58789   -181.46356   -218.31577   -114.52559
  -89.099594 ]


## 3.2 Hamiltonian Monte Carlo

Now we are ready to move on to using HMC. Below, we implement a simple version of HMC using JAX. In particular, this implementation does not accept or adapt a mass matrix (Betancourt, 2018; Neal, 2011) (it is implicitly an identity matrix of the appropriate size), nor does it adapt the step size or number of leapfrog steps. We use the variable `z` for the parameters, and `m` for the momentum.

In [19]:
def leapfrog_step(target_log_prob_and_grad, step_size, i, leapfrog_state):
    z, m, tlp, tlp_grad = leapfrog_state
    m += 0.5 * step_size * tlp_grad
    z += step_size * m
    tlp, tlp_grad = target_log_prob_and_grad(z)
    m += 0.5 * step_size * tlp_grad
    return z, m, tlp, tlp_grad

The leapfrog_step function is a core part of the Hamiltonian Monte Carlo (HMC) algorithm. It simulates the movement of a particle in a potential energy field (defined by your target_log_prob_and_grad).

Here's a breakdown of what the code does:

def leapfrog_step(target_log_prob_and_grad, step_size, i, leapfrog_state):: This defines the function leapfrog_step which takes the following arguments:
target_log_prob_and_grad: A function that returns both the log probability and its gradient with respect to the parameters (z).
step_size: The size of each step in the simulation.
i: The current step number (though it's not used in this particular implementation).
leapfrog_state: A tuple containing the current state of the simulation: (z, m, tlp, tlp_grad).
z: The parameters (position of the particle).
m: The momentum of the particle.
tlp: The target log probability at the current z.
tlp_grad: The gradient of the target log probability at the current z.
z, m, tlp, tlp_grad = leapfrog_state: This unpacks the leapfrog_state tuple into its individual components.
m += 0.5 * step_size * tlp_grad: This updates the momentum (m) by adding half of the gradient of the log probability multiplied by the step_size. This is the "half-step" for momentum at the beginning.
z += step_size * m: This updates the parameters (z) by adding the momentum multiplied by the step_size. This is the full step for position.
tlp, tlp_grad = target_log_prob_and_grad(z): After updating z, the target log probability (tlp) and its gradient (tlp_grad) are recalculated at the new position.
m += 0.5 * step_size * tlp_grad: This is the second "half-step" for momentum, using the gradient at the new position z. This symmetric update makes the leapfrog integrator more stable and reversible.
return z, m, tlp, tlp_grad: The function returns the updated state after the leapfrog step.
In essence, the leapfrog step simulates the physics of a particle (representing your parameters) moving in a landscape where the potential energy is related to the negative log probability of your model. It does this by alternately updating the momentum and position in a way that conserves energy over time, which is crucial for efficient exploration of the parameter space in HMC.

### Building a complete HMC step

Now we assemble the pieces into a single HMC iteration. The `hmc_step` function orchestrates the full Metropolis-adjusted leapfrog trajectory, and understanding its flow is key to grasping how HMC explores the posterior.

**What happens in one HMC step:**

1. **Sample fresh momentum.** We draw `m` from a standard normal distribution, independent of the current position `z`. This randomness injects energy into the system and ensures the chain explores different regions of the parameter space.

2. **Compute the starting energy.** The Hamiltonian (total energy) splits into kinetic energy $\frac{1}{2} \|m\|^2$ and potential energy $-\log p(z \mid x, y)$. Adding these gives us a baseline to compare against after we simulate the trajectory.

3. **Simulate Hamiltonian dynamics.** We call `leapfrog_step` repeatedly via `jax.lax.fori_loop`, which unrolls the trajectory without Python-level overhead. Each leapfrog step updates position and momentum in lockstep, preserving volume and keeping the dynamics reversible—properties that help HMC achieve high acceptance rates even in high dimensions.

4. **Evaluate the new energy.** After `num_leapfrog_steps` iterations, we recompute kinetic plus potential energy at the proposed state `(new_z, new_m)`. Because the leapfrog integrator introduces small numerical errors, the energy usually drifts slightly.

5. **Accept or reject via Metropolis-Hastings.** We compare `energy - new_energy` (the log acceptance ratio). If the new energy is lower, we always accept. If it's higher, we accept with probability $\exp(\text{energy} - \text{new\_energy})$. This correction ensures our samples remain draws from the true posterior, not just from the approximate Hamiltonian flow.

6. **Return diagnostics.** Beyond the updated parameter vector `z`, we also return `is_accepted` (a boolean indicating whether we took the proposal) and `log_accept_ratio` (useful for tuning step size). High acceptance rates suggest we could take bigger steps; very low rates mean the integrator is drifting too far and we should shrink the step size.

**Why this matters for accelerators:** By structuring the trajectory loop with `jax.lax.fori_loop` and `partial`, JAX can compile the entire HMC step into a single fused kernel. The CPU or GPU sees a static computation graph with no Python interpreter overhead, which is exactly what Hoffman et al. emphasize for modern hardware.

In [None]:
def hmc_step(target_log_prob_and_grad, num_leapfrog_steps, step_size, z, seed):
    m_seed, mh_seed = jax.random.split(seed)
    tlp, tlp_grad = target_log_prob_and_grad(z)
    m = jax.random.normal(m_seed, z.shape)
    energy = 0.5 * jnp.square(m).sum() - tlp
    new_z, new_m, new_tlp, _ = jax.lax.fori_loop(
        0,
        num_leapfrog_steps,
        partial(leapfrog_step, target_log_prob_and_grad, step_size),
        (z, m, tlp, tlp_grad))
    new_energy = 0.5 * jnp.square(new_m).sum() - new_tlp
    log_accept_ratio = energy - new_energy
    is_accepted = jnp.log(jax.random.uniform(mh_seed, [])) < log_accept_ratio
    # select the proposed state if accepted
    z = jnp.where(is_accepted, new_z, z)
    hmc_output = {"z": z,
                  "is_accepted": is_accepted,
                  "log_accept_ratio": log_accept_ratio}
    # hmc_output["z"] has shape [num_dimensions]
    return z, hmc_output