In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display

from rp2 import load_biomart_gene_symbols_df, hagai_2018
from rp2.paths import get_output_path

In [None]:

species = "mouse"

In [None]:
gene_info_df = load_biomart_gene_symbols_df(species)

In [None]:
umi_counts_ad = hagai_2018.load_umi_count(species)

In [None]:
all_gene_ids = hagai_2018.load_lps_responsive_genes()
umi_counts_ad = umi_counts_ad[:, all_gene_ids].copy()

In [None]:
index_columns = ["gene", "replicate", "treatment", "time_point"]

def read_txburst_params_csv():
    df = pd.read_csv(get_output_path("txburst", f"{species}_responsive_genes.csv"))
    df["replicate"] = df["replicate"].astype(str)
    return df.set_index(index_columns)


all_condition_stats_df = hagai_2018.calculate_umi_condition_stats(umi_counts_ad).set_index(index_columns).join(
    read_txburst_params_csv(),
    how="left",
).reset_index()

In [None]:
def create_df_row_mask(df, column_values):
    mask = pd.Series(index=df.index, data=True)
    for column_name, column_value in column_values.items():
        if isinstance(column_value, (list, tuple)):
            mask &= df[column_name].isin(column_value)
        else:
            mask &= df[column_name] == column_value
    return mask


def filter_df_rows(df, column_values):
    mask = create_df_row_mask(df, column_values)
    return df.loc[mask]


all_genes = gene_info_df.loc[umi_counts_ad.var_names].symbol.sort_values()
all_replicates = umi_counts_ad.obs.replicate.sort_values().unique().tolist()
all_treatments = umi_counts_ad.obs.treatment.sort_values().unique().tolist()
all_time_points = umi_counts_ad.obs.time_point.sort_values().unique().tolist()

gene_selector = widgets.Select(
    description="Gene:",
    options=list(zip(all_genes.values, all_genes.index)),
    label="Pfn1",
    rows=4,
)
replicates_selector = widgets.SelectMultiple(
    description="Replicate:",
    options=all_replicates,
    value=all_replicates,
    rows=3,
)
treatments_selector = widgets.SelectMultiple(
    description="Treatments:",
    options=all_treatments,
    value=all_treatments,
    rows=3,
)
time_points_selector = widgets.SelectMultiple(
    description="time points:",
    options=all_time_points,
    value=all_time_points,
    rows=len(all_time_points),
)

mv_plot_output = widgets.Output()
dist_plot_output = widgets.Output()
bp_output = widgets.Output()
bp_plot_output = widgets.Output()


def update_ui():
    selected_gene_id = gene_selector.value
    selected_conditions = {
        "replicate": (replicates_selector.value),
        "treatment": (treatments_selector.value),
        "time_point": (time_points_selector.value),
    }

    mv_plot_output.clear_output()
    dist_plot_output.clear_output()
    bp_output.clear_output()
    bp_plot_output.clear_output()

    condition_stats_subset = filter_df_rows(all_condition_stats_df, {**selected_conditions, "gene": selected_gene_id})
    if not condition_stats_subset.empty:
        with mv_plot_output:
            ax = sns.scatterplot(
                x="mean",
                y="variance",
                hue="replicate",
                style="treatment",
                size="time_point",
                data=condition_stats_subset,
            )
            plt.xlim(left=0)
            plt.ylim(bottom=0)
            sns.regplot(
                x="mean",
                y="variance",
                truncate=False,
                scatter=False,
                data=condition_stats_subset,
                ax=ax,
            )
            plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
            plt.show()

        bp_plot_params = ["bs_point", "bf_point", "k_off", "k_syn"]
        with bp_plot_output:
            fig, axes = plt.subplots(1, len(bp_plot_params), figsize=(len(bp_plot_params) * 4, 4))
            for plot_param, ax in zip(bp_plot_params, axes.flat):
                sns.scatterplot(
                    x="mean",
                    y=plot_param,
                    hue="time_point",
                    style="replicate",
                    data=condition_stats_subset,
                    ax=ax,
                )
                ax.set_xlim(left=0)
                ax.set_ylim(bottom=0)
                ax.get_legend().remove()
            plt.legend(*ax.get_legend_handles_labels(), loc="upper left", bbox_to_anchor=(1, 1))
            plt.tight_layout()
            plt.show()

    umi_counts_obs_subset = filter_df_rows(umi_counts_ad.obs, selected_conditions)
    if not umi_counts_obs_subset.empty:
        with dist_plot_output:
            counts_list = []
            labels_list = []
            for (replicate, treatment, time_point), obs_group in umi_counts_obs_subset.groupby(["replicate", "treatment", "time_point"]):
                counts_list.append(umi_counts_ad[obs_group.index, selected_gene_id].X.A)
                labels_list.append(f"{treatment} R{replicate} T{time_point}")
            plt.hist(
                np.squeeze(counts_list),
                label=labels_list,
            )
            plt.xlabel("UMI count")
            plt.ylabel("Cells")
            plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
            plt.show()

    txburst_params_subset = condition_stats_subset.drop(columns="gene").set_index(["replicate", "treatment", "time_point"]).loc[:, "k_on":]
    with bp_output:
        display(txburst_params_subset)


def event_handler(event):
    if event["type"] != "change": return
    if event["name"] != "value": return
    update_ui()


gene_selector.observe(event_handler)
replicates_selector.observe(event_handler)
treatments_selector.observe(event_handler)
time_points_selector.observe(event_handler)

tab_widget = widgets.Tab()
tab_widget.children = [mv_plot_output, dist_plot_output, bp_output, bp_plot_output]
tab_widget.set_title(0, "Mean-variance plot")
tab_widget.set_title(1, "UMI distribution")
tab_widget.set_title(2, "Burst parameters")
tab_widget.set_title(3, "Burst parameter plots")

ui_container = widgets.VBox([
    widgets.HBox([
        widgets.VBox([
            gene_selector,
            replicates_selector,
        ]),
        widgets.VBox([
            treatments_selector,
            time_points_selector,
        ]),
    ]),
    tab_widget,
])

display(ui_container)

update_ui()