In [1]:
import cmdstanpy
import pandas as pd
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import os
from cmdstanpy import cmdstan_path, CmdStanModel

In [2]:
az.style.use("arviz-darkgrid")

In [3]:
simplex_stan = os.path.join('transforms/simplex-stan/simplex-stan.stan')
simplex_stickbreaking = os.path.join('transforms/simplex-stickbreaking/simplex-stickbreaking.stan')

In [4]:
def calc_rhat_mixed_chains(path_1, path_2, variables, data, force_compile=False):
    
    """
    This function evaluates rhat for 4 chains from 2 different models, as in 2 chains from each.
    
    Parameters
    ----------
    
    path_1: str
        path for the first model
        
    path_2: str
        path for the second model
    
    variables: list
        Variables to evaluate rhat for
    
    data: dict
        Dictionary of data
    
    force_compile: Bool
        Whether to recompile or not
    """
    variables = ' '.join(variables)
    #Build stan model
    file_1 = os.path.join(path_1)
    file_2 = os.path.join(path_2)
    model_1 = CmdStanModel(stan_file=file_1, cpp_options={'STAN_THREADS':'true'})
    model_2 = CmdStanModel(stan_file=file_2, cpp_options={'STAN_THREADS':'true'})
    
    #Recompile
    if force_compile is True:
        model_1.compile(force=True)
        model_2.compile(force=True)
    
    #Fit stan model
    fit_1 = model_1.sample(data=dict(K=10))
    fit_2 = model_2.sample(data=dict(K=10))
    
    #Convert to idata for arviz
    idata_1 = az.from_cmdstanpy(fit_1)
    idata_2 = az.from_cmdstanpy(fit_2)
    
    #Stack the samples according to chains
    stacked_1 = az.extract_dataset(idata_1)
    stacked_2 = az.extract_dataset(idata_2)
    
    #Concatenate chains and evaluate rhat
    chains = np.concatenate((stacked_1.sel(chain=1)[str(variables)], stacked_1.sel(chain=2)[str(variables)], stacked_2.sel(chain=1)[str(variables)], stacked_2.sel(chain=2)[str(variables)]), axis=1)
    rhat = az.rhat(chains, var_names=variables,method="rank")
    return rhat

In [5]:
calc_rhat_mixed_chains(path_1=simplex_stan,
                       path_2=simplex_stan,
                       variables=["x"],
                       data=dict(K=10))

INFO:cmdstanpy:found newer exe file, not recompiling
INFO:cmdstanpy:found newer exe file, not recompiling
INFO:cmdstanpy:CmdStan start processing


chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status

                                                                                                                                                                                                                                                                                                                                

INFO:cmdstanpy:CmdStan done processing.
INFO:cmdstanpy:CmdStan start processing





chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status

                                                                                                                                                                                                                                                                                                                                

INFO:cmdstanpy:CmdStan done processing.





1.000112883381856