In [1]:
import pandas as pd
import numpy as np
import os
import time
import json
import sys
import pickle
from cmdstanpy import CmdStanModel
import argparse
import arviz as az
from pathlib import Path
from tqdm import tqdm
from pathlib import Path
from scipy.stats import norm, entropy

In [None]:
stan_filename='../stan_models/simplex/ALR_DirichletSymmetric.stan'

In [None]:
with open(stan_filename, 'w') as f:
    f.write(f'#include ../../target_densities/DirichletSymmetric.stan{os.linesep}#include ../../transforms/simplex/ALR.stan{os.linesep}')
    f.close()

In [None]:
output_file_name='../output/simplex/ALR/samples_3_2.nc'
data={'alpha': [0.1]*1000, 'N':1000}
n_repeat=2
output_file_name_time='../output/simplex/ALR/time_1_2.txt'

In [2]:
import sys
sys.path.insert(1, '../utils')
from sample import sample

In [None]:
sample(
    stan_filename,
    data,
    output_file_name,
    n_repeat,
    output_file_name_time,
    n_iter=1000,
    n_chains=4,
    show_progress=True,
    return_idata=False,
    inits=None
)

In [None]:
idata = az.from_netcdf(output_file_name)

In [None]:
ess={}
for i in idata.posterior.x.x_dim_0.values:
    ess_array=[]
    for k in tqdm(range(2)):
        ess_array.append(list(az.ess(idata.sel(chain=[k*4, k*4+1, k*4+2, k*4+3], 
                                         x_dim_0=i), var_names=['x']).values())[0].item())
    ess['x_'+str(i)] = ess_array

In [None]:
true_var = np.asarray(data['alpha'])/sum(data['alpha'])

def cumulative_mean(x):
    return np.divide(np.cumsum(x), np.arange(1, len(x) + 1))

In [None]:
rmse={}
for i in idata.posterior.x.x_dim_0.values:
    pred_var = cumulative_mean(np.mean(idata.posterior['x'].sel(x_dim_0=i), axis=0))

    rmse_array = []
    for j in tqdm(range(1, len(pred_var) + 1)):
        rmse_array.append(np.sqrt(np.mean((true_var[i]-pred_var[:j].values) ** 2)))
    rmse['x_'+str(i)] = rmse_array

In [None]:
idata.sample_stats.n_steps.values

In [None]:
az.summary(idata)
# az.summary(idata.sel(chain=[0,1,2,3,4,5,6,7,8,9,10,11]))

In [6]:
transforms = ['Stickbreaking', 'ALR',
    'AugmentedILR', 'HypersphericalAngular', 'HypersphericalLogit',
    'HypersphericalProbit', 'ProbitProduct']
n_repeat=100
for i in transforms:
    for j in list(datajson.keys()):
    
        stan_filename=f'../stan_models/simplex/{i}_DirichletSymmetric.stan'

        with open(stan_filename, 'w') as f:
            f.write(f'#include ../../target_densities/DirichletSymmetric.stan{os.linesep}#include ../../transforms/simplex/{i}.stan{os.linesep}')
            f.close()
        output_file_name=f'../output/simplex/{i}/samples_{j}_{n_repeat}.nc'
        alpha=datajson[j]
        data={'alpha': alpha, 'N': len(alpha)}

        output_file_name_time=f'../output/simplex/ALR/time_{j}_{n_repeat}.txt'

#         sample(
#         stan_filename,
#         data,
#         output_file_name,
#         n_repeat,
#         output_file_name_time,
#         n_iter=1000,
#         n_chains=4,
#         show_progress=True,
#         return_idata=False,
#         inits=None
#         )

        idata = az.from_netcdf(output_file_name)

        ess={}
        for i in idata.posterior.x.x_dim_0.values:
            print("lol")
            ess_array=[]
            for k in tqdm(range(n_repeat)):
                ess_array.append(list(az.ess(idata.sel(chain=[k*4, k*4+1, k*4+2, k*4+3], 
                                                 x_dim_0=i), var_names=['x']).values())[0].item())
            ess['x_'+str(i)] = ess_array
        
        with open(f'{output_dir}/{i}/DirichletSymmetric/ess_{j}_{n_repeat}.pickle', 'wb') as handle:
            pickle.dump(ess, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
        true_var = np.asarray(data['alpha'])/sum(data['alpha'])

        def cumulative_mean(x):
            return np.divide(np.cumsum(x), np.arange(1, len(x) + 1))

        rmse={}
        for i in idata.posterior.x.x_dim_0.values:
            pred_var = cumulative_mean(np.mean(idata.posterior['x'].sel(x_dim_0=i), axis=0))

            rmse_array = []
            for j in tqdm(range(1, len(pred_var) + 1)):
                rmse_array.append(np.sqrt(np.mean((true_var[i]-pred_var[:j].values) ** 2)))
            rmse['x_'+str(i)] = rmse_array
            
        with open(f'{output_dir}/{i}/DirichletSymmetric/rmse_{j}_{n_repeat}.pickle', 'wb') as handle:
            pickle.dump(rmse, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
        with open(f'{output_dir}/{i}/DirichletSymmetric/leapfrog_{j}_{n_repeat}.pickle', 'wb') as handle:
            pickle.dump(idata.sample_stats.n_steps.values, handle, protocol=pickle.HIGHEST_PROTOCOL) 
        
        if data['N']==1000:
            with open(f'{output_dir}/{i}/DirichletSymmetric/summary_{j}_{n_repeat}.pickle', 'wb') as handle:
                pickle.dump(az.summary(idata.sel(chain=[0,1,2,3,4,5,6,7,8,9,10,11])), handle, protocol=pickle.HIGHEST_PROTOCOL) 
        else:
            with open(f'{output_dir}/{i}/DirichletSymmetric/summary_{j}_{n_repeat}.pickle', 'wb') as handle:
                pickle.dump(az.summary(idata), handle, protocol=pickle.HIGHEST_PROTOCOL) 

                

lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 162.52it/s]


lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 260.86it/s]


lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 272.12it/s]


lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 248.45it/s]


lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 298.03it/s]


lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 248.85it/s]


lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 294.37it/s]


lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 276.03it/s]


lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 252.78it/s]


lol


100%|███████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 291.61it/s]


{'x_0': [1771.5888451983462, 1418.1619638092793], 'x_1': [1910.715114270037, 1926.643326244528], 'x_2': [1417.2752431942467, 1879.7427323672657], 'x_3': [1862.851066242899, 1691.539702558564], 'x_4': [2293.3581148905464, 1822.5241210900967], 'x_5': [1797.409460265909, 1895.4812013628782], 'x_6': [1706.3828139476677, 2051.3897407542936], 'x_7': [1501.7713054890091, 1903.6849917407308], 'x_8': [1623.4450915347509, 1816.3731214855802], 'x_9': [423.4285097793711, 341.60133476642704]}


100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 15223.45it/s]
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 15310.25it/s]
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 14330.05it/s]
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 14577.48it/s]
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 13411.17it/s]
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 15334.54it/s]
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 15379.24it/s]
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 14835.49it/s]
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 14620.06it/s]
100%|███████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 15017.52it/s]


{'x_0': [0.09374695460536377, 0.07241920033475852, 0.0631665828362379, 0.05471242094652188, 0.048940465535903206, 0.04482644251934361, 0.041533758429638176, 0.03893069267759263, 0.03672159579794416, 0.034976630359683375, 0.03362544878775996, 0.03232047611769896, 0.03129522619599589, 0.030316737699684557, 0.029338705700852615, 0.02840847981313588, 0.027576862944323228, 0.02687849223250072, 0.026258332242995425, 0.02571888101734688, 0.025270528398127927, 0.024959474853385195, 0.02467380260280311, 0.024367679203334207, 0.02399037428994617, 0.023625126638776656, 0.02325120887529051, 0.022891703547713575, 0.022541606979202114, 0.022213363739749553, 0.021887863826375496, 0.02159433954027613, 0.021281234709092208, 0.02098995299722106, 0.0207217369880858, 0.02048247832712831, 0.020216181860991493, 0.019963257902943638, 0.0197096948779251, 0.019472367795019017, 0.019245111229273408, 0.019022804635696895, 0.018816851056002407, 0.018622949127631037, 0.018436303537362128, 0.018261880780253415, 0.0

In [None]:
import bridgestan as bs
from pathlib import Path
import os
from utils import *
import json
import pickle
import arviz as az
import numpy as np
from tqdm import tqdm


bsmodel = bs.StanModel.from_stan_file(stan_filename, data, stanc_args=[f"--include-paths='/mnt/home/mjhajaria/simplex-transforms/'"])
n=bsmodel.param_unc_num()
print(pkey, bsmodel.param_num(include_tp=True), n)
hessian=np.empty((400000,n,n))
grad=np.empty((400000,n))


for i in range(400):
    data = idata.posterior.sel(chain=[i]).y.values[0]

    for idx, row in enumerate(data):
        theta = bsmodel.param_unconstrain(row)
        bsmodel.log_density_hessian(theta, out_grad=grad[idx], out_hess=hessian[idx])
    
    numpy.save("hessian.npy",hessian)
    numpy.save("grad.npy",grad)