In [1]:
import holoviews as hv
import numpy as np
import scistanpy as ssp

hv.extension('bokeh')


This notebook contains an example workflow for modeling deep mutational scanning data. It covers the major steps of defining and using a model in SciStanPy.

To begin, let's just create some example data:

In [2]:
def sample_data():
    """
    Generate sample data for deep mutational scanning analysis.
    Returns:
        INPUT_COUNTS: Array of input counts for each variant.
        LOG_INPUT_FREQS: Log frequencies of input variants.
        LOG_OUTPUT_FREQS: Log frequencies of output variants after selection.
    """
    # Sample input counts
    rng = np.random.default_rng(1025)
    input_freqs = rng.dirichlet(np.ones(10))
    log_input_freqs = np.log(input_freqs)
    input_counts = np.stack([rng.multinomial(10000, input_freqs)
                             for _ in range(3)])

    # Sample enrichment factors
    log_enrichment_factors = np.log(rng.exponential(0.1, size=(10,)))

    # Generate output counts after selection
    log_output_freqs = log_input_freqs + log_enrichment_factors
    log_output_freqs -= np.log(np.sum(np.exp(log_output_freqs))) # Normalize
    output_counts = np.stack([rng.multinomial(10000, np.exp(log_output_freqs))
                               for _ in range(3)])

    return {
        "INPUT_COUNTS": input_counts,
        "LOG_INPUT_FREQS": log_input_freqs,
        "OUTPUT_COUNTS": output_counts,
        "LOG_OUTPUT_FREQS": log_output_freqs,
        "LOG_ENRICHMENT_FACTORS": log_enrichment_factors
    }

SAMPLE_DATA = sample_data()

The above simulates a deep mutational scanning experiment of 10 variants (so, 10 sampled log-enrichment ratios); sequencing was performed in triplicate at both the beginning and end (so, 3 samples from a multinomial distribution).

We can now model the process using SciStanPy:

In [3]:
# All model's inherit from `ssp.Model`
class DMSModel(ssp.Model):

    # Define the structure of the model in the `__init__` method
    def __init__(self, input_counts, output_counts):

        # We're going to register default data for the input and output counts.
        # This isn't necessary, but means you won't need to pass the observables
        # into later methods.
        super().__init__(
            default_data={"input_counts": input_counts, "output_counts": output_counts}
        )

        # We now define our priors. Let's assume that we expect our enrichment
        # ratios to follow an exponential distribution. The log-enrichment factors
        # will then follow a exponential-exponential (Gumbel) distribution.
        # Note: We define 10 independent log-enrichment factors using the "shape"
        # argument.
        # Note: The "beta" parameter here is the inverse of the scale parameter
        # in Numpy/Scipy.
        self.log_enrichment = ssp.parameters.ExpExponential(beta=10.0, shape=(10,))

        # We reason that the input and output counts are multinomially distributed
        # with some unknown frequency, which are the values we want to infer. To
        # handle potentially small values, we will use an Exp-Dirichlet prior to
        # model the log-frequencies.
        self.log_input_freqs = ssp.parameters.ExpDirichlet(alpha=1.0, shape=(10,))

        # From the log-input frequencies and log-enrichment factors, we can define
        # a transformation that takes us to the output frequencies. We're in log
        # space, so this is just addition followed by normalization. Note that,
        # currently, all reductions and normalizations are performed over the last
        # axis (this cannot be changed yet).
        self.log_output_freqs = ssp.operations.normalize_log(
            self.log_input_freqs + self.log_enrichment
        )

        # Finally, we can model our observed counts at both the beginning and end
        # as multinomially distributed. Note that the name of the observable must
        # match the name we used when registering default data. If not registering
        # default data, you will need to provide the observables as keyword arguments
        # in the relevant functions (again, with matching names).
        # Note: We are using an alternate parametrization of the multinomial distribution
        # here to keep in log space.
        # Note: Numpy broadcasting rules apply, so the below will use the same
        # 10 log-frequencies for all 3 replicates. Note that this is why we need
        # `keepdims=True` when summing the counts to get `N`: (shapes (3, 1) and
        # (10,) broadcast to (3 x 10), while (3,) and (10,) do not broadcast.
        self.input_counts = ssp.parameters.MultinomialLogTheta(
            log_theta=self.log_input_freqs,
            N=input_counts.sum(axis=-1, keepdims=True),
            shape=(3, 10),
        )
        self.output_counts = ssp.parameters.MultinomialLogTheta(
            log_theta=self.log_output_freqs,
            N=output_counts.sum(axis=-1, keepdims=True),
            shape=(3, 10),
        )

That's it! Model is defined. SciStanPy models are, of course, Python classes, and can be extended in all the usual ways. This can allow for the construction of class hierarchies that greatly reduce the need for duplicated code.

Now, what can we do with our model? Let's create an instance of it and test out some SciStanPy operations. First up, let's do a prior predictive check:

In [4]:
# Build an instance
EXAMPLE_MODEL = DMSModel(
    input_counts=SAMPLE_DATA["INPUT_COUNTS"],
    output_counts=SAMPLE_DATA["OUTPUT_COUNTS"]
)

# Run a prior predictive check
EXAMPLE_MODEL.prior_predictive()

BokehModel(combine_events=True, render_bundle={'docs_json': {'377d1304-c669-48d3-b213-0717e6b3c979': {'version…

The `prior_predictive` function brings up an interactive dashboard that lets you test out the effects of different values for hyperparameters on model observables and parameters. By default, updating any parameters with the sliders (and subsequently clicking "update model") will also update their values in the model. This allows you to explore model hyperparameter values interactively before moving on to fitting a model.

Depending on the parameter selected, you may want to increase the value for "Number of Experiments"--this is the number of draws made from the model to build the figure. The default ECDF view flattens the displayed array before calculation, which obscures any relationships within or between variables. You can additionally choose values for "Group By", which will plot separate lines (or violins) for each grouping dimension; any constants with appropriate dimensionality can also be selected as "Independent Variable" to plot relationships between components.

Let's say we are happy with our hyperparameter selection. We're not yet ready to commit to full MCMC sampling, but we want to get an estimate for our parameter values. We can perform a maximum likelihood estimate using PyTorch as the backend:

In [5]:
MLE = EXAMPLE_MODEL.mle(lr=0.01)

Epochs:   8%|▊         | 8270/100000 [00:30<05:41, 268.57it/s, -log pdf/pmf=245.26] 


By default, maximum likelihood estiation will run for 100,000 steps or until the loss (negative log likelihood) has not decreased for 10 steps, whichever comes first. The output results object exposes a lot of additional functionality which is covered in great detail in the documentation. For the purposes of this example notebook, however, we'll look at just two: Extracting maximum likelihood estimates and bootstrapping observations:

In [6]:
MLE_ESTIMATES = {k: v.mle for k, v in MLE.model_varname_to_mle.items()}
MLE_ESTIMATES

{'log_enrichment': array([-0.92663609, -3.09016203, -2.57271679, -2.7671538 , -2.67913775,
        -3.40916818, -4.39909115, -2.60754036, -0.87086089, -3.94263251]),
 'log_input_freqs': array([-2.31168775, -1.04707487, -2.53073436, -2.33456603, -2.91141387,
        -2.95025464, -2.54946068, -6.09458604, -2.15806421, -2.64736437]),
 'input_counts': None,
 'output_counts': None}

A dictionary of results linking model varnames to MLE estimates can be accessed with the `model_to_varname_property`. Note that observables do not have an MLE--observables are ground truth, so there is nothing to estimate.

Let's compare the estimates above to our known values from our simulated experiment:

In [7]:
hv.Scatter(
    data={
        "x": SAMPLE_DATA["LOG_ENRICHMENT_FACTORS"],
        "y": MLE_ESTIMATES["log_enrichment"],
    }
)

In [8]:
hv.Scatter(
    data={"x": SAMPLE_DATA["LOG_INPUT_FREQS"], "y": MLE_ESTIMATES["log_input_freqs"]}
)

As we'd expect considering we know the exact generative process, pretty good alignment between MLE and true values. Obviously, in practice, we do not know the generative process--our model defines what we believe it to be, then we fit the model and evaluate how well the fit model describes the data. One simple way to evaluate goodness of fit is to bootstrap samples from the MLE and perform a posterior predictive check on the results:

In [9]:
INFERENCE_OBJ = MLE.get_inference_obj() # Bootstrapping
INFERENCE_OBJ.run_ppc() # Checking fit

BokehModel(combine_events=True, render_bundle={'docs_json': {'02dd3b90-7906-4944-a9e2-93d68e27f0c3': {'version…

The above example is contrived, so the plots look a little too pristine compared to what you'd get in a real-world setting. Here's how to interpret them, though:

1. The first plot shows the true rank (x-axis) of an observation against its values bootstrapped from the model. The x-axis does not mean anything in this case--it is shown by true rank to simplify data presentation when there are thousands or more observed points. The y-axis, however, shows the distribution of bootstrapped values (grey) and associated observed values (gold). A model that is fit well will have most gold dots falling within the shaded regions.
2. The next plot has the same axis, but now shows the quantile of true observations relative to the bootstrapped distribution of observations. This figure is also designed to work with thousands or more data points, so it aggregates nearby points into hexagonal bins. It is not the most useful figure for data at the scale of this example; however, for larger datasets, what you would want to see is uniform distribution of probability across all quantiles at all values with the median line (grey) running down the center.
3. The final plot is a quantile-quantile plot. Effectively, it is an ECDF over the quantile values for observables relative to bootstrapped samples. Because, by definition, each quantile should be uniformly represented in a sufficiently large sample from a perfectly calibrated, a perfectly calibrated (and fit) model will have a diagonal ECDF, indicated in the figure by the dashed line. The annotation "Absolute Deviance" gives the absolute area between the observed and ideal ECDF curves and can be used to measure how well calibrated and fit a given model is relative to another.

It should also be noted that the above plot is interactive! Use the dropdown to update the plot for different observables.

Now, one final note on the `INFERENCE_OBJ` variable. It exposes a special property `inference_obj` which is an ArviZ `InferenceData` instance holding all bootstrapped data. Use it to plug directly into the ArviZ ecosystem:

In [10]:
INFERENCE_OBJ.inference_obj  # ArviZ InferenceData instance

Alright, we're satisfied with our maximum likelihood estimation and ready to move on to full monte carlo sampling with Stan. Just call the below and SciStanPy will:

1. Convert your SciStanPy model into Stan code.
2. Compile that code.
3. Run that code.
4. Organize and return the results.

In [11]:
HMC_RES = EXAMPLE_MODEL.mcmc()

14:02:40 - cmdstanpy - INFO - compiling stan file /tmp/tmp2an6vi4v/model.stan to exe file /tmp/tmp2an6vi4v/model


14:02:50 - cmdstanpy - INFO - compiled model executable: /tmp/tmp2an6vi4v/model
--- Translating Stan model to C++ code ---
bin/stanc --filename-in-msg=model.stan --warn-pedantic --O1 --include-paths=/home/bwittmann/GitRepos/SciStanPy/scistanpy/model/stan --o=/tmp/tmp2an6vi4v/model.hpp /tmp/tmp2an6vi4v/model.stan
    no prior is provided, or the prior(s) depend on data variables. In the
    later case, this may be a false positive.
    prior is provided, or the prior(s) depend on data variables. In the later
    case, this may be a false positive.

--- Compiling C++ code ---
g++ -std=c++17 -pthread -D_REENTRANT -Wno-sign-compare -Wno-ignored-attributes -Wno-class-memaccess     -DSTAN_THREADS -I stan/lib/stan_math/lib/tbb_2020.3/include    -O3 -I src -I stan/src -I stan/lib/rapidjson_1.1.0/ -I lib/CLI11-1.9.1/ -I stan/lib/stan_math/ -I stan/lib/stan_math/lib/eigen_3.4.0 -I stan/lib/stan_math/lib/boost_1.87.0 -I stan/lib/stan_math/lib/sundials_6.1.1/include -I stan/lib/stan_math/lib/sundi

chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status

                                                                                                                                                                                                                                                                                                                                

14:02:51 - cmdstanpy - INFO - CmdStan done processing.





Converting CSV to NetCDF: 100%|██████████| 4/4 [00:16<00:00,  4.03s/it]


You can ignore warnings about parameters not having priors. This is a result of how data is passed to the model: that is, priors depend on data variables.

Now that we have the results, lets run some diagnostics:

In [12]:
_ = HMC_RES.diagnose()

Sample diagnostic tests results' summaries:
-------------------------------------------
0 of 4000 (0.00%) samples had a low energy.
0 of 4000 (0.00%) samples reached the maximum tree depth.
0 of 4000 (0.00%) samples diverged.

R_hat diagnostic tests results' summaries:
------------------------------------------
0 of 10 (0.00%) r_hats tests failed for log_enrichment.
0 of 10 (0.00%) r_hats tests failed for log_input_freqs.
0 of 10 (0.00%) r_hats tests failed for log_output_freqs.

Ess_bulk diagnostic tests results' summaries:
---------------------------------------------
0 of 10 (0.00%) ess_bulks tests failed for log_enrichment.
0 of 10 (0.00%) ess_bulks tests failed for log_input_freqs.
0 of 10 (0.00%) ess_bulks tests failed for log_output_freqs.

Ess_tail diagnostic tests results' summaries:
---------------------------------------------
0 of 10 (0.00%) ess_tails tests failed for log_enrichment.
0 of 10 (0.00%) ess_tails tests failed for log_input_freqs.
0 of 10 (0.00%) ess_tails tests

The diagnostic tests and their meanings are described in greater detail in the full documentation. For our purposes here, what matters is that they all passed! 

We stored the output of the function in a throw away variable, as we don't need it. However, it contains indices of failed samples and variables, as relevant. Also note that the `SampleResults` object has additional functionality for helping to diagnose failed samples, when they're present, via the `plot_sample_failure_quantile_traces` and `plot_variable_failure_quantile_traces` methods. See the full documentation for details on these methods.

Finally, we're going to want to do a posterior predictive check on our Stan samples. Note that, unlike the MLE example, these samples should be representative of the full posterior, not just the MLE. The same workflow as for evaluating MLE applies here, however:

In [13]:
HMC_RES.run_ppc()

BokehModel(combine_events=True, render_bundle={'docs_json': {'e2c79cc2-fa83-48a4-a169-0b73dcd980d0': {'version…

As before, the fit looks reasonable. Also as before, note that we can access an underlying ArviZ `InferenceData` object for further analysis. It will have some additional fields compared to the MLE-associated one, reflecting the richer information content of HMC samples:

In [14]:
HMC_RES.inference_obj