In [1]:
import re
_INDEX_EXTRACTOR = re.compile(r"([A-Za-z0-9_]+)\[?([0-9, ]*)\]?")

In [2]:
import copy
import re

import numpy as np
import dms_stan as dms
import dms_stan.model.components as dms_components

import itertools

First we define our model. For demo purposes, we will do an exponential growth model:

In [3]:
# Normalized time
time = np.linspace(0, 1, 5)

# Define the model
model = dms.model.ExponentialGrowthBinomialModel(
    t = time[:, None],
    counts = np.random.randint(0, 100, (5, 100)),
    log_A = dms_components.Normal(mu=0.0, sigma=0.01, shape=(100,)),
    r = dms_components.Normal(mu=0.0, sigma=5.0, shape=(100,)),
    sigma = dms_components.HalfNormal(sigma=0.03, shape=())
)

In [4]:
model.constants_dict

{'counts_N': <dms_stan.model.components.constants.Constant at 0x253beec8440>,
 'log_A_mu': <dms_stan.model.components.constants.Constant at 0x253bf0e2090>,
 'log_A_sigma': <dms_stan.model.components.constants.Constant at 0x253bed0cb00>,
 'log_theta_unorm_mean_t': <dms_stan.model.components.constants.Constant at 0x253bf1eb290>,
 'r_mu': <dms_stan.model.components.constants.Constant at 0x253bf1d4260>,
 'r_sigma': <dms_stan.model.components.constants.Constant at 0x253becadf70>,
 'sigma_mu': <dms_stan.model.components.constants.Constant at 0x253bb5abd70>,
 'sigma_sigma': <dms_stan.model.components.constants.Constant at 0x253bf1eaf00>}

In [5]:
stan_model = dms.model.stan.StanModel(model)

In [6]:
stan_model.steps

{'data': ['array[5,100] int<lower=0> counts',
  'array[5,1] int<lower=0> counts_N',
  'real log_A_mu',
  'real<lower=0> log_A_sigma',
  'real r_mu',
  'real<lower=0> r_sigma',
  'real sigma_mu',
  'real<lower=0> sigma_sigma'],
 'parameters': ['array[5] vector[100] log_theta_unorm',
  'vector[100] log_A',
  'vector[100] r',
  'real<lower=0.0> sigma'],
 'transformed parameters': ['array[5] vector[100]<lower=0.0> theta',
  'array[5] vector[100]<upper=0.0> log_theta',
  'array[5] vector[100] log_theta_unorm_mean'],
 'model': '',
 'generated quantities': []}

In [7]:
# The number of levels is given by the dimensionality of the observable. We create
# lists for each level. Different variables will be defined and accessed depending
# on the level to which they belong. Note that we assume the last level is vectorized
n_levels = model.observables[0].stan_code_level
model_levels = [[] for _ in range(n_levels)]
transformed_data_levels = copy.deepcopy(model_levels)

# Get allowed index variable names for each level
allowed_index_names = tuple(
    char
    for char in dms.defaults.DEFAULT_INDEX_ORDER
    if {char, char.upper()}.isdisjoint(stan_model.all_varnames)
)

# There should be enough index names to cover all levels
if len(allowed_index_names) < (n_levels - 1):
    raise ValueError(
        f"Not enough index names ({len(allowed_index_names)}) to cover {n_levels} levels"
    )

# Observable is automatically the last level
model_levels[-1].append(model.observables[0].get_target_incrementation(allowed_index_names))

# # Loop over model parameters beginning from the observable
# for _, parent, _ in model.observables[0].recurse_parents():

    # If a named

# We are working up to the `get_stan_transformation` and `get_stan_distribution` methods

# Example:
model.counts.get_indexed_varname(allowed_index_names)

IndexError: list index out of range

In [None]:
model.observables[0].parameters["N"].get_indexed_varname(allowed_index_names)

ValueError: DMS Stan variable name not set. This is only set when the parameteris used in a DMS Stan model.

In [None]:
model.observables[0].get_target_incrementation(allowed_index_names)

ValueError: DMS Stan variable name not set. This is only set when the parameteris used in a DMS Stan model.

In [None]:
model.counts.get_indexed_varname(allowed_index_names)

('counts[1,j,k]', True)