In [None]:
import pymc as pm
import pytensor.tensor as pt
import numpy as np
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import glob

from utils import matplotlib_style
cor, pal = matplotlib_style()

# Set random seed
RANDOM_SEED = 42
rng = np.random.default_rng(RANDOM_SEED)

# Single-cell RNA-seq "MNIST" dataset

In this notebook, we will explore the inference over a [recently
published](https://www.biorxiv.org/content/10.1101/2021.12.08.471773v1) dataset
aiming to mimic the MNIST dataset in the single-cell RNA-seq space. The details
of the dataset do not matter much for the purpose of this notebook. The key is
to get a sense of the type of data that we are dealing with.

The dataset is provided in `h5ad` format and can be downloaded from [this Zenodo
repository](https://zenodo.org/records/7795653). Let's start by listing the
available files in the dataset.

In [None]:
# List all files in the data directory
files = glob.glob("../data/scmark_v2/scmark_v2/*.h5ad")

files

Since we have no specific interest in the dataset itself, we can pick one of
the datasets at random and use it for the purpose of this notebook. Let's load
one of these datasets.

In [None]:
# Load the first file
data = sc.read_h5ad(files[0])

data

The dataset contains 10,003 cells and 33,515 unique genes. Let's extract the 
data into a DataFrame.

In [None]:
# Extract the counts into a pandas dataframe
df_counts = pd.DataFrame(
    data.X.toarray(),
    columns=data.var.gene,
    index=data.obs.index
)

print(df_counts.shape)
df_counts.head()

Let's look at a few example ECDFs from the dataset. We will select a few genes
based on their mean count and plot the ECDFs for these genes.

In [None]:
# Define number of genes to select
n_genes = 9

# Compute the mean expression of each gene and sort them
df_mean = df_counts.mean().sort_values(ascending=False)

# Remove all genes with mean expression less than 1
df_mean = df_mean[df_mean > 1]

# Generate logarithmically spaced indices
log_indices = np.logspace(
    0, np.log10(len(df_mean) - 1), num=n_genes, dtype=int
)

# Select genes using the logarithmically spaced indices
genes = df_mean.iloc[log_indices].index

df_mean[genes]

Let's now plot the ECDFs for these genes.

In [None]:
# Initialize figure
fig, ax = plt.subplots(1, 1, figsize=(2, 1.5))

# Define step size for ECDF
step = 1

# Loop throu each gene
for (i, gene) in enumerate(genes):
    # Plot the ECDF for each column in the DataFrame
    sns.ecdfplot(
        data=df_counts,
        x=gene,
        ax=ax,
        color=sns.color_palette('Blues', n_colors=n_genes)[i],
        label=np.round(df_mean[gene], 0).astype(int),
        lw=1
    )

# Set x-axis to log scale
ax.set_xscale('log')

# Add axis labels
ax.set_xlabel('UMI count')
ax.set_ylabel('ECDF')

# Add legend
ax.legend(loc='lower right', fontsize=4, title=r"$\langle U \rangle$")

This seems like a good representative sample of the data.

## Negative Binomial-Dirichlet-Multinomial model

Let's write this model in `PyMC`. First we define the data.

In [None]:
# Define total number of counts per cell as the sum per row
U_cells = df_counts[genes].sum(axis=1).values

# Define counts per cell
u_cells = df_counts[genes].values

Now we define a `pm.Model`.

In [None]:
# Set model
with pm.Model() as scmark_negbin_dirmult:
    # Define prior on p
    p = pm.Beta('p', alpha=1, beta=1)
    # Define prior on all r parameters
    r_vec = pm.Gamma('r', alpha=2, beta=2, shape=n_genes)

    # Sum of r parameters
    r_o = pm.math.sum(r_vec)

    # Likelihood for Total observed counts
    U = pm.NegativeBinomial("U", p=p, alpha=r_o, observed=U_cells)

    # Use Dirichlet-Multinomial distribution for observed counts
    u_vec = pm.DirichletMultinomial(
        "umi_counts", n=U, a=r_vec, observed=u_cells
    )

Now, we can sample from the posterior using the NUTS sampler.

In [None]:
# Perform MCMC sampling with 4 chains
with scmark_negbin_dirmult:
    trace = pm.sample(1000, tune=4000, chains=4, cores=4)

Let's look at the traces to make sure everything looks good.

In [None]:
# Plot trace
az.plot_trace(trace, compact=False)

plt.tight_layout()

All chains seem to have converged. Let's look at the corner plot.

In [None]:
# Plot corner plot
axes = az.plot_pair(
    trace, var_names=['p', 'r'], kind="scatter", marginals=True
)

Let's now sample from the posterior predictive distribution.

In [None]:
with scmark_negbin_dirmult:
    post_pred_check_dm = pm.sample_posterior_predictive(trace)

Having these posterior predictive (retrodictive) checks in place allows us to
compare the model's predictions with the observed data. Let's plot the ECDFs for
the total mRNA count.

In [None]:
# Initialize figure
fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5))

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.U.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.U.values.shape[1]),
    size=n_samples
)

# Loop through samples
for i in range(n_samples):
    # Plot ECDF of the posterior predictive checks total counts
    sns.ecdfplot(
        post_pred_check_dm.posterior_predictive.U.values[x_idx[i], y_idx[i], :],
        ax=ax,
        color=cor['pale_blue'],
        alpha=0.1
    )

# Plot ECDF of the real data total counts
sns.ecdfplot(
    U_cells,
    ax=ax,
    label='data',
)

# Set x-axis to log scale
ax.set_xscale('log')

# Label axis
ax.set_xlabel('total counts')
ax.set_ylabel('ECDF')

plt.tight_layout()

In [None]:
# Initialize figure
fig, axes = plt.subplots(3, 3, figsize=(5, 5))

# Flatten axes
axes = axes.flatten()

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.U.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(post_pred_check_dm.posterior_predictive.U.values.shape[1]),
    size=n_samples
)

# Loop through each gene
for (i, ax) in enumerate(axes):
    # Loop through samples
    for j in range(n_samples):
        # Plot ECDF of the posterior predictive checks total counts
        sns.ecdfplot(
            post_pred_check_dm.posterior_predictive.umi_counts.values[x_idx[j],
                                                                      y_idx[j],
                                                                      :, i],
            ax=ax,
            color=cor['pale_blue'],
            alpha=0.1
        )
    # Plot ECDF of the real data total counts
    sns.ecdfplot(
        u_cells[:, i],
        ax=ax,
        label='data',
    )
    # Label axis
    ax.set_xlabel('counts')
    ax.set_ylabel('ECDF')
    # Set title
    ax.set_title(genes[i])
    # Set x-axis to log scale
    ax.set_xscale('log')

    # Add legend
    ax.legend(loc='lower right', fontsize=4)


plt.tight_layout()

This is not an ideal fit. Let's compare this with a model in which each gene is
fit to a completely independent negative binomial distribution.

In [None]:
# Set model
with pm.Model() as scmark_negbins:
    # Define prior on p
    p_vec = pm.Beta('p', alpha=1, beta=1, shape=n_genes)
    # Define prior on all r parameters
    r_vec = pm.Gamma('r', alpha=2, beta=2, shape=n_genes)

    # Use Negative Binomial distribution for observed counts
    u_vec = pm.NegativeBinomial(
        "umi_counts",
        p=p_vec,
        alpha=r_vec,
        shape=(len(u_cells), n_genes),
        observed=u_cells
    )

Let's now sample from the posterior of this model.

In [None]:
# Perform MCMC sampling with 4 chains
with scmark_negbins:
    trace_nb = pm.sample(1000, tune=4000, chains=4, cores=4)

In [None]:
# Plot trace
az.plot_trace(trace_nb, compact=False)

plt.tight_layout()

In [None]:
# Plot corner plot
axes = az.plot_pair(
    trace_nb, var_names=['p', 'r'], kind="scatter", marginals=True
)

In [None]:
with scmark_negbins:
    post_pred_check_nb = pm.sample_posterior_predictive(trace_nb)

In [None]:
# Initialize figure
fig, axes = plt.subplots(3, 3, figsize=(5, 5))

# Flatten axes
axes = axes.flatten()

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(
        post_pred_check_nb.posterior_predictive.umi_counts.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(
        post_pred_check_nb.posterior_predictive.umi_counts.values.shape[1]),
    size=n_samples
)

# Loop through each gene
for (i, ax) in enumerate(axes):
    # Loop through samples
    for j in range(n_samples):
        # Plot ECDF of the posterior predictive checks total counts
        sns.ecdfplot(
            post_pred_check_nb.posterior_predictive.umi_counts.values[x_idx[j],
                                                                      y_idx[j],
                                                                      :, i],
            ax=ax,
            color=cor['pale_blue'],
            alpha=0.1
        )
    # Plot ECDF of the real data total counts
    sns.ecdfplot(
        u_cells[:, i],
        ax=ax,
        label='data',
    )
    # Label axis
    ax.set_xlabel('counts')
    ax.set_ylabel('ECDF')
    # Set title
    ax.set_title(genes[i])
    # Set x-axis to log scale
    ax.set_xscale('log')

    # Add legend
    ax.legend(loc='lower right', fontsize=4)


plt.tight_layout()

In [None]:
# Initialize figure
fig, axes = plt.subplots(3, 3, figsize=(5, 5))

# Flatten axes
axes = axes.flatten()

# Define number of samples to plot
n_samples = 200

# Pick first dimension random indexes
x_idx = rng.choice(
    np.arange(
        post_pred_check_nb.posterior_predictive.umi_counts.values.shape[0]),
    size=n_samples
)
# Pick second dimension random indexes
y_idx = rng.choice(
    np.arange(
        post_pred_check_nb.posterior_predictive.umi_counts.values.shape[1]),
    size=n_samples
)

# Loop through each gene
for (i, ax) in enumerate(axes):
    # Loop through samples
    for j in range(n_samples):
        # Plot ECDF of the posterior predictive checks total counts
        sns.ecdfplot(
            post_pred_check_dm.posterior_predictive.umi_counts.values[x_idx[j],
                                                                      y_idx[j],
                                                                      :, i],
            ax=ax,
            color=cor['pale_blue'],
            alpha=0.1
        )
        sns.ecdfplot(
            post_pred_check_nb.posterior_predictive.umi_counts.values[x_idx[j],
                                                                      y_idx[j],
                                                                      :, i],
            ax=ax,
            color=cor['pale_red'],
            alpha=0.1
        )
    # Plot ECDF of the real data total counts
    sns.ecdfplot(
        u_cells[:, i],
        ax=ax,
        label='data',
    )
    # Label axis
    ax.set_xlabel('counts')
    ax.set_ylabel('ECDF')
    # Set title
    ax.set_title(genes[i])
    # Set x-axis to log scale
    ax.set_xscale('log')

    # Set fake plots not plotting anything for the legend
    ax.plot([], [], color=cor['pale_blue'], label='NegBin-DirMult')
    ax.plot([], [], color=cor['pale_red'], label='NegBins')

    # Add legend
    ax.legend(loc='lower right', fontsize=4)


plt.tight_layout()