In [1]:
import numpy as np
import dms_stan as dms
import dms_stan.model.components as dms_components

First we define our model. For demo purposes, we will do an exponential growth model:

In [3]:
# Normalized time
time = np.linspace(0, 1, 5)[:, None]

# Define the model
model = dms.model.ExponentialGrowthBinomialModel(
    t = time,
    counts = np.random.randint(0, 100, (5, 100)),
    log_A = dms_components.parameters.Normal(mu=0.0, sigma=0.01, shape=(100,)),
    r = dms_components.parameters.Normal(mu=0.0, sigma=5.0, shape=(100,)),
    sigma = dms_components.parameters.HalfNormal(sigma=0.03, shape=(1,))
)

This model defines a prior. We can perform an interactive prior predictive check as follows:

In [4]:
model.prior_predictive()

BokehModel(combine_events=True, render_bundle={'docs_json': {'e459d406-fc96-4432-bf14-e6e47d91c015': {'version…

Looking at everything compiled isn't the most helpful. Let's take a look at distributions in the different timesteps:

In [5]:
model.prior_predictive(initial_view="theta", independent_dim=0)

BokehModel(combine_events=True, render_bundle={'docs_json': {'809b1338-03c7-493f-9316-f9ba5aecbf06': {'version…

We know the timepoints, though, so it's probably valuable to include them in our prior:

In [6]:
model.prior_predictive(
    initial_view="theta", independent_dim=0, independent_labels=time.flatten()
)

BokehModel(combine_events=True, render_bundle={'docs_json': {'6e75a546-c934-4a3e-963b-9b7076358273': {'version…

As we modified the priors interactively, our model was updated in kind:

This lets us set up our prior interactively. 

Now that we have the prior, we can identify the maximum a posteriori. Under the hood, this is done using PyTorch to minimize the negative log likelihood of the model given the prior we just defined.

In [7]:
# Sample data from the model for demo purposes. In real applications, you would
# provide your own data.
data = model.draw(1)["counts"][0]

# Fit the model to the data
map_ = model.approximate_map(counts = data)

Epochs:  18%|█▊        | 17856/100000 [00:30<02:18, 592.30it/s, loss=-1695.97]


Now let's draw samples from the MAP:

In [8]:
posterior_samples = map_["distributions"]["counts"].sample([100])

We can plot the distribution of trajectories we expect from the posterior:

In [9]:
dms.plotting.plot_distribution(
    samples = posterior_samples,
    independent_dim=-2,
    independent_labels=time.flatten(),
    paramname="counts",
)

We can also overlay the experimentally observed counts:

In [10]:
dms.plotting.plot_distribution(
    samples = posterior_samples,
    independent_dim=-2,
    independent_labels=time.flatten(),
    paramname="counts",
    overlay=data,
)