In [219]:
import jax
from scipy import stats
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as random

In [220]:
# set the seed for jax.random
key = random.PRNGKey(1)

def random_PowerLaw(N, alpha, M_min, M_max):
    """
    Draw random samples from a power-law defined over M_min, M_max.
        dN/dM = Z x M ** -alpha
    INPUTS
    ------
    N: int
        number of samples.
    alpha: float
        power-law index.
    M_min: float
        lower bound of mass interval.
    M_max: float
        upper bound of mass interval.
    OUTPUTS
    -------
    masses: ndarray[float, ndim=1]
        list of N masses drawn from the power-law
    """
    beta = 1. - alpha
    x = np.random.uniform(0., 1., N)
    if beta == 0:
        y = M_min * np.exp(np.log(M_max / M_min) * x)
    else:
        y = ((M_max ** beta - M_min ** beta) * x + M_min ** beta) ** (1. / beta)
    
    return y

# Generate toy data.
Nstars = int(1e5)
alpha_theory  = 2.35
M_min  = 1.0
M_max  = 100.0
Masses = random_PowerLaw(Nstars, alpha_theory, M_min, M_max)
LogM   = np.log(Masses)
D      = np.mean(LogM) * Nstars

theta0 = np.asarray([3.0])

def negative_logLikelihood(theta):
    """
    Define logarithmic likelihood function.
    theta: ndarray[float, ndim=1]
        array of fit params
    D: float
        data
    N: int
        number of data points
    M_min: float
        lower limit of mass interval
    M_max: float
        upper limit of mass interval
    """
    # extract alpha
    alpha = theta[-1]
    beta = 1.0 - alpha

    # Compute normalisation constant.
    if beta == 0:
        c = 1. / np.log(M_max / M_min)
    else:
        c = beta / (M_max ** beta - M_min ** beta)

    # return log likelihood.
    return -(Nstars * jnp.log(c) - alpha * D)

def grad_logLikelihood(theta):
    """Define gradient of log-likelihood
    d lnL / d alpha
    theta: ndarray[float, ndim=1]
        array of fit params
    D: float
        data
        normalization constant, sum_n log(M_n)
    N: int
        number of data points
    M_min: float
        lower limit of mass interval
    M_max: float
        upper limit of mass interval
    """
    alpha = theta[0]  # extract alpha
    beta = 1.0 - alpha

        #Convert limits from M to logM.
    logMmin = np.log(M_min)
    logMmax = np.log(M_max)

    # Compute normalisation constant.
    if beta != 0:
        grad = logMmin * M_min ** beta - logMmax * M_max ** beta
        grad = 1.0 + grad * beta / (M_max ** beta - M_min ** beta)
        grad = -D - Nstars * grad / beta
    else:
        # If alpha == 1, the normalization factor is the logarithmic term
        grad = -D + Nstars * np.log(M_max / M_min)

    return np.array([-grad])


In [221]:
def hamiltonian_monte_carlo(
    n_samples, negative_log_prob, initial_position, 
    path_len, step_size, key
):
    # autograd magic
    dVdq = grad_logLikelihood

    # collect all our samples in a list
    samples = [initial_position]

    # Keep a single object for momentum resampling
    momentum = stats.norm(0, 1)
    size = (n_samples,) + initial_position.shape[:1]
    
    for p0 in momentum.rvs(size=size):
        # Integrate over our path to get a new position and momentum
        q_new, p_new = leapfrog(
            samples[-1], p0, dVdq, path_len=path_len, step_size=step_size
        )

        if np.any(np.isnan(q_new)) or np.any(np.isnan(p_new)):
            print(f"NaN encountered in leapfrog update, skipping step")
            continue  # Skip this update and retry


        # Check Metropolis acceptance criterion
        start_log_p = negative_log_prob(samples[-1]) - np.sum(momentum.logpdf(p0))
        new_log_p = negative_log_prob(q_new) - np.sum(momentum.logpdf(p_new))

        # Calculate the acceptance probability
        accept_prob = min(1, np.exp(new_log_p - start_log_p))

        # Generate a random number between 0 and 1 using jax.random
        u = random.uniform(key)

        # Accept or reject the new sample based on the acceptance probability
        if u <= accept_prob:
            samples.append(q_new)
        else:
            samples.append(samples[-1])

        print(np.array(samples).shape)

    return np.array(samples)

def leapfrog(q, p, dVdq, path_len, step_size):
    # Half-step for momentum
    print(dVdq(q))
    p -= step_size * dVdq(q) / 2.0

    # Full steps for position and momentum
    for _ in range(int(path_len / step_size) - 1):
        q += step_size * p
        p -= step_size * dVdq(q)

    # Final half-step for momentum
    q += step_size * p
    p -= step_size * dVdq(q) / 2.0

    print(f"Final q: {q}, Final p: {p}")

    return q, -p

In [222]:
# 初始化参数
n_samples = 1000                         # 采样次数
path_len = 1e-5                             # 轨迹长度
step_size = 1e-3                         # 步长

# 运行HMC采样
samples = hamiltonian_monte_carlo(
    n_samples=n_samples,
    negative_log_prob=negative_logLikelihood,
    initial_position=theta0,
    path_len=path_len,
    step_size=step_size,
    key=key
)

[23465.25835819]
Final q: [2.98783192], Final p: [-23.74900304]
(2, 1)
[23161.84906424]
Final q: [2.97506027], Final p: [-24.19139583]
(3, 1)
[22839.49818393]
Final q: [2.96452728], Final p: [-21.8182974]
(4, 1)
[22570.60139859]
Final q: [2.95224506], Final p: [-23.40897624]
(5, 1)
[22253.51097705]
Final q: [2.94164465], Final p: [-21.58876888]
(6, 1)
[21976.730017]
Final q: [2.93003254], Final p: [-22.44720016]
(7, 1)
[21670.17239301]
Final q: [2.91816994], Final p: [-22.53925452]
(8, 1)
[21353.31598954]
Final q: [2.90692192], Final p: [-21.77271004]
(9, 1)
[21049.37808153]
Final q: [2.89545921], Final p: [-21.83075662]
(10, 1)
[20736.0793827]
Final q: [2.88607831], Final p: [-19.61938189]
(11, 1)
[20476.96685264]
Final q: [2.87487929], Final p: [-21.28121462]
(12, 1)
[20164.38776104]
Final q: [2.86476974], Final p: [-20.04910987]
(13, 1)
[19879.13764134]
Final q: [2.8541097], Final p: [-20.44761416]
(14, 1)
[19575.14098525]
Final q: [2.84511279], Final p: [-18.65489552]
(15, 1)
[1931

In [225]:
np.mean(samples, axis=0)

array([2.33704874])