In [1]:
import numpy as np
import dms_stan as dms

First we define our model. For demo purposes, we will do an exponential growth model:

In [2]:
# Normalized time
time = np.linspace(0, 1, 5)

# Define the model
model = dms.model.ExponentialGrowthBinomialModel(
    t = time,
    counts = np.random.randint(0, 100, (1, 5, 100)),
    log_A = dms.param.Normal(mu=0.0, sigma=0.01, shape=(100,)),
    r = dms.param.Normal(mu=0.0, sigma=5.0, shape=(100,)),
    sigma = dms.param.HalfNormal(sigma=0.03, shape=(1,))
)

This model defines a prior. We can perform an interactive prior predictive check as follows:

In [3]:
model.prior_predictive()

BokehModel(combine_events=True, render_bundle={'docs_json': {'d767bb5f-accf-43bf-a0f1-b46582d5b964': {'version…

Looking at everything compiled isn't the most helpful. Let's take a look at distributions in the different timesteps:

In [4]:
model.prior_predictive(initial_view="theta", independent_dim=1)

BokehModel(combine_events=True, render_bundle={'docs_json': {'c184c896-1ce8-46f0-a6fd-36ccf935562c': {'version…

We know the timepoints, though, so it's probably valuable to include them in our prior:

In [5]:
model.prior_predictive(initial_view="theta", independent_dim=1, independent_labels=time)

BokehModel(combine_events=True, render_bundle={'docs_json': {'c8b0a223-4569-4369-a719-6a9e13733ec6': {'version…

As we modified the priors interactively, our model was updated in kind:

This lets us set up our prior interactively. 

Now that we have the prior, we can identify the maximum a posteriori. Under the hood, this is done using PyTorch to minimize the negative log likelihood of the model given the prior we just defined.

In [6]:
# Sample data from the model for demo purposes. In real applications, you would
# provide your own data.
data = model.draw(1)["counts"][0]

# Fit the model to the data
map_ = model.approximate_map(counts = data)

Epochs:  19%|█▊        | 18517/100000 [00:46<03:23, 400.24it/s, loss=-407.32]  


Now let's draw samples from the MAP:

In [7]:
posterior_samples = map_["distributions"]["counts"].sample([100])

We can plot the distribution of trajectories we expect from the posterior:

In [8]:
dms.plotting.plot_distribution(
    samples = posterior_samples,
    independent_dim=-2,
    independent_labels=time,
    paramname="counts",
)

<class 'holoviews.element.raster.RGB'>


We can also overlay the experimentally observed counts:

In [9]:
dms.plotting.plot_distribution(
    samples = posterior_samples,
    independent_dim=-2,
    independent_labels=time,
    paramname="counts",
    overlay=data,
)

<class 'holoviews.element.raster.RGB'>
<class 'holoviews.element.chart.Curve'>
