In [None]:
def model(time, rv_err, y=None):
    logPeriod_min = jnp.log(1)
    logPeriod_max = jnp.log(10)

    mean_s = 0.0
    std_s = 1.0

    # Setup
    a = 0.867
    b = 3.03

    # Priors
    logPeriod = numpyro.sample("logPeriod", dist.Uniform(logPeriod_min, logPeriod_max))
    eccentricity = numpyro.sample("eccentricity", dist.Beta(a, b))
    omega = numpyro.sample("$\\omega$", dist.Uniform(0.0, 2*jnp.pi))
    phi0 = numpyro.sample("$\\phi_0$", dist.Uniform(0.0, 2*jnp.pi))
    logJitter = numpyro.sample("logJitter", dist.Normal(mean_s, std_s))
    K = numpyro.sample("K", dist.Uniform(0.0, 100.0))
    v0 = numpyro.sample("$v_0$", dist.Uniform(-100.0, 100.0))

    period = numpyro.deterministic("period", jnp.exp(logPeriod))
    jitter = numpyro.deterministic("jitter", jnp.exp(logJitter))
    
    # Likelihood
    rv = velocity(time, period, eccentricity, omega, phi0, K, v0)
    
    # the likelihood function
    numpyro.sample("raidalVelocity", dist.Normal(rv, rv_err + jitter), obs=y)

In [None]:
time = time_obs
rv_err = rv_err
rv_obs = rv_obs

sampler = infer.MCMC(
    infer.NUTS(model),
    num_warmup=500,
    num_samples=1000,
    num_chains=1,
    progress_bar=True,
)

sampler.run(jax.random.PRNGKey(6), time, rv_err, y=rv_obs)


In [None]:
samples_df = pd.DataFrame(sampler.get_samples())
samples_df.drop(columns=["logJitter", "logPeriod"], inplace=True)

# Visualization
g = sns.PairGrid(samples_df, corner=True)
g.map_lower(sns.scatterplot, s=5, color='black',)
g.map_lower(sns.histplot, bins=30, pthresh=.005, cmap='mako')
g.map_lower(sns.kdeplot, levels=5, color='white', linewidths=1)
g.map_diag(sns.kdeplot, )
plt.show()