In [1]:
import cmdstanpy
import pandas as pd
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import os
from cmdstanpy import cmdstan_path, CmdStanModel

from sklearn.metrics import mean_squared_error
from math import sqrt

from tqdm import tqdm
# import logging, sys
# logging.disable(sys.maxsize)

In [2]:
!export CMDSTAN='/Users/meenaljhajharia/cmdstan'

In [3]:
az.style.use("arviz-darkgrid")

In [4]:
def get_rmse_leapfrog(file, alpha, K, return_idata=False):
    model = CmdStanModel(stan_file=file, cpp_options={'STAN_THREADS':'true'})
#     model.compile(force=True)
    alpha = np.repeat(alpha,K)
    true = np.asarray(alpha/np.sum(alpha))[0]
    fit = model.sample(data=dict(K=K,alpha=alpha))
    dataframe_draws = fit.draws_pd()
    pred = np.asarray(dataframe_draws[['x[2]']])[:1000]
    true = np.repeat(0.1, 4000)
    leapfrog = dataframe_draws[['n_leapfrog__']][:1000]

    y=[]
    for i in range(1, 1001):
            y.append(sqrt(mean_squared_error(true[:i], pred[:i])))

    x = np.cumsum(leapfrog)
    
    return x,y

In [5]:
def get_plot(ax, alpha, K=10):
    
    file_simplex_stan = '/Users/meenaljhajharia/cmdstan/transforms/simplex-stan/simplex-stan.stan'
    file_simplex_stickbreaking = '/Users/meenaljhajharia/cmdstan/transforms/simplex-stickbreaking/simplex-stickbreaking.stan'
    file_simplex_softmax = '/Users/meenaljhajharia/cmdstan/transforms/simplex-softmax/simplex-softmax.stan'
#     file_simplex_softmax_augmented = '/Users/meenaljhajharia/cmdstan/transforms/simplex-softmax-augmented/simplex-softmax-augmented.stan'
    
    x_simplex_stan, y_simplex_stan = get_rmse_leapfrog(file_simplex_stan, alpha=alpha, K=K)
    x_simplex_stickbreaking, y_simplex_stickbreaking = get_rmse_leapfrog(file_simplex_stickbreaking, alpha=alpha, K=K)
    x_simplex_softmax, y_simplex_softmax = get_rmse_leapfrog(file_simplex_softmax, alpha=alpha, K=K)
#     x_simplex_softmax_augmented, y_simplex_softmax_augmented = get_rmse_leapfrog(file_simplex_softmax_augmented, alpha=alpha, K=K)

    ax.plot(x_simplex_stan, y_simplex_stan, label = "Stan Transform")
    ax.plot(x_simplex_stickbreaking, y_simplex_stickbreaking, label = "Stickbreaking")
    ax.plot(x_simplex_softmax, y_simplex_softmax, label = "Softmax")
#     ax.plot(x_simplex_softmax_augmented, y_simplex_softmax_augmented, label = "Softmax Augmented")
    
    ax.set_title('alpha='+str(alpha)+', K='+str(K))

In [None]:
plt.rcParams["figure.figsize"] = (20,10)
alphas = [10,1,0.1]
Ks=[10,1000,1000]

fig, axes = plt.subplots(3,3)

fig.supxlabel('Cumulative Leapfrog Steps')
fig.supylabel('Root Mean Squared Error')

for i in tqdm(alphas):
    for ax,j in zip(axes.flatten(), Ks):
        get_plot(ax,i, j) 
    
fig.legend(labels=["Stan Transform","Stickbreaking","Softmax"],bbox_to_anchor = (0.6, -0.05));


  0%|                                                     | 0/3 [00:00<?, ?it/s]INFO:cmdstanpy:found newer exe file, not recompiling
INFO:cmdstanpy:CmdStan start processing


chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status

                                                                                                                                                                                                                                                                                                                                

INFO:cmdstanpy:CmdStan done processing.





INFO:cmdstanpy:found newer exe file, not recompiling
INFO:cmdstanpy:CmdStan start processing


chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status

                                                                                                                                                                                                                                                                                                                                

INFO:cmdstanpy:CmdStan done processing.





INFO:cmdstanpy:found newer exe file, not recompiling
INFO:cmdstanpy:CmdStan start processing


chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status