In [None]:
import pymc as pm
import pytensor.tensor as pt
import numpy as np
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns

from utils import matplotlib_style
cor, pal = matplotlib_style()

# Set random seed
RANDOM_SEED = 42
rng = np.random.default_rng(RANDOM_SEED)

# Single-molecule mRNA FISH inference

In this notebook, we will explore the validity of our statistical model using
single-molecule mRNA FISH data from [this paper from the Elowitz
lab](https://doi.org/10.1016/j.molcel.2014.06.029)

Let us begin by importing the data.

In [None]:
# Import tidy dataframe with mRNA counts
df_counts = pd.read_csv('../data/singer_transcript_counts.csv', comment='#')

df_counts.head(5)

We can see that the data has counts for four different genes: `Rex1`, `Rest`,
`Nanog`, and `Prdm14`. Let's plot the individual ECDFs for each of these genes.

In [None]:
# Initialize figure
fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5))

# Extract column names
genes = df_counts.columns

# Loop throu each gene
for (i, gene) in enumerate(genes):
    # Plot the ECDF for each column in the DataFrame
    sns.ecdfplot(
        data=df_counts,
        x=gene,
        ax=ax,
        color=pal[i],
        label=gene,
        lw=1
    )

# Add axis labels
ax.set_xlabel('mRNA count')
ax.set_ylabel('ECDF')

# Add legend
ax.legend()

This seems like a very good dataset to tests the assumptions of our model as
there is clear variability between the different genes.

## Negative Binomial-Dirichlet-Multinomial model (`scrappy`)

Our main full model assumes that the joint distribution over all genes mRNA
counts is given by the product of independent negative binomial distributions
for each gene, with the strong assumption that all genes share the same $p$
parameter (the probability of success in each trial). Therefore, our inference
is over the parameters
$$
\underline{r} = [r_1, r_2, \cdots, r_G],
\tag{1}
$$
where $G$ is the number of genes, and $p$, the probability of success in each
trial. By Bayes' theorem, we have
$$
\pi(\underline{r}, p | \underline{\underline{M}}) \propto 
\pi(\underline{\underline{M}} | \underline{r}, p) \pi(\underline{r}, p),
\tag{2}
$$
where
$$
\underline{\underline{M}} = \begin{bmatrix} 
\lvert & \lvert & \cdots & \lvert \\
\underline{m}^{(1)} & \underline{m}^{(2)} & \cdots & \underline{m}^{(C)} \\
\lvert & \lvert & \cdots & \lvert 
\end{bmatrix},
\tag{3}
$$
is the data matrix with $C$ cells and each column $\underline{m}^{(c)}$ is the
transcriptional profile of cell $c$, i.e., the mRNA counts for each gene in
cell.

On top of assuming each gene is drawn from an indepdendent negative binomial, we
assume each cell is indepedent, allowing us to write the likelihood as
$$
\pi(\underline{\underline{M}} | \underline{r}, p) =
\prod_{c=1}^C \prod_{g=1}^G \pi(m_g^{(c)} | r_g, p),
\tag{4}
$$
where $m_g^{(c)}$ is the mRNA count of gene $g$ in cell $c$. Furthermore, we 
assume that each parameter is independent and that all $r_g$ are drawn from the
same prior distribution. Therefore, we can write the prior as
$$
\pi(\underline{r}, p) = \pi(p) \pi(r)^G,
\tag{5}
$$
where $\pi(r)$ is the shared prior for all $r_g$.

Putting all of these together, we can write the posterior as
$$
\pi(\underline{r}, p | \underline{\underline{M}}) \propto
\pi(p) \pi(r)^G \prod_{c=1}^C \prod_{g=1}^G \pi(m_g^{(c)} | r_g, p).
\tag{6}
$$

As for the functional forms, we established that the likelihood is a negative
binomial distribution. It can be shown that the joint distribution of negative
binomials with shared $p$ parameter is given by the product of a negative
binomial for the total mRNA count and a dirichlet-multinomial for the partition
of the total mRNA count among the different genes. Therefore, we can write the
likelihood for each cell as
$$
\begin{aligned}
M^{(c)} | r_o, p &\sim \text{NegBinom}(r_o, p),\\
\underline{m}^{(c)} | M^{(c)},\underline{r} &\sim 
\text{DirMult}(M^{(c)}, \underline{r}), 
\end{aligned}
\tag{7}
$$
where $M^{(c)}$ is the total mRNA count in cell $c$ and $r_o = \sum_{g=1}^G
r_g$.

For the priors, we know that $p \in [0, 1]$ and that $r_g$ are positive. We can
therefore choose a $Beta$ prior for $p$ and a $Gamma$ prior for $r_g$. This
leads to
$$
p \sim \text{Beta}(\alpha_p, \beta_p),
\tag{8}
$$
and 
$$
r \sim \text{Gamma}(\alpha_r, \beta_r).
$$


Let's write this model in `PyMC`. First we define the data.

In [None]:
# Define total number of genes
n_genes = len(genes)

# Define total number of counts per cell as the sum per row
M_cells = df_counts.sum(axis=1)

# Define counts per cell
m_cells = df_counts.values

Now we define a `pm.Model`.

In [None]:
# Set model
with pm.Model() as scFISH_negbin_dirmult:
    # Define prior on p
    p = pm.Beta('p', alpha=1, beta=1)
    # Define prior on all r parameters
    r_vec = pm.Gamma('r_vec', alpha=2, beta=2, shape=df_counts.shape[1])

    # Sum of r parameters
    r_o = pm.math.sum(r_vec)

    # Likelihood for Total observed counts
    M = pm.NegativeBinomial("M", p=p, alpha=r_o, observed=M_cells)

    # Use Dirichlet-Multinomial distribution for observed counts
    m_vec = pm.DirichletMultinomial(
        "counts", n=M, a=r_vec, observed=m_cells
    )

Having defined the model, let's generate prior predictive samples.

In [None]:
# Use model to sample from the prior
with scFISH_negbin_dirmult:
    # sample from the prior
    prior_pred_check_dm = pm.sample_prior_predictive(
        draws=100, random_seed=rng)

Let's contrast the prior predictive samples with the data for the total counts.

In [None]:
# Initialize figure
fig, ax = plt.subplots(1, 2, figsize=(3, 1.5))

# Plot histogram of the real data total counts
ax[0].hist(
    M_cells,
    bins=range(0, max(M_cells)),
    alpha=0.75,
    label='data',
    density=True
)

# Plot histogram of prior predictive checks total counts
ax[0].hist(
    prior_pred_check_dm.prior_predictive.M.values.flatten(),
    bins=range(0, max(M_cells)),
    alpha=0.75,
    label='PPC',
    density=True
)

# Set log scale on y axis
ax[0].set_yscale('log')

# Plot ECDF of the real data total counts
sns.ecdfplot(
    M_cells,
    ax=ax[1],
    label='data',
)

# Plot ECDF of the prior predictive checks total counts
sns.ecdfplot(
    prior_pred_check_dm.prior_predictive.M.values.flatten(),
    ax=ax[1],
    label='PPC',
)

# Add legend
ax[1].legend()

# Add axis labels
ax[0].set_xlabel('total counts')
ax[1].set_xlabel('total counts')
ax[0].set_ylabel('density')
ax[1].set_ylabel('ECDF')

plt.tight_layout()

Now, we can sample from the posterior using the NUTS sampler.

In [None]:
# Perform MCMC sampling with 4 chains
with scFISH_negbin_dirmult:
    trace = pm.sample(4000, tune=1000, chains=4, cores=4)

Let's take a look at the trace for each of the chains. For this, we will use the
convenient `plot_trace` function from `ArviZ`.

In [None]:
# Plot trace
az.plot_trace(trace, compact=False)

plt.tight_layout()

In general, what we look for are traces that are stable and do not show any
obvious pathologies. These traces look reasonably stable, so we can proceed to
examine the posterior distributions.

In [None]:
# Plot posterior
az.plot_posterior(trace, var_names=['p', 'r_vec'])

Let's look at a corner plot of the parameters.

In [None]:
# Plot corner plot
axes = az.plot_pair(
    trace, var_names=['p', 'r_vec'], kind="scatter", marginals=True
)

Let's now sample from the posterior predictive distribution.

In [None]:
with scFISH_negbin_dirmult:
    post_pred_check_dm = pm.sample_posterior_predictive(trace)

Having these posterior predictive (retrodictive) checks in place allows us to
compare the model's predictions with the observed data. Let's plot the ECDFs for
the total mRNA count.

In [None]:
# Set random seed
rng = np.random.default_rng(42)

# Initialize figure
fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5))

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.M.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.M.values.shape[1]),
    size=n_samples
)

# Loop through samples
for i in range(n_samples):
    # Plot ECDF of the posterior predictive checks total counts
    sns.ecdfplot(
        post_pred_check_dm.posterior_predictive.M.values[x_idx[i], y_idx[i], :],
        ax=ax,
        color='gray',
        alpha=0.1
    )

# Plot ECDF of the real data total counts
sns.ecdfplot(
    M_cells,
    ax=ax,
    label='data',
)

# Label axis
ax.set_xlabel('total counts')
ax.set_ylabel('ECDF')

plt.tight_layout()

Now, let's plot the ECDFs for the mRNA counts of each gene.

In [None]:
# Set random seed
rng = np.random.default_rng(42)

# Initialize figure
fig, axes = plt.subplots(2, 2, figsize=(3, 3))

# Flatten axes
axes = axes.flatten()

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.M.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.M.values.shape[1]),
    size=n_samples
)

# Loop through each gene
for (i, ax) in enumerate(axes):
    # Loop through samples
    for j in range(n_samples):
        # Plot ECDF of the posterior predictive checks total counts
        sns.ecdfplot(
            post_pred_check_dm.posterior_predictive.counts.values[x_idx[j],
                                                                  y_idx[j],
                                                                  :, i],
            ax=ax,
            color='gray',
            alpha=0.1
        )
    # Plot ECDF of the real data total counts
    sns.ecdfplot(
        m_cells[:, i],
        ax=ax,
        label='data',
    )
    # Label axis
    ax.set_xlabel('counts')
    ax.set_ylabel('ECDF')
    # Set title
    ax.set_title(genes[i])

plt.tight_layout()

The model is not completely able to capture the variability in the data. This
is caused by the strong assumption that all genes share the same $p$ parameter.

However, this needs to be contrasted with the state-of-the-art model for single
cell RNA-seq data

## Poisson-Multinomial model (`sanity`)

The method published by [Breda et
al.](https://www.nature.com/articles/s41587-021-00875-x), referred to as
`sanity`, is considered the state-of-the-art for single-cell RNA-seq data 
Bayesian analysis. The novelty of this method is to bring the power of the 
Bayesian paradigm to large-scale single-cell RNA-seq datasets. However, because
of the tremendous dimensionality of the parameter space, `sanity` uses several
approximations that make it computationally feasible, albeit at the cost of
loosing the ability to properly describe the data.

To show this, we will attempt to fit the basis of the `sanity` model without the 
approximations needed for computational feasibility on this sc-FISH dataset. 
Although `sanity`'s model has sc-RNAseq data in mind, the generative model they
define should be applicable regardless of how gene expression is measured.

The basic structure of the model consits of having each mRNA count be drawn from
a Poisson distribution whose parameter
$$
\lambda_g = \alpha_g \langle M \rangle,
\tag{9}
$$
is composed by the product of a gene-specific parameter $\alpha_g$ and the mean
number of total mRNA counts in a cell $\langle M \rangle$. The $\alpha_g$ 
parameters must satisfy the constraint that
$$
\sum_{g=1}^G \alpha_g = 1, \; \forall \; \alpha_g \geq 0,
\tag{10}
$$
i.e., they must form a probability simplex. This parameterization allows for
the normalization of the gene expression levels as the parameter that matters
for any experiment is this relative expression level.

With these two parameters in the model, by Bayes' theorem, we have
$$
\pi(\underline{\alpha}, \langle M \rangle | \underline{\underline{M}}) \propto
\pi(\underline{\underline{M}} | \underline{\alpha}, \langle M \rangle)
\pi(\underline{\alpha}, \langle M \rangle).
\tag{11}
$$
As before, we assume that both genes and cells are independent, so the
likelihood can be expressed as
$$
\pi(\underline{\underline{M}} | \underline{\alpha}, \langle M \rangle) =
\prod_{c=1}^C \prod_{g=1}^G \pi(m_g^{(c)} | \alpha_g, \langle M \rangle).
\tag{12}
$$
For the prior, we cannot assume that the $\alpha_g$ are independent, as they
must add up to one. However, we can assume that the mean total mRNA count is
indepedent of the $\alpha_g$. Therefore, we can write the prior as
$$
\pi(\underline{\alpha}, \langle M \rangle) = 
\pi(\langle M \rangle) \pi(\underline{\alpha}).
\tag{13}
$$

For the functional forms, we know that the likelihood is a Poisson distribution,
i.e.,
$$
m_g^{(c)} | \alpha_g, \langle M \rangle \sim 
\text{Poisson}(\alpha_g \langle M \rangle).
\tag{14}
$$
The natural choice for the prior of the $\alpha_g$ is the Dirichlet
distribution, i.e.,
$$
\underline{\alpha} \sim \text{Dirichlet}(\underline{\beta}),
\tag{15}
$$
where $\underline{\beta}$ is a vector of concentration parameters. Finally, for
the prior of the mean total mRNA count, we can choose a strictly positive
distribution, such as the Gamma distribution or, as we will do below, a
lognormal distribution, i.e.,
$$
\langle M \rangle \sim \text{LogNormal}(\mu, \sigma).
\tag{16}
$$

One can show that the joint distribution of independent Poisson distributions
for each gene can be expressed as the product of a Poisson distribution for the
total mRNA and a multinomial distribution for the partition of the total mRNA
into the different genes. Therefore, the likelihood for each cell can be written
as
$$
\begin{aligned}
M^{(c)} | \langle M \rangle &\sim \text{Poisson}(\langle M \rangle),\\
\underline{m}^{(c)} | M^{(c)}, \underline{\alpha} &\sim
\text{Multinomial}(M^{(c)}, \underline{\alpha}).
\end{aligned}
\tag{17}
$$

We can now write the model in `PyMC`.

In [None]:
# Set model
with pm.Model() as scFISH_poisson_multinomial:
    # Define prior on r_o for total counts
    r_o = pm.LogNormal('r_o', mu=2, sigma=2.5)

    # Define prior on p vector from a Dirichlet distribution
    alpha_vec = pm.Dirichlet('alpha_vec', a=np.ones(df_counts.shape[1]))

    # Likelihood for Total observed counts
    M = pm.Poisson("M", mu=r_o, observed=M_cells)

    # Use Dirichlet-Multinomial distribution for observed counts
    m_vec = pm.Multinomial(
        "counts", n=M, p=alpha_vec, observed=m_cells
    )

Having defined the model, let's generate prior predictive samples.

In [None]:
# Use model to sample from the prior
with scFISH_poisson_multinomial:
    # sample from the prior
    prior_pred_check_pm = pm.sample_prior_predictive(
        draws=100, random_seed=rng)

Let's contrast the prior predictive samples with the data for the total counts.

In [None]:
# Initialize figure
fig, ax = plt.subplots(1, 2, figsize=(3, 1.5))

# Plot histogram of the real data total counts
ax[0].hist(
    M_cells,
    bins=range(0, max(M_cells)),
    alpha=0.75,
    label='data',
    density=True
)

# Plot histogram of prior predictive checks total counts
ax[0].hist(
    prior_pred_check_pm.prior_predictive.M.values.flatten(),
    bins=range(0, max(M_cells)),
    alpha=0.75,
    label='PPC',
    density=True
)

# Set log scale on y axis
ax[0].set_yscale('log')

# Plot ECDF of the real data total counts
sns.ecdfplot(
    M_cells,
    ax=ax[1],
    label='data',
)

# Plot ECDF of the prior predictive checks total counts
sns.ecdfplot(
    prior_pred_check_pm.prior_predictive.M.values.flatten(),
    ax=ax[1],
    label='PPC',
)

# Seet ylim
ax[1].set_ylim(0, 1.05)
# Add legend
ax[1].legend()

# Add axis labels
ax[0].set_xlabel('total counts')
ax[1].set_xlabel('total counts')
ax[0].set_ylabel('density')
ax[1].set_ylabel('ECDF')

plt.tight_layout()

Now, we can sample from the posterior using the NUTS sampler.

In [None]:
# Perform MCMC sampling with 4 chains
with scFISH_poisson_multinomial:
    trace_pm = pm.sample(4000, tune=1000, chains=4, cores=4)

Let's take a look at the trace for each of the chains.

In [None]:
# Plot trace
az.plot_trace(trace_pm, compact=False)

plt.tight_layout()

Again, the traces look reasonably stable, so we can proceed to examine the
posterior distributions.

In [None]:
# Plot corner plot
axes = az.plot_pair(
    trace_pm, var_names=['r_o', 'alpha_vec'], kind="scatter", marginals=True
)

Let's now sample from the posterior predictive distribution.

In [None]:
with scFISH_poisson_multinomial:
    post_pred_check_pm = pm.sample_posterior_predictive(trace_pm)

Having these posterior predictive (retrodictive) checks in place allows us to
compare the model's predictions with the observed data. Let's plot the ECDFs for
the total mRNA count.

In [None]:
# Set random seed
rng = np.random.default_rng(42)

# Initialize figure
fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5))

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(post_pred_check_pm.posterior_predictive.M.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(post_pred_check_pm.posterior_predictive.M.values.shape[1]),
    size=n_samples
)

# Loop through samples
for i in range(n_samples):
    # Plot ECDF of the posterior predictive checks total counts
    sns.ecdfplot(
        post_pred_check_pm.posterior_predictive.M.values[x_idx[i], y_idx[i], :],
        ax=ax,
        color='gray',
        alpha=0.1
    )

# Plot ECDF of the real data total counts
sns.ecdfplot(
    M_cells,
    ax=ax,
    label='data',
)

# Label axis
ax.set_xlabel('total counts')
ax.set_ylabel('ECDF')

plt.tight_layout()

Now, let's plot the ECDFs for the mRNA counts of each gene.

In [None]:
# Set random seed
rng = np.random.default_rng(42)

# Initialize figure
fig, axes = plt.subplots(2, 2, figsize=(3, 3))

# Flatten axes
axes = axes.flatten()

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(post_pred_check_pm.posterior_predictive.M.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(post_pred_check_pm.posterior_predictive.M.values.shape[1]),
    size=n_samples
)

# Loop through each gene
for (i, ax) in enumerate(axes):
    # Loop through samples
    for j in range(n_samples):
        # Plot ECDF of the posterior predictive checks total counts
        sns.ecdfplot(
            post_pred_check_pm.posterior_predictive.counts.values[x_idx[j],
                                                                  y_idx[j],
                                                                  :, i],
            ax=ax,
            color='gray',
            alpha=0.1
        )
    # Plot ECDF of the real data total counts
    sns.ecdfplot(
        m_cells[:, i],
        ax=ax,
        label='data',
    )
    # Label axis
    ax.set_xlabel('counts')
    ax.set_ylabel('ECDF')
    # Set title
    ax.set_title(genes[i])

plt.tight_layout()

This is obviously a terrible fit. The Poisson model is not able to capture the
overdispersion in the data. To emphasize this even more, let's plot both
posterior predictive distributions for the total mRNA count.

In [None]:
# Set random seed
rng = np.random.default_rng(42)

# Initialize figure
fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5))

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.M.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.M.values.shape[1]),
    size=n_samples
)

# Loop through samples
for i in range(n_samples):
    # Plot ECDF of the posterior predictive checks total counts
    sns.ecdfplot(
        post_pred_check_dm.posterior_predictive.M.values[x_idx[i], y_idx[i], :],
        ax=ax,
        color=cor['pale_blue'],
        alpha=0.1,
    )

# Loop through samples
for i in range(n_samples):
    # Plot ECDF of the posterior predictive checks total counts
    sns.ecdfplot(
        post_pred_check_pm.posterior_predictive.M.values[x_idx[i], y_idx[i], :],
        ax=ax,
        color=cor['pale_red'],
        alpha=0.1,
    )


# Plot ECDF of the real data total counts
sns.ecdfplot(
    M_cells,
    ax=ax,
    label='data',
)

# Set fake plots not plotting anything for the legend
ax.plot([], [], color=cor['pale_blue'], label='NegBin-DirMult')
ax.plot([], [], color=cor['pale_red'], label='Poiss-Mult')

# Add legend
ax.legend(loc='lower right', fontsize=4)

# Label axis
ax.set_xlabel('total counts')
ax.set_ylabel('ECDF')


plt.tight_layout()

Let's do the same for the individual genes.

In [None]:
# Set random seed
rng = np.random.default_rng(42)

# Initialize figure
fig, axes = plt.subplots(2, 2, figsize=(3, 3))

# Flatten axes
axes = axes.flatten()

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.M.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.M.values.shape[1]),
    size=n_samples
)

# Loop through each gene
for (i, ax) in enumerate(axes):
    # Loop through samples
    for j in range(n_samples):
        # Plot ECDF of the posterior predictive checks total counts
        sns.ecdfplot(
            post_pred_check_dm.posterior_predictive.counts.values[x_idx[j],
                                                                  y_idx[j],
                                                                  :, i],
            ax=ax,
            color=cor['pale_blue'],
            alpha=0.1
        )
        sns.ecdfplot(
            post_pred_check_pm.posterior_predictive.counts.values[x_idx[j],
                                                                  y_idx[j],
                                                                  :, i],
            ax=ax,
            color=cor['pale_red'],
            alpha=0.1
        )
    # Plot ECDF of the real data total counts
    sns.ecdfplot(
        m_cells[:, i],
        ax=ax,
        label='data',
    )
    # Label axis
    ax.set_xlabel('counts')
    ax.set_ylabel('ECDF')
    # Set title
    ax.set_title(genes[i])
    # Set fake plots not plotting anything for the legend
    ax.plot([], [], color=cor['pale_blue'], label='NegBin-DirMult')
    ax.plot([], [], color=cor['pale_red'], label='Poiss-Mult')

    # Add legend
    ax.legend(loc='lower right', fontsize=4)


plt.tight_layout()

The negative binomial-Dirichlet-multinomial model is able to much better capture
the over-dispersion in the data. The challenge is now to work on the
computational efficiency of the model to make it feasible for large-scale
inference.