In [None]:
import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats, special

from rp2 import load_biomart_gene_symbols_df
from rp2.paths import get_data_path, get_output_path

In [None]:
species = "mouse"

In [None]:
gene_symbols_df = load_biomart_gene_symbols_df(species)

In [None]:
umi_adata = anndata.read_h5ad(get_data_path("ArrayExpress", f"E-MTAB-6754.processed.2.{species}.h5ad"))

In [None]:
txburst_params_df = pd.read_csv(get_output_path("txburst", "mouse_responsive_genes.csv"))

In [None]:
def poisson_beta_pmf(k, k_on, k_off, k_syn, n_roots=50):
    assert(k_on > 0)
    assert(k_off > 0)

    roots, weights = special.j_roots(n_roots, alpha=k_off - 1, beta=k_on - 1)
    mus = k_syn * (roots + 1) / 2
    assert(max(mus) < 1e6)

    gs = np.sum(weights * stats.poisson.pmf(k.reshape(-1, 1), mus), axis=1)
    probabilities = 1 / special.beta(k_on, k_off) * 2 ** (1 - k_on - k_off) * gs
    return probabilities

In [None]:
def describe_gene_condition(gene_symbol, replicates, treatments, time_points):
    gene_symbols_view = gene_symbols_df.loc[gene_symbols_df.symbol == gene_symbol]
    assert(len(gene_symbols_view) == 1)
    gene_id = gene_symbols_view.index[0]
    print(f"{gene_symbol} ({gene_id})")

    umi_adata_view = umi_adata[:, gene_id]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.species == species]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.replicate.isin(replicates)]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.treatment.isin(treatments)]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.time_point.isin(time_points)]
    umi_adata_view = umi_adata_view.copy()
    umi_counts = umi_adata_view.X.A.squeeze()

    txburst_params_view = txburst_params_df.loc[txburst_params_df.gene == gene_id]
    print(f"  {len(txburst_params_view)} conditions")

    txburst_params_view = txburst_params_view.loc[txburst_params_view.replicate.isin(replicates)]
    txburst_params_view = txburst_params_view.loc[txburst_params_view.treatment.isin(treatments)]
    txburst_params_view = txburst_params_view.loc[txburst_params_view.time_point.isin(time_points)]
    assert(len(txburst_params_view) == 1)

    n_cells = len(umi_counts)
    max_count = np.max(umi_counts)

    k_on, k_off, k_syn = txburst_params_view.squeeze()[["k_on", "k_off", "k_syn"]]

    plt.hist(umi_counts, bins=50, color="k")
    plt.plot(n_cells * poisson_beta_pmf(np.arange(max_count), k_on, k_off, k_syn), "g--", linewidth=2)
    plt.show()

In [None]:
describe_gene_condition(
    gene_symbol="Pfn1",
    replicates=["1"],
    treatments=["unst", "lps"],
    time_points=["6"],
)