In [1]:
import arviz
import jax
import numpy as np

from app.inference import run_mcmc, evaluate_model
from app.model import linear_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
w = np.array([3.5, -1.5,  0.0, 0.0])
sigma = 0.5

D = len(w)
N = 100
np.random.seed(0)
X_ = np.random.randn(N, D)
y_ = np.dot(X_, w) + np.random.randn(N) * sigma

X = jax.device_put(X_)
y = jax.device_put(y_)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
mcmc = run_mcmc(linear_model, X, y)
mcmc.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
     sigma      0.53      0.04      0.52      0.46      0.58   1364.29      1.00
      w[0]      3.43      0.05      3.43      3.35      3.51   1220.64      1.00
      w[1]     -1.56      0.06     -1.56     -1.64     -1.46    916.44      1.00
      w[2]     -0.01      0.06     -0.01     -0.11      0.09   1100.94      1.00
      w[3]     -0.01      0.05     -0.01     -0.09      0.08   1218.64      1.00

Number of divergences: 0


In [4]:
evaluate_model(linear_model, X, y, mcmc.get_samples())

0.7954699710394434

In [5]:
# comparison with arivz implementation
# looks like arviz does not divide WAIC by N (number of data points)
arviz.waic(mcmc, scale="negative_log")

See http://arxiv.org/abs/1507.04544 for details


Computed from 1000 posterior samples and 100 observations log-likelihood matrix.

           Estimate       SE
-elpd_waic    79.55     7.12
p_waic         4.73        -


In [6]:
# model selection
for i in range(4):
    XX = X[:, :D-i]
    mcmc = run_mcmc(linear_model, XX, y)
    waic = evaluate_model(linear_model, XX, y, mcmc.get_samples())
    print(f"WAIC for {D-i} dimensions: {waic}")

WAIC for 4 dimensions: 0.7954699710394434
WAIC for 3 dimensions: 0.7834370291883676
WAIC for 2 dimensions: 0.7765665572323642
WAIC for 1 dimensions: 1.9153631023793407
