In [68]:
import arviz as az
import cmdstanpy
import numpy as np

from dms_stan.datasets.trpb import load_trpb_dataset, TrpBGrowthModel
from dms_stan.model.stan.stan_results import SampleResults
from dms_stan.model.components import Normal

In [2]:
trpb_data = load_trpb_dataset(
    "~/GitRepos/DMSStan/raw_data/trpb/3-site_merged_replicates/LibI/20230926/LibI_merged_AAs.csv"
)
model = TrpBGrowthModel(**trpb_data)
trpb_data.pop("times")
fit = cmdstanpy.from_csv("/home/bwittmann/GitRepos/DMSStan/flip3/trpB/tmp/model-20250221204218_*")

In [42]:
stan_model = model.to_stan(output_dir="tmp")

22:12:22 - cmdstanpy - INFO - compiling stan file /home/bwittmann/GitRepos/DMSStan/tmp/model.stan to exe file /home/bwittmann/GitRepos/DMSStan/tmp/model


22:12:37 - cmdstanpy - INFO - compiled model executable: /home/bwittmann/GitRepos/DMSStan/tmp/model


In [None]:
name_mapper = model.get_dimname_map()

In [57]:
fit.stan_variables().keys()

dict_keys(['log_A', 'r_mean', 'r_std', 'r_raw', 'r', 'theta_t0', 'theta_tg0', 'starting_counts_ppc', 'timepoint_counts_ppc'])

In [69]:
# We need the variables for which there are samples generated by Stan
sampled_varnames = set(fit.stan_variables().keys())

# Set up values for recording
varname_to_named_shape = {}
dummies = set()

# Process all variables
for varname in stan_model.program.all_varnames:

    # Get the stan-friendly variable name
    stan_varname = varname.replace(".", "__")

    # Get the model component
    model_component = model[varname]

    # Get the name of the dimensions
    named_shape = [None] * model_component.ndim
    for dimind, dimsize in enumerate(model_component.shape[::-1]):

        # See if we can get the name of the dimension. If we cannot, this must
        # be a singleton dimension
        if (dimname := name_mapper.get((dimind, dimsize))) is None:
            assert dimsize == 1
            dimname = f"dummy_{dimind}"
        named_shape[dimind] = dimname

    # Scalars are a special case unless they are sampled
    if model_component.ndim == 0 and stan_varname not in sampled_varnames:
        named_shape = ["dummy_0"]

    # Update the set of dummies
    dummies.update(name for name in named_shape if name.startswith("dummy_"))

    # Update the mapping
    named_shape = named_shape[::-1]
    varname_to_named_shape[stan_varname] = named_shape

    # If an observable, also add the posterior predictive samples
    if model_component.observable:
        varname_to_named_shape[f"{stan_varname}_ppc"] = named_shape

    # If non-centered, also add the "raw" version
    if isinstance(model_component, Normal) and not model_component.is_hyperparameter:
        varname_to_named_shape[f"{stan_varname}_raw"] = named_shape


In [70]:
varname_to_named_shape

{'theta_tg0': ['c', 'b', 'a'],
 'theta_tg0__dist1__dist1__t': ['dummy_2', 'b', 'dummy_0'],
 'r_mean': ['a'],
 'r_std': [],
 'timepoint_counts': ['c', 'b', 'a'],
 'timepoint_counts_ppc': ['c', 'b', 'a'],
 'r_mean__beta': ['dummy_0'],
 'theta_tg0__dist1__dist1': ['c', 'b', 'a'],
 'r_std__mu': ['dummy_0'],
 'starting_counts__N': ['dummy_0'],
 'starting_counts': ['a'],
 'starting_counts_ppc': ['a'],
 'r': ['c', 'dummy_1', 'a'],
 'r_raw': ['c', 'dummy_1', 'a'],
 'theta_t0': ['a'],
 'log_A__mu': ['dummy_0'],
 'theta_t0__dist1': ['a'],
 'theta_tg0__dist1': ['c', 'b', 'a'],
 'log_A__sigma': ['dummy_0'],
 'r_std__sigma': ['dummy_0'],
 'log_A': ['a'],
 'timepoint_counts__N': ['c', 'b', 'dummy_0']}

In [71]:
coords = {name: np.arange(dimsize) for (_, dimsize), name in name_mapper.items()} | {
    dummy: np.array([0]) for dummy in dummies
}

In [72]:
test = az.from_cmdstanpy(
    fit, posterior_predictive=[f"{key}_ppc" for key in trpb_data.keys()],
    observed_data=trpb_data,
    constant_data=stan_model.autogathered_data,
    coords=coords,
    dims=varname_to_named_shape,
)

In [73]:
test

In [40]:
coords

{'a': array([   0,    1,    2, ..., 9118, 9119, 9120]),
 'b': array([0, 1, 2, 3, 4]),
 'c': array([0, 1]),
 'dummy_2': array([0]),
 'dummy_0': array([0]),
 'dummy_1': array([0])}