# Modeling microtubule catastrophe II

<hr>

In [1]:
# Colab setup ------------------
import os, sys, subprocess
if "google.colab" in sys.modules:
    cmd = "pip install --upgrade iqplot colorcet bebi103 arviz cmdstanpy watermark"
    process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()
    import cmdstanpy; cmdstanpy.install_cmdstan()
    data_path = "https://raw.githubusercontent.com/justinbois/learnbayes-livecode/main/"
else:
    data_path = "./"
# ------------------------------

import numpy as np
import pandas as pd

import tqdm

import cmdstanpy
import arviz as az

import bebi103

import iqplot

import bokeh.io
bokeh.io.output_notebook()

<hr>

We have thoroughly investigated the process by which microtubules undergo catastrophe using data from the [Gardner, Zanic, et al. paper](../protected/papers/gardner_2011.pdf). We used an Exponential, Gamma, Weibull, and our custom two-step distribution to model the catastrophe times. But we are unsatsified with that because we expect an *integer* number $m$ of Poisson processes to arrive sequentially in order for catastrophe to occur. The Exponential and two-step model are special cases we have already considered where $m = 1$ and $m = 2$, respectively. Now, we will consider a model for arbitrary $m$ and assess which $m$ gives the most plausible generative model.

Of course, to begin, we need to load in the data.

In [2]:
t = np.loadtxt(os.path.join(data_path, 'gardner_zanic_mt_catastrophe.csv'))
data = dict(t=t, N=len(t))

## The model for m-step catastrophe

The probability density function for the time $t$ to catastrophe triggered by the arrival of $m$ Poisson processes may be shown to be (after some real grunge)

\begin{align}
f(t\mid \tau_1, \tau_2, \ldots, \tau_m) = \sum_{j=1}^m \frac{\tau_j^{m-2}\,\mathrm{e}^{-t/\tau_j}}{\prod_{k=1,k\ne j}^m (\tau_j - \tau_k)}.
\end{align}

For clarity, the likelihoods for the first few $m$ are

\begin{align}
f(t\mid \tau_1) &= \frac{\mathrm{e}^{-t/\tau_1}}{\tau_1},\\[1em]
f(t\mid \tau_1, \tau_2) &=
\frac{\mathrm{e}^{-t/\tau_1}}{\tau_1 - \tau_2} + \frac{\mathrm{e}^{-t/\tau_2}}{\tau_2 - \tau_1}
= \frac{\mathrm{e}^{-t/\tau_2} - \mathrm{e}^{-t/\tau_1}}{\tau_2 - \tau_1}, \\[1em]
f(t\mid \tau_1, \tau_2, \tau_3) &=
\frac{\tau_1^\,\mathrm{e}^{-t/\tau_1}}{(\tau_1 - \tau_2)(\tau_1-\tau_3)}
+\frac{\tau_2\,\mathrm{e}^{-t/\tau_2}}{(\tau_2 - \tau_1)(\tau_2-\tau_3)}
+\frac{\tau_3\,\mathrm{e}^{-t/\tau_3}}{(\tau_3 - \tau_1)(\tau_3-\tau_2)},\\[1em]
f(t\mid \tau_1, \tau_2, \tau_3, \tau_4) &=
\frac{\tau_1^2\,\mathrm{e}^{-t/\tau_1}}{(\tau_1 - \tau_2)(\tau_1-\tau_3)(\tau_1 - \tau_4)}
+\frac{\tau_2^2\,\mathrm{e}^{-t/\tau_2}}{(\tau_2 - \tau_1)(\tau_2-\tau_3)(\tau_2 - \tau_4)} \\[1em]
&\;\;\;\;\;+\frac{\tau_3^2\,\mathrm{e}^{-t/\tau_3}}{(\tau_3 - \tau_1)(\tau_3-\tau_2)(\tau_3 - \tau_4)}
+\frac{\tau_4^2\,\mathrm{e}^{-t/\tau_4}}{(\tau_4 - \tau_1)(\tau_4-\tau_2)(\tau_4 - \tau_3)}.
\end{align}

We will use the same priors as before for the $\tau$'s, noting that we will enforce ordering such that $\tau_1 < \tau_2 < \cdots < \tau_m$.

\begin{align}
&\log_{10}\tau_j \sim \text{Norm}(1/2, 0.75)\;\forall\;j\in[1,m],\\[1em]
&f(t_i\mid \tau_1, \tau_2, \ldots, \tau_m) = f(t\mid \tau_1, \tau_2, \ldots, \tau_m) = \sum_{j=1}^m \frac{\tau_j^{m-2}\,\mathrm{e}^{-t/\tau_j}}{\prod_{k=1,k\ne j}^m (\tau_j - \tau_k)}\;\forall i.
\end{align}

## Prior preditive checks

This time, just for fun, for our prior predictive checks, we will use Stan to do the sampling.

In [3]:
multistep_ppc = """
data {
  int m;
  int N;
}


generated quantities {
  array[m] real tau;
  array[N] real t;
  
  for (i in 1:m) {
    tau[i] = 10 ^ normal_rng(0.5, 0.75);
  }
  
  for (i in 1:N) {
    t[i] = exponential_rng(1.0 / tau[1]);
    for (j in 2:m) {
      t[i] += exponential_rng(1.0 / tau[j]);
    }
  }
}
"""

Let's compile the model and draw our prior predictive samples for 1 ≤ _m_ ≤ 6.

In [4]:
with open("multistep_ppc.stan", "w") as f:
    f.write(multistep_ppc)
    
sm_prior_pred = cmdstanpy.CmdStanModel(stan_file="multistep_ppc.stan")

plots = []
for m in range(1, 7):
    data_ppc = dict(m=m, N=len(t))
    with bebi103.stan.disable_logging():
        samples_prior_pred = sm_prior_pred.sample(
            data=data_ppc,
            iter_sampling=1000,
            fixed_param=True,
            show_progress=False,
        )

    samples_prior_pred = az.from_cmdstanpy(
        prior=samples_prior_pred, prior_predictive="t"
    )

    plots.append(
        bebi103.viz.predictive_ecdf(
            samples_prior_pred.prior_predictive.t,
            x_axis_label="time (min)",
            title=f"m = {m}",
            frame_height=150,
            frame_width=200,
            x_axis_type='log',
        )
    )

bokeh.io.show(bokeh.layouts.gridplot(plots, ncols=2))

23:42:38 - cmdstanpy - INFO - compiling stan file /Users/bois/Dropbox/git/learnbayes-livecode/multistep_ppc.stan to exe file /Users/bois/Dropbox/git/learnbayes-livecode/multistep_ppc
23:42:50 - cmdstanpy - INFO - compiled model executable: /Users/bois/Dropbox/git/learnbayes-livecode/multistep_ppc


This prior predictive ECDFs look good, like they could encompass what we might see. Let's proceed to code up the model.

We have to be careful here in many respects.

1. We need to ensure that the $\tau$'s are sorted to maintain indentifiability so that we do not get a label switching degeneracy.
2. The log likelihood is the logarithm of the (signed) sum of exponentials. We need to be careful about underflow or overflow when doing this, so we use the logsumexp trick.
3. We have to take care with signs to make sure we get the signs right in the sum and also do not take the logarithm of a negative number.

With these considerations in mind, I came up with the code below.

In [5]:
multistep_model = """
functions {
  real sign(int j, int m) {
    /* Sign of respective term in sum of PDF. If both m and j are
     * even or both are odd, then sign is positive. Otherwise, sign
     * is negative.
     */
    if ((m % 2) + (j % 2) == 1) return -1.0;

    return 1.0;
  }
  

  real signed_log_sum_exp(vector terms, row_vector signs) {
    // log-sum-exp trick with signed arguments.
    real max_term = max(terms);
    vector[size(terms)] adjusted_terms = terms - max_term;      

    return max_term + log(signs * exp(adjusted_terms));
  }
  

  real catastrophe_lpdf(real t, vector tau, row_vector signs, int m) {
    real log_pdf;
    real log_prod;
    vector[m] terms;
    
    // For case where m = 1, it's just Exponential
    if (m == 1) {
      log_pdf = exponential_lpdf(t | 1.0 / tau[1]);
    }

    // For all other cases, first compute summed terms of PDF    
    else {
      for (j in 1:m) {
        log_prod = 0.0;
        for (k in 1:m) {
          if (j != k) {
            log_prod += log(abs(tau[j] - tau[k]));
          }
        }
        terms[j] = (m - 2) * log(tau[j]) - t / tau[j] - log_prod;
      }
      
      // Compute log PDF using logsumexp trick 
      log_pdf = signed_log_sum_exp(terms, signs);
    }
    
    return log_pdf;
  }
}


data {
  int m;
  int N;
  array[N] real t;
}


transformed data {
  row_vector[m] signs;

  for (j in 1:m) {
    signs[j] = sign(j, m);
  }
}


parameters {
  ordered[m] log_tau;
}


transformed parameters {
  vector[m] tau = 10 ^ log_tau;
}

model {
  log_tau ~ normal(0.5, 0.75);
  
  for (i in 1:N) {
    target += catastrophe_lpdf(t[i] | tau, signs, m);
  }
}


generated quantities {
  array[N] real log_lik;
  array[N] real t_ppc;
  
  for (i in 1:N) {
    log_lik[i] = catastrophe_lpdf(t[i] | tau, signs, m);
    t_ppc[i] = exponential_rng(1.0 / tau[1]);

    for (j in 2:m) {
      t_ppc[i] += exponential_rng(1.0 / tau[j]);
    }
  }
}
"""

Let's compile the model.

In [6]:
with open("multistep.stan", "w") as f:
    f.write(multistep_model)
    
sm = cmdstanpy.CmdStanModel(stan_file='multistep.stan')

23:42:57 - cmdstanpy - INFO - compiling stan file /Users/bois/Dropbox/git/learnbayes-livecode/multistep.stan to exe file /Users/bois/Dropbox/git/learnbayes-livecode/multistep
23:43:11 - cmdstanpy - INFO - compiled model executable: /Users/bois/Dropbox/git/learnbayes-livecode/multistep


I will now sample out of this distribution for $1 \le m \le 6$. I did some tests with various models, and found that I should adjust `adapt_delta` down to 0.975 to deal with divergences and we should allow for a bit bigger recursive tree depth.

In [7]:
samples_dict = {}
for m in tqdm.tqdm(range(1, 7)):
    # Load pre-computed to save time; if not present, run and save
    try:
        samples = az.from_netcdf(os.path.join(data_path, f"samples_m{m}.nc"))
    except:
        data = dict(N=len(t), t=t, m=m)
        with bebi103.stan.disable_logging():
            samples = sm.sample(
                data=data, adapt_delta=0.975, max_treedepth=15, show_progress=False
            )

        samples = az.from_cmdstanpy(
            posterior=samples, posterior_predictive="t_ppc", log_likelihood="log_lik"
        )

        # Try to save (can only save locally, not using Colab)
        try:
            samples.to_netcdf(os.path.join(data_path, f"samples_m{m}.nc"))
        except:
            pass

    # Store results in a dictionary
    samples_dict[m] = samples

100%|█████████████████████████████████████████████| 6/6 [00:01<00:00,  3.71it/s]


Of course, we should check the diagnostics.

In [8]:
for m, samples in samples_dict.items():
    print(f"\n\nChecking diagnostics for m = {m}....")
    bebi103.stan.check_all_diagnostics(samples, max_treedepth=15)



Checking diagnostics for m = 1....
Effective sample size looks reasonable for all parameters.

Rhat looks reasonable for all parameters.

0 of 4000 (0.0%) iterations ended with a divergence.

0 of 4000 (0.0%) iterations saturated the maximum tree depth of 15.

E-BFMI indicated no pathological behavior.


Checking diagnostics for m = 2....
Effective sample size looks reasonable for all parameters.

Rhat looks reasonable for all parameters.

0 of 4000 (0.0%) iterations ended with a divergence.

0 of 4000 (0.0%) iterations saturated the maximum tree depth of 15.

E-BFMI indicated no pathological behavior.


Checking diagnostics for m = 3....
Effective sample size looks reasonable for all parameters.

Rhat looks reasonable for all parameters.

1 of 4000 (0.025%) iterations ended with a divergence.
  Try running with larger adapt_delta to remove divergences.

0 of 4000 (0.0%) iterations saturated the maximum tree depth of 15.

E-BFMI indicated no pathological behavior.


Checking diagnost

A couple warnings, but overall a clean bill of health!

Now, let's make posterior predictive check plots for each of these.

In [9]:
plots = []
for m, samples in samples_dict.items():
    t_ppc = samples.posterior_predictive.t_ppc.stack(
        {"sample": ("chain", "draw")}
    ).transpose("sample", "t_ppc_dim_0")

    plots.append(
        bebi103.viz.predictive_ecdf(
            t_ppc,
            data=t,
            diff="ecdf",
            x_axis_label="time to catastrophe (min)",
            title=f"m = {m}",
            frame_height=150,
            frame_width=200,
        )
    )

bokeh.io.show(bokeh.layouts.gridplot(plots, ncols=2))

These are interesting. For $m=1$ and $m=2$, the model fails pretty spectacularly to capture the real data set. There is essentially no difference between $m=3$ and any plot with $m > 3$. We can explore this a bit more when we plot CDFs of the marginalized posterior distributions for the τ's.

In [10]:
plots = []
for m, samples in samples_dict.items():
    p = iqplot.ecdf(
        samples.posterior.tau.sel(tau_dim_0=0).values.flatten(),
        frame_width=250,
        frame_height=150,
        x_axis_label="τ (min)",
        y_axis_label="ECDF of samples",
        title=f"m = {m}",
        style="staircase",
    )
    for i in range(2, m + 1):
        p = iqplot.ecdf(
            samples.posterior.tau.sel(tau_dim_0=i - 1).values.flatten(),
            p=p,
            style="staircase",
        )
    plots.append(p)

for i in range(1, len(plots)):
    plots[i].x_range = plots[0].x_range

bokeh.io.show(bokeh.layouts.gridplot(plots, ncols=2))

We see that we get three different $\tau$'s for the $m=3$ case. As $m$ grows, each successive model adds smaller and smaller $\tau$'s. This means that there are really only three slow processes that we can resolve, and everything else is fast. This is reminiscent of the notion of rate limiting steps in chemical kinetics. 

At this point, we can conclude that under this model, there are three rate limiting steps. There are still some issues with the posterior predictive check, in that there are uncomfortably many data points falling outside of what we might get form our posterior-informed generative model. Nonetheless, under this model, the $m = 3$ case is the clear winner.

Let's take a look at a corner plot for this case.

In [11]:
bokeh.io.show(
    bebi103.viz.corner(samples_dict[3], parameters=["tau[0]", "tau[1]", "tau[2]"])
)

Finally, just to put the icing on the cake, we can compute a model comparison using LOO and stacking.

In [12]:
az.compare(samples_dict, ic="loo", scale="deviance")

Unnamed: 0,rank,elpd_loo,p_loo,elpd_diff,weight,se,dse,warning,scale
3,0,3609.822361,1.80216,0.0,1.0,43.847587,0.0,False,deviance
4,1,3611.839217,2.602685,2.016855,1.21608e-13,44.619828,2.444892,False,deviance
5,2,3613.60828,2.801736,3.785919,5.318485e-13,44.842419,3.522113,False,deviance
6,3,3615.192761,2.885204,5.370399,3.466387e-13,44.979185,4.360552,False,deviance
2,4,3659.499275,0.764064,49.676914,0.0,36.762501,11.922048,False,deviance
1,5,3941.992349,0.362171,332.169987,0.0,32.008015,26.346484,False,deviance


This further cements our conclusion that $m = 3$ is the best model. But we really did not need to do model comparison; it was apparent from the parameter estimation that there may be several fast steps, but there are three slow ones.

## Computing enviroment

In [13]:
%load_ext watermark
%watermark -v -p numpy,pandas,cmdstanpy,arviz,iqplot,bebi103,bokeh,jupyterlab
print("cmdstan   :", bebi103.stan.cmdstan_version())

Python implementation: CPython
Python version       : 3.11.4
IPython version      : 8.12.0

numpy     : 1.24.3
pandas    : 1.5.3
cmdstanpy : 1.1.0
arviz     : 0.16.1
iqplot    : 0.3.3
bebi103   : 0.1.14
bokeh     : 3.2.1
jupyterlab: 4.0.3

cmdstan   : 2.32.2
