In [2]:
import os
import sys
sys.path.insert(1, 'utils')
from cmdstanpy import CmdStanModel
from pathlib import Path

from scipy.stats import norm, entropy, skewnorm
import numpy as np
from utils import get_kl_divergence, get_true_x

In [3]:
evaluating_model='DirichletSymmetric'
transform_category='simplex'
transform='ALR'

Path("stan_models").mkdir(parents=True, exist_ok=True)
stan_filename=f'stan_models/{transform}_{evaluating_model}.stan'

with open(stan_filename, 'w') as f:
    f.write(f'#include ../target_densities/{evaluating_model}.stan{os.linesep}#include ../transforms/{transform_category}/{transform}.stan{os.linesep}')
    f.close()
    
model = CmdStanModel(stan_file=stan_filename, cpp_options={"STAN_THREADS": "true"})

06:18:36 - cmdstanpy - INFO - compiling stan file /Users/meenaljhajharia/cmdstan/transforms/stan_models/ALR_DirichletSymmetric.stan to exe file /Users/meenaljhajharia/cmdstan/transforms/stan_models/ALR_DirichletSymmetric
06:18:50 - cmdstanpy - INFO - compiled model executable: /Users/meenaljhajharia/cmdstan/transforms/stan_models/ALR_DirichletSymmetric


In [4]:
n_iter=1000

# alpha= skewnorm.rvs(a=20, loc=0.1, scale=50, size=1000)
# alpha = np.linspace(100, 0.1, 1000)
alpha = np.linspace(0.1, 100, 1000)

draws = model.sample(data={'alpha': alpha, 'N':len(alpha)}, iter_sampling=n_iter)

06:18:54 - cmdstanpy - INFO - CmdStan start processing


chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status

                                                                                                                                                                                                                                                                                                                                

06:19:09 - cmdstanpy - INFO - CmdStan done processing.





	Chain 1 had 3 divergent transitions (0.3%)
	Chain 3 had 11 divergent transitions (1.1%)
	Chain 4 had 15 divergent transitions (1.5%)
	Use function "diagnose()" to see further information.


In [5]:
get_kl_divergence(draws, n_iter)

1.6733222463175095

In [7]:
draws.summary()

Unnamed: 0_level_0,Mean,MCSE,StdDev,5%,50%,95%,N_Eff,N_Eff/s,R_hat
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
lp__,374.965000,0.677076,22.445100,338.024000,375.197000,411.104000,1098.92000,45.95490,1.002110
y[1],-14.381800,0.284721,9.719570,-34.155100,-11.421600,-5.108320,1165.35000,48.73290,1.003620
y[2],-9.864450,0.156622,5.221900,-20.179200,-8.371770,-4.629230,1111.60000,46.48510,1.004840
y[3],-8.128770,0.096717,3.587830,-15.136900,-7.145950,-4.336330,1376.14000,57.54770,1.001290
y[4],-7.138140,0.061410,2.616210,-12.150600,-6.518480,-4.066780,1814.96000,75.89860,1.001330
...,...,...,...,...,...,...,...,...,...
x[996],0.001994,0.000003,0.000203,0.001669,0.001988,0.002339,6399.49000,267.61600,0.999836
x[997],0.001989,0.000003,0.000196,0.001682,0.001984,0.002320,4993.02000,208.79900,0.999329
x[998],0.001993,0.000002,0.000192,0.001688,0.001987,0.002321,6886.68000,287.98900,0.999269
x[999],0.001997,0.000002,0.000196,0.001688,0.001993,0.002323,6825.96000,285.45000,0.999622


In [11]:
import arviz as az
idata=az.from_cmdstanpy(draws)

2.1223429693136105e-06 1.998001998001998e-06


In [19]:
from rmse import rmse

pred_x=idata.posterior.x.mean(dim=['draw','chain']).values
true_x = get_true_x(alpha=alpha, evaluating_model='DirichletAsymmetric')

rmse(true_x, pred_x)

2.1361103907762476e-06