In [1]:
import numpy as np

from app.stan.inference import run_mcmc, evaluate_model, make_prediction
from app.stan.model import get_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
w = np.array([3.5, -1.5,  0.0, 0.0])
sigma = 0.5

D = len(w)
N = 100
np.random.seed(0)
X = np.random.randn(N, D)
y = np.dot(X, w) + np.random.randn(N) * sigma

In [3]:
linear_model = get_model()
mcmc = run_mcmc(linear_model, X, y, num_chains=4)
mcmc.summary()

17:25:54 - cmdstanpy - INFO - compiling stan file /data/home/du/workspace/numpyro_linear_regression_waic/src/app/stan/model.stan to exe file /data/home/du/workspace/numpyro_linear_regression_waic/src/app/stan/model
17:26:21 - cmdstanpy - INFO - compiled model executable: /data/home/du/workspace/numpyro_linear_regression_waic/src/app/stan/model
17:26:21 - cmdstanpy - INFO - CmdStan start processing
17:26:21 - cmdstanpy - INFO - Chain [1] start processing
17:26:21 - cmdstanpy - INFO - Chain [2] start processing
17:26:21 - cmdstanpy - INFO - Chain [3] start processing
17:26:21 - cmdstanpy - INFO - Chain [4] start processing
17:26:22 - cmdstanpy - INFO - Chain [1] done processing
17:26:22 - cmdstanpy - INFO - Chain [4] done processing
17:26:22 - cmdstanpy - INFO - Chain [2] done processing
17:26:22 - cmdstanpy - INFO - Chain [3] done processing


Unnamed: 0,Mean,MCSE,StdDev,5%,50%,95%,N_Eff,N_Eff/s,R_hat
lp__,-85.010500,0.034815,1.600810,-88.156300,-84.674800,-83.053300,2114.14,2288.03,1.000850
w[1],3.424320,0.000726,0.049929,3.342750,3.423440,3.506520,4723.24,5111.74,0.999924
w[2],-1.557500,0.000861,0.054606,-1.647080,-1.559210,-1.467980,4020.61,4351.31,0.999670
w[3],-0.014894,0.000869,0.060309,-0.114398,-0.014557,0.083934,4815.63,5211.72,0.999974
w[4],-0.009083,0.000724,0.053545,-0.099930,-0.008251,0.077756,5473.88,5924.11,0.999574
...,...,...,...,...,...,...,...,...,...
logp[96],-0.384072,0.002334,0.144335,-0.660608,-0.356488,-0.206780,3825.10,4139.72,1.000140
logp[97],-0.672388,0.002965,0.207382,-1.058130,-0.641468,-0.386968,4893.60,5296.11,0.999392
logp[98],-0.311268,0.001228,0.078727,-0.446517,-0.306971,-0.189385,4112.06,4450.29,1.000270
logp[99],-1.043010,0.003469,0.228425,-1.464140,-1.025960,-0.701853,4335.74,4692.35,0.999823


In [4]:
evaluate_model(mcmc)

0.795507695137393

In [5]:
# model selection
results = []
for i in range(4):
    XX = X[:, :D-i]
    mcmc_ = run_mcmc(linear_model, XX, y)
    waic = evaluate_model(mcmc_)
    results.append((i, waic))

for i, waic in results:
    print(f"WAIC for {D-i} dimensions: {waic}")

17:26:23 - cmdstanpy - INFO - CmdStan start processing
17:26:23 - cmdstanpy - INFO - Chain [1] start processing


17:26:24 - cmdstanpy - INFO - Chain [1] done processing
17:26:24 - cmdstanpy - INFO - CmdStan start processing
17:26:24 - cmdstanpy - INFO - Chain [1] start processing
17:26:24 - cmdstanpy - INFO - Chain [1] done processing
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'model.stan', line 23, column 12 to column 54)
Consider re-running with show_console=True if the above output is unclear!
17:26:24 - cmdstanpy - INFO - CmdStan start processing
17:26:24 - cmdstanpy - INFO - Chain [1] start processing
17:26:25 - cmdstanpy - INFO - Chain [1] done processing
Exception: normal_lpdf: Scale parameter is 0, but must be positive! (in 'model.stan', line 23, column 12 to column 54)
Consider re-running with show_console=True if the above output is unclear!
17:26:25 - cmdstanpy - INFO - CmdStan start processing
17:26:25 - cmdstanpy - INFO - Chain [1] start processing
17:26:25 - cmdstanpy - INFO - Chain [1] done processing


WAIC for 4 dimensions: 0.7973426224066149
WAIC for 3 dimensions: 0.7877531351930877
WAIC for 2 dimensions: 0.7781081652866775
WAIC for 1 dimensions: 1.9152551051974886


In [6]:
make_prediction(linear_model, X, mcmc)

17:26:25 - cmdstanpy - INFO - Chain [1] start processing
17:26:25 - cmdstanpy - INFO - Chain [2] start processing
17:26:25 - cmdstanpy - INFO - Chain [3] start processing
17:26:25 - cmdstanpy - INFO - Chain [4] start processing
17:26:25 - cmdstanpy - INFO - Chain [2] done processing
17:26:25 - cmdstanpy - INFO - Chain [1] done processing
17:26:25 - cmdstanpy - INFO - Chain [3] done processing
17:26:25 - cmdstanpy - INFO - Chain [4] done processing


array([[ 5.54901  ,  6.36067  ,  0.0634991, ...,  5.49057  ,  0.729398 ,
         0.877588 ],
       [ 4.93151  ,  8.01321  , -1.3038   , ...,  5.07612  ,  1.13686  ,
         0.708401 ],
       [ 5.75272  ,  8.49084  , -1.01292  , ...,  5.09816  ,  1.64969  ,
         2.36296  ],
       ...,
       [ 5.44669  ,  7.33665  , -1.2899   , ...,  5.71663  ,  1.60873  ,
         0.364744 ],
       [ 5.88514  ,  8.44136  , -1.11645  , ...,  6.53558  ,  1.2513   ,
         1.32404  ],
       [ 5.59862  ,  8.39457  , -0.92521  , ...,  5.19721  ,  2.06629  ,
         1.32276  ]])