In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display
from scipy import stats, special

from rp2 import load_biomart_gene_symbols_df, hagai_2018
from rp2.paths import get_output_path

In [None]:
species = "mouse"

In [None]:
gene_symbols_df = load_biomart_gene_symbols_df(species)

In [None]:
umi_adata = hagai_2018.load_umi_count(species)

In [None]:
txburst_params_df = pd.read_csv(get_output_path("txburst", "mouse_responsive_genes.csv"))

In [None]:
def poisson_beta_pmf(k, k_on, k_off, k_syn, n_roots=50):
    assert(k_on > 0)
    assert(k_off > 0)

    roots, weights = special.j_roots(n_roots, alpha=k_off - 1, beta=k_on - 1)
    mus = k_syn * (roots + 1) / 2
    assert(max(mus) < 1e6)

    gs = np.sum(weights * stats.poisson.pmf(k.reshape(-1, 1), mus), axis=1)
    probabilities = 1 / special.beta(k_on, k_off) * 2 ** (1 - k_on - k_off) * gs
    return probabilities

In [None]:
def describe_gene_condition(gene_symbol, replicates, treatments, time_points):
    gene_symbols_view = gene_symbols_df.loc[gene_symbols_df.symbol == gene_symbol]
    if len(gene_symbols_view) > 1:
        print(f"Gene symbol {gene_symbol} is not unique:")
        for gene_id in gene_symbols_view.index:
            print(f"  {gene_id}")
        return

    gene_id = gene_symbols_view.index[0]
    print(f"{gene_symbol} ({gene_id})")

    umi_adata_view = umi_adata[:, gene_id]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.species == species]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.replicate.isin(replicates)]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.treatment.isin(treatments)]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.time_point.isin(time_points)]
    umi_adata_view = umi_adata_view.copy()
    umi_counts = umi_adata_view.X.A.squeeze().astype(np.int)

    txburst_params_view = txburst_params_df.loc[txburst_params_df.gene == gene_id]
    print(f"  {len(txburst_params_view)} conditions (total for gene)")

    txburst_params_view = txburst_params_view.loc[txburst_params_view.replicate.isin(replicates)]
    txburst_params_view = txburst_params_view.loc[txburst_params_view.treatment.isin(treatments)]
    txburst_params_view = txburst_params_view.loc[txburst_params_view.time_point.isin(time_points)]
    if len(txburst_params_view) == 0:
        print("  Selected condition has no data")
        return
    assert(len(txburst_params_view) == 1)

    n_cells = len(umi_counts)
    max_count = np.max(umi_counts)
    print(f"  {n_cells:,} cells")

    k_on, k_off, k_syn = txburst_params_view.squeeze()[["k_on", "k_off", "k_syn"]]

    bin_values, bin_edges, _ = plt.hist(umi_counts, bins=min(50, max_count), color="k")
    hist_area = np.sum(np.diff(bin_edges) * bin_values)

    pmf_in = np.arange(max_count)
    pmf_out = poisson_beta_pmf(pmf_in, k_on, k_off, k_syn)
    plt.plot(pmf_in + np.mean(bin_edges[:2]), hist_area * pmf_out, "g--", linewidth=2)

    plt.xlabel("UMI count")
    plt.ylabel("Cells")
    plt.show()

    display(txburst_params_view.loc[:, "k_on":])

In [None]:
def create_dropdown(description, series, selection):
    return widgets.Dropdown(
        description=description,
        options=series.sort_values().unique().tolist(),
        value=selection,
    )

gene_symbol_dropdown = create_dropdown("Gene:", txburst_params_df.gene.map(gene_symbols_df.symbol), "Pfn1")
replicate_dropdown = create_dropdown("Replicate:", umi_adata.obs.replicate, "1")
treatment_dropdown = create_dropdown("Treatment:", umi_adata.obs.treatment[umi_adata.obs.treatment != "unst"], "lps")
time_point_dropdown = create_dropdown("Time", umi_adata.obs.time_point, "6")

ui_output = widgets.Output()

ui_container = widgets.VBox([
    widgets.HBox([
        widgets.VBox([gene_symbol_dropdown, replicate_dropdown]),
        widgets.VBox([treatment_dropdown, time_point_dropdown]),
    ]),
    ui_output,
])

def update_ui():
    ui_output.clear_output()
    with ui_output:
        describe_gene_condition(
            gene_symbol=gene_symbol_dropdown.value,
            replicates=[replicate_dropdown.value],
            treatments=["unst", treatment_dropdown.value],
            time_points=[time_point_dropdown.value],
        )

def event_handler(event):
    if event["type"] != "change": return
    if event["name"] != "value": return
    update_ui()

gene_symbol_dropdown.observe(event_handler)
replicate_dropdown.observe(event_handler)
treatment_dropdown.observe(event_handler)
time_point_dropdown.observe(event_handler)

display(ui_container)
update_ui()