In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display, Latex

import rp2
from rp2 import load_biomart_gene_symbols_df, hagai_2018, txburst
from rp2.paths import get_output_path

rp2.check_environment()

In [None]:
species = "mouse"

In [None]:
gene_symbols_df = load_biomart_gene_symbols_df(species)

In [None]:
umi_adata = hagai_2018.load_umi_count(species)

In [None]:
txburst_params_df = pd.read_csv(get_output_path("txburst", "mouse_responsive_genes.csv"))

In [None]:
def describe_gene_condition(gene_id, replicates, treatments, time_points):
    gene_symbol = gene_symbols_df.loc[gene_id].symbol
    print(f"{gene_symbol} ({gene_id})")

    umi_adata_view = umi_adata[:, gene_id]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.species == species]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.replicate.isin(replicates)]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.treatment.isin(treatments)]
    umi_adata_view = umi_adata_view[umi_adata_view.obs.time_point.isin(time_points)]
    umi_adata_view = umi_adata_view.copy()
    umi_counts = umi_adata_view.X.A.squeeze().astype(np.int)

    txburst_params_view = txburst_params_df.loc[txburst_params_df.gene == gene_id]
    print(f"  {len(txburst_params_view)} conditions (total for gene)")

    txburst_params_view = txburst_params_view.loc[txburst_params_view.replicate.isin(replicates)]
    txburst_params_view = txburst_params_view.loc[txburst_params_view.treatment.isin(treatments)]
    txburst_params_view = txburst_params_view.loc[txburst_params_view.time_point.isin(time_points)]
    if len(txburst_params_view) == 0:
        print("  Selected condition has no data")
        return
    assert(len(txburst_params_view) == 1)

    n_cells = len(umi_counts)
    max_count = np.max(umi_counts)
    print(f"  {n_cells:,} cells, {np.count_nonzero(umi_counts == 0):,} with no UMIs")

    umi_var = np.var(umi_counts, ddof=1)
    umi_mean = np.mean(umi_counts)

    k_on, k_off, k_syn = txburst_params_view.squeeze()[["k_on", "k_off", "k_syn"]]

    bin_values, bin_edges, _ = plt.hist(umi_counts, bins=min(50, max_count), color="orange")
    hist_area = np.sum(np.diff(bin_edges) * bin_values)

    pmf_in = np.arange(max_count)
    pmf_out = txburst.poisson_beta_pmf(pmf_in, k_on, k_off, k_syn)
    plt.plot(pmf_in + np.mean(bin_edges[:2]), hist_area * pmf_out, "g--", linewidth=2)

    plt.axvline(x=umi_mean, linestyle=":")
    plt.xlabel("UMI count")
    plt.ylabel("Cells")
    plt.show()

    display(txburst_params_view.loc[:, "k_on":])

    bm = umi_var / umi_mean
    fm = 1 / (bm - 1)
    display(
        Latex(", ".join([
            f"$\\mu={umi_mean:.2f}$",
            f"$\\sigma^2={umi_var:.2f}$",
            f"$b_{{m}}=\\frac{{\\sigma^2}}{{\\mu}}={bm:.2f}$",
            f"$f_{{m}}=\\frac{{1}}{{b_{{m}}-1}}={fm:.2f}$",
        ])),
        Latex(f"burstiness$=\\frac{{k_{{off}}}}{{k_{{on}}}}={k_off / k_on:.2f}$"),
    )

In [None]:
all_genes = gene_symbols_df.loc[txburst_params_df.gene.unique()].symbol.sort_values()

gene_selector = widgets.Select(
    description="Gene:",
    options=list(zip(all_genes.values, all_genes.index)),
    label="Pfn1",
    rows=4,
)
replicate_selector = widgets.Select(
    description="Replicate:",
    options=umi_adata.obs.replicate.sort_values().unique().tolist(),
    value="1",
    rows=3,
)
treatment_selector = widgets.Select(
    description="Treatment:",
    options=umi_adata.obs.treatment[umi_adata.obs.treatment != "unst"].sort_values().unique().tolist(),
    value="lps",
    rows=2,
)
time_point_selector = widgets.Select(
    description="Time:",
    options=umi_adata.obs.time_point.sort_values().unique().tolist(),
    value="6",
    rows=5,
)

ui_output = widgets.Output()

ui_container = widgets.VBox([
    widgets.HBox([
        widgets.VBox([gene_selector, replicate_selector]),
        widgets.VBox([treatment_selector, time_point_selector]),
    ]),
    ui_output,
])


def update_ui():
    ui_output.clear_output()
    with ui_output:
        describe_gene_condition(
            gene_id=gene_selector.value,
            replicates=[replicate_selector.value],
            treatments=["unst", treatment_selector.value],
            time_points=[time_point_selector.value],
        )

def event_handler(event):
    if event["type"] != "change": return
    if event["name"] != "value": return
    update_ui()

gene_selector.observe(event_handler)
replicate_selector.observe(event_handler)
treatment_selector.observe(event_handler)
time_point_selector.observe(event_handler)

display(ui_container)
update_ui()