In [1]:
from dms_stan.datasets.trpb import (
    TrpBExponentialGrowth,
    TrpBSigmoidGrowthInitParam,
    TrpBSigmoidGrowth,
    load_trpb_dataset
)
from dms_stan.model.stan.stan_results import SampleResults

SOURCE_FILE = "~/GitRepos/DMSStan/raw_data/trpb/3-site_merged_replicates/LibI/20230926/LibI_merged_AAs.csv"

Prior predictive check for the TrpB exponential growth model:

In [2]:
EXP_MODEL = TrpBExponentialGrowth.from_data_file(SOURCE_FILE)
EXP_MODEL.prior_predictive()

BokehModel(combine_events=True, render_bundle={'docs_json': {'4e9d1007-6f9a-4111-abb5-15b6841b4dcc': {'version…

Now a slightly more expressive model: Sigmoid growth parametrized using initial abundances:

In [3]:
SIG_INIT_MODEL = TrpBSigmoidGrowthInitParam.from_data_file(SOURCE_FILE)
SIG_INIT_MODEL.prior_predictive()

BokehModel(combine_events=True, render_bundle={'docs_json': {'73566234-618c-47de-87e4-7b42200de422': {'version…

Slightly more expressive again: Sigmoid growth with variable growth rates and inflection points, but assuming identical maximum abundances for all variants.

In [4]:
SIG_MODEL = TrpBSigmoidGrowth.from_data_file(SOURCE_FILE)
SIG_MODEL.prior_predictive()

BokehModel(combine_events=True, render_bundle={'docs_json': {'0131c58e-10ba-4cda-8472-a9440e7ad4a0': {'version…

# MAP

Now that we've selected our priors, we're ready to identify the MAP for each.

In [5]:
# EXP_MAP = EXP_MODEL.approximate_map(early_stop=10, device=0, seed=1025)
# EXP_MAP.plot_loss_curve()

We can plot the posterior predictive checks for the MAP:

In [6]:
# EXP_MAP.get_inference_obj(batch_size=50).run_ppc(logy_ppc_samples=True)

Same for the abundance-initialized sigmoid model:

In [None]:
SIG_INIT_MAP = SIG_INIT_MODEL.approximate_map(early_stop=10, device=0, seed=1025)
SIG_INIT_MAP.plot_loss_curve()

Epochs: 100%|██████████| 10/10 [00:00<00:00, 56.51it/s, -log pdf/pmf=69385087.44]


In [None]:
SIG_INIT_INF_OBJ = SIG_INIT_MAP.get_inference_obj(batch_size=50)
SIG_INIT_INF_OBJ.run_ppc(logy_ppc_samples=True)

And for the full sigmoid model:

In [None]:
SIG_MAP = SIG_MODEL.approximate_map(early_stop=10, device=0, seed=1025)
SIG_MAP.plot_loss_curve()

Epochs:   0%|          | 1/100000 [00:00<1:30:38, 18.39it/s, -log pdf/pmf=69708513.99]

Epochs:  28%|██▊       | 28113/100000 [11:13<28:43, 41.72it/s, -log pdf/pmf=477823.05] 


In [None]:
SIG_MAP.get_inference_obj(batch_size=50).run_ppc(logy_ppc_samples=True)

KeyError: 'A__dist1'

# MCMC

Finally, we will use Stan to sample from the posterior. Sampling is likely to take some time with these models, so we're going to compile an object that will allow us to run sampling outside of the notebook:

In [None]:
EXP_MODEL.mcmc(
    output_dir="./exponential",
    cpp_options={"STAN_THREADS": True},
    seed=1025,
    delay_run=True,
)
SIG_INIT_MODEL.mcmc(
    output_dir="./sigmoid_init",
    cpp_options={"STAN_THREADS": True},
    seed=1025,
    delay_run=True,
    iter_warmup=2000,
)
SIG_MODEL.mcmc(
    output_dir="./sigmoid",
    cpp_options={"STAN_THREADS": True},
    seed=1025,
    delay_run=True,
    iter_warmup=2000,
)

AssertionError: Node encountered twice in tree

Now run analysis on the diagnostics and report:

In [None]:
# samples = SampleResults.from_disk("/home/bwittmann/GitRepos/DMSStan/flip3/trpB/sigmoid/model-20250410192656-20250410192733_arviz.nc", skip_fit=True)

In [None]:
# _ = samples.diagnose()

Sample diagnostic tests results' summaries:
-------------------------------------------
0 of 4000 (0.00%) samples had a low energy.
0 of 4000 (0.00%) samples reached the maximum tree depth.
0 of 4000 (0.00%) samples diverged.

R_hat diagnostic tests results' summaries:
------------------------------------------
13 of 9121 (0.14%) r_hats tests failed for A__dist1.
56 of 9121 (0.61%) r_hats tests failed for r_mean.
1 of 1 (100.00%) r_hats tests failed for r_std.
78 of 18242 (0.43%) r_hats tests failed for r_raw.
13 of 9121 (0.14%) r_hats tests failed for A.
24 of 18242 (0.13%) r_hats tests failed for r.
16 of 9121 (0.18%) r_hats tests failed for raw_abundances_t0.
192 of 91210 (0.21%) r_hats tests failed for raw_abundances_tg0.
16 of 9121 (0.18%) r_hats tests failed for theta_t0.
186 of 91210 (0.20%) r_hats tests failed for theta_tg0.

Ess_bulk diagnostic tests results' summaries:
---------------------------------------------
0 of 9121 (0.00%) ess_bulks tests failed for A__dist1.
0 of 91

In [None]:
# samples.run_ppc(logy_ppc_samples=True)

BokehModel(combine_events=True, render_bundle={'docs_json': {'4e828aeb-29f7-4996-b9c2-578105eef1e9': {'version…

In [None]:
# samples.inference_obj