In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display

from rp2 import load_biomart_gene_symbols_df, hagai_2018

In [None]:

species = "mouse"

In [None]:
gene_info_df = load_biomart_gene_symbols_df(species)

In [None]:
umi_counts_ad = hagai_2018.load_umi_count(species)

In [None]:
all_gene_ids = hagai_2018.load_lps_responsive_genes()
umi_counts_ad = umi_counts_ad[:, all_gene_ids].copy()

In [None]:
all_condition_stats_df = hagai_2018.calculate_umi_condition_stats(umi_counts_ad)

In [None]:
all_genes = gene_info_df.loc[umi_counts_ad.var_names].symbol.sort_values()
all_replicates = umi_counts_ad.obs.replicate.sort_values().unique().tolist()
all_treatments = umi_counts_ad.obs.treatment.sort_values().unique().tolist()
all_time_points = umi_counts_ad.obs.time_point.sort_values().unique().tolist()

gene_selector = widgets.Select(
    description="Gene:",
    options=list(zip(all_genes.values, all_genes.index)),
    label="Pfn1",
    rows=4,
)
replicates_selector = widgets.SelectMultiple(
    description="Replicate:",
    options=all_replicates,
    value=all_replicates,
    rows=3,
)
treatments_selector = widgets.SelectMultiple(
    description="Treatments:",
    options=all_treatments,
    value=all_treatments,
    rows=3,
)
time_points_selector = widgets.SelectMultiple(
    description="time points:",
    options=all_time_points,
    value=all_time_points,
    rows=len(all_time_points),
)

mv_output = widgets.Output()


def update_ui():
    mv_output.clear_output()
    with mv_output:
        selected_gene_id = gene_selector.value
        selected_replicates = replicates_selector.value
        selected_treatments = treatments_selector.value
        selected_time_points = time_points_selector.value

        condition_stats_subset = all_condition_stats_df.loc[
            (all_condition_stats_df.gene == selected_gene_id) &
            all_condition_stats_df.replicate.isin(selected_replicates) &
            all_condition_stats_df.treatment.isin(selected_treatments) &
            all_condition_stats_df.time_point.isin(selected_time_points)
        ]

        if condition_stats_subset.empty:
            return

        ax = sns.scatterplot(
            x="mean",
            y="variance",
            hue="replicate",
            style="treatment",
            size="time_point",
            data=condition_stats_subset,
        )
        plt.xlim(left=0)
        plt.ylim(bottom=0)
        sns.regplot(
            x="mean",
            y="variance",
            truncate=False,
            scatter=False,
            data=condition_stats_subset,
            ax=ax,
        )
        plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
        plt.show()


def event_handler(event):
    if event["type"] != "change": return
    if event["name"] != "value": return
    update_ui()


gene_selector.observe(event_handler)
replicates_selector.observe(event_handler)
treatments_selector.observe(event_handler)
time_points_selector.observe(event_handler)

ui_container = widgets.VBox([
    widgets.HBox([
        widgets.VBox([
            gene_selector,
            replicates_selector,
        ]),
        widgets.VBox([
            treatments_selector,
            time_points_selector,
        ]),
    ]),
    mv_output,
])

display(ui_container)

update_ui()