In [1]:
import numpy as np

import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import *

import pymc3 as pm
import theano
import theano.tensor as tt

import arviz as az

numpyro.enable_x64(True)
numpyro.set_host_device_count(4)

In [2]:
print("numpyro version: ", numpyro.__version__)
print("jax version: ", jax.__version__)
print("pymc3 version: ", pm.__version__)
print("theano version: ", theano.__version__)

numpyro version:  0.6.0
jax version:  0.2.10
pymc3 version:  3.9.0
theano version:  1.0.5


In [3]:
A = np.load("A.npy")
P2Y = np.load("P2Y.npy")
fo = np.load("f.npy")
ferr = np.load("ferr.npy")

n = P2Y.shape[1]

In [4]:
# PyMC3 model
with pm.Model() as model_pm:
    PositiveNormal = pm.Bound(pm.Normal, lower=0.0)
    fs = PositiveNormal("fs", mu=0.0, sigma=1.0)
    p_scale = PositiveNormal("p_scale", mu=0.0, sigma=1e-04)
    p = pm.Exponential("p", 1 / p_scale, shape=(n,))

    x = tt.dot(P2Y, p)
    pm.Deterministic("x", x)

    f = tt.dot(A, x[:, None]).flatten() + fs

    pm.Normal("obs", mu=f, sigma=ferr, observed=fo)

  ret += x.c_compile_args()


In [5]:
init_vals = {"p": 1e-06 * np.random.rand(n), "fs": 0.99999, "p_scale": 2e-05}

nwarmup = 1000
nsamples = 500
nchains = 2

with model_pm:
    samples_pm = pm.sample(
        tune=nwarmup,
        draws=nsamples,
        cores=nchains,
        chains=nchains,
        target_accept=0.95,
        start=init_vals,
        init="adapt_diag",
    )
    samples_pm_az = az.from_pymc3(samples_pm)

Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [p, p_scale, fs]


Sampling 2 chains for 1_000 tune and 500 draw iterations (2_000 + 1_000 draws total) took 257 seconds.


In [6]:
# Numpyro model
def model_num():
    fs = numpyro.sample("fs", dist.HalfNormal(scale=1.0))
    p_scale = numpyro.sample("p_scale", dist.HalfNormal(scale=1e-04))
    p = numpyro.sample("p", dist.Exponential(1 / p_scale).expand([n]))

    x = jnp.dot(P2Y, p)
    numpyro.deterministic("x", x)

    f = jnp.dot(A, x[:, None]).flatten() + fs

    numpyro.sample("obs", dist.Normal(f, ferr), obs=fo)


init_vals = {
    "p": samples_pm["p"][-1],
    "fs": samples_pm["fs"][-1],
    "p_scale": samples_pm["p_scale"][-1],
}
nuts_kernel = NUTS(
    model_num,
    dense_mass=False,
    init_strategy=init_to_value(values=init_vals),
    target_accept_prob=0.95,
)
mcmc = MCMC(nuts_kernel, num_warmup=nwarmup, num_samples=nsamples, num_chains=nchains)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key)
samples_num_az = az.from_numpyro(mcmc)

  0%|          | 0/1500 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

In [7]:
az.summary(samples_pm_az, var_names=["p_scale"])

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat
p_scale,0.0,0.0,0.0,0.0,0.0,0.0,1179.0,1153.0,1196.0,880.0,1.0


In [8]:
az.summary(samples_num_az, var_names=["p_scale"])

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat
p_scale,0.0,0.0,0.0,0.0,0.0,0.0,3.0,3.0,3.0,28.0,1.77
