In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display

import rp2
from rp2 import hagai_2018, txburst
from rp2.paths import get_txburst_results_csv_path

rp2.check_environment()

In [None]:
species = "mouse"

condition_index_columns = ["replicate", "treatment", "time_point"]
gene_index_columns = ["gene"] + condition_index_columns

In [None]:
txburst_params_df = pd.read_csv(get_txburst_results_csv_path(species, condition_index_columns))
txburst_params_df["replicate"] = txburst_params_df["replicate"].astype(str)

analysis_gene_ids = txburst_params_df.gene.sort_values().unique().tolist()
print(f"{len(analysis_gene_ids):,} genes have been processed by txburst")

In [None]:
umi_counts_adata, ms_counts_adata = hagai_2018.load_counts(species, scaling=["umi", "median"])
print(f"Total of {umi_counts_adata.n_vars:,} genes")

umi_counts_adata = umi_counts_adata[:, analysis_gene_ids].copy()
ms_counts_adata = ms_counts_adata[:, analysis_gene_ids].copy()
print(f"Reduced to subset of {umi_counts_adata.n_vars:,}")

In [None]:
def create_expression_data(counts_adata):
    return dict(
        counts=counts_adata,
        stats=hagai_2018.calculate_counts_condition_stats(counts_adata),
    )


expression_data_map = {
    "UMI": create_expression_data(umi_counts_adata),
    "Median scaled": create_expression_data(ms_counts_adata),
}

In [None]:
def create_df_row_mask(df, column_values):
    mask = pd.Series(index=df.index, data=True)
    for column_name, column_value in column_values.items():
        if isinstance(column_value, (list, tuple)):
            mask &= df[column_name].isin(column_value)
        else:
            mask &= df[column_name] == column_value
    return mask


def filter_df_rows(df, column_values):
    mask = create_df_row_mask(df, column_values)
    return df.loc[mask]


def plot_expression_distribution(expression_adata, burst_param_subset):
    counts = np.squeeze(expression_adata.X.A)
    max_count = np.max(counts)
    mean_count = np.mean(counts)

    k_on, k_off, k_syn = burst_param_subset.squeeze()[["k_on", "k_off", "k_syn"]]

    upper_bin = int(np.ceil(max_count))
    bin_values, bin_edges, _ = plt.hist(counts, bins=min(50, upper_bin), color="orange")
    hist_area = np.sum(np.diff(bin_edges) * bin_values)

    pmf_in = np.arange(max_count)
    pmf_out = txburst.poisson_beta_pmf(pmf_in, k_on, k_off, k_syn)
    plt.plot(pmf_in + np.mean(bin_edges[:2]), hist_area * pmf_out, "g--", linewidth=2)

    plt.axvline(x=mean_count, linestyle=":")
    plt.xlabel("Count")
    plt.ylabel("Cells")
    plt.show()


def plot_expression_distributions(expression_adata, counts_obs_subset, burst_param_subset):
    if counts_obs_subset.empty:
        print("No data available")
        return

    condition_groups = counts_obs_subset.groupby(condition_index_columns)
    if len(condition_groups) == 1:
        index_values, group_df = list(condition_groups)[0]
        plot_expression_distribution(expression_adata[group_df.index], burst_param_subset)
        return

    _, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    counts_list = []
    labels_list = []
    for (replicate, treatment, time_point), obs_group in condition_groups:
        counts_list.append(np.squeeze(expression_adata[obs_group.index].X.A))
        labels_list.append(f"{treatment} R{replicate} T{time_point}")
        sns.kdeplot(
            counts_list[-1],
            label=labels_list[-1],
            bw=1,
            ax=ax2,
        )

    ax1.hist(
        counts_list,
        label=labels_list,
    )

    for ax in (ax1, ax2):
        ax.set_xlabel("Count")
        ax.set_xlim(left=0)
        ax.legend(loc="upper left", bbox_to_anchor=(1, 1))

    ax1.set_ylabel("Cells")
    ax2.set_ylabel("Density (kde)")
    ax2.set_xlim(right=ax1.get_xlim()[1])

    plt.tight_layout()
    plt.show()


all_genes = umi_counts_adata.var.symbol.sort_values()
all_replicates = umi_counts_adata.obs.replicate.sort_values().unique().tolist()
all_treatments = umi_counts_adata.obs.treatment.sort_values().unique().tolist()
all_time_points = umi_counts_adata.obs.time_point.sort_values().unique().tolist()
all_expression_types = sorted(expression_data_map.keys())

gene_selector = widgets.Select(
    description="Gene:",
    options=list(zip(all_genes.values, all_genes.index)),
    label="Pfn1",
    rows=4,
)
replicates_selector = widgets.SelectMultiple(
    description="Replicate:",
    options=all_replicates,
    value=all_replicates,
    rows=3,
)
treatments_selector = widgets.SelectMultiple(
    description="Treatments:",
    options=all_treatments,
    value=all_treatments,
    rows=3,
)
time_points_selector = widgets.SelectMultiple(
    description="Time points:",
    options=all_time_points,
    value=all_time_points,
    rows=len(all_time_points),
)
expression_type_selector = widgets.Select(
    description="Count type:",
    options=all_expression_types,
    value="UMI",
    rows=len(all_expression_types),
)

mv_plot_output = widgets.Output()
dist_plot_output = widgets.Output()
bp_output = widgets.Output()
bp_plot_output = widgets.Output()


def update_ui():
    selected_gene_id = gene_selector.value
    selected_conditions = {
        "replicate": replicates_selector.value,
        "treatment": treatments_selector.value,
        "time_point": time_points_selector.value,
    }
    selected_expression_type = expression_type_selector.value

    expression_data = expression_data_map[selected_expression_type]

    all_condition_stats_df = expression_data["stats"].set_index(gene_index_columns).join(
        txburst_params_df.set_index(gene_index_columns),
        how="left",
    ).reset_index()
    all_condition_stats_df["k_burstiness"] = all_condition_stats_df.k_off / all_condition_stats_df.k_on
    all_condition_stats_df["m_burstiness"] = all_condition_stats_df["variance"] / all_condition_stats_df["mean"]

    mv_plot_output.clear_output()
    dist_plot_output.clear_output()
    bp_output.clear_output()
    bp_plot_output.clear_output()

    condition_stats_subset = filter_df_rows(all_condition_stats_df, {**selected_conditions, "gene": selected_gene_id})
    if not condition_stats_subset.empty:
        with mv_plot_output:
            ax = sns.scatterplot(
                x="mean",
                y="variance",
                hue="replicate",
                style="treatment",
                size="time_point",
                data=condition_stats_subset,
            )
            plt.xlim(left=0)
            plt.ylim(bottom=0)
            sns.regplot(
                x="mean",
                y="variance",
                truncate=False,
                scatter=False,
                data=condition_stats_subset,
                ax=ax,
            )
            plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
            plt.show()

        bp_plot_params = ["bs_point", "bf_point", "k_off", "k_syn", "k_burstiness"]
        with bp_plot_output:
            fig, axes = plt.subplots(1, len(bp_plot_params), figsize=(len(bp_plot_params) * 4, 4))
            for plot_param, ax in zip(bp_plot_params, axes.flat):
                sns.scatterplot(
                    x="mean",
                    y=plot_param,
                    hue="time_point",
                    style="replicate",
                    data=condition_stats_subset,
                    ax=ax,
                )
                ax.set_xlim(left=0)
                ax.set_ylim(bottom=0)
                ax_legend = ax.get_legend()
                if ax_legend is not None:
                    ax_legend.remove()
            plt.legend(*ax.get_legend_handles_labels(), loc="upper left", bbox_to_anchor=(1, 1))
            plt.tight_layout()
            plt.show()

    gene_expression_adata = expression_data["counts"][:, selected_gene_id]
    counts_obs_subset = filter_df_rows(gene_expression_adata.obs, selected_conditions)
    txburst_params_subset = condition_stats_subset.drop(columns="gene").set_index(condition_index_columns).loc[:, "k_on":]

    with dist_plot_output:
        plot_expression_distributions(gene_expression_adata, counts_obs_subset, txburst_params_subset)

    with bp_output:
        display(txburst_params_subset)


def event_handler(event):
    if event["type"] != "change": return
    if event["name"] != "value": return
    update_ui()


gene_selector.observe(event_handler)
replicates_selector.observe(event_handler)
treatments_selector.observe(event_handler)
time_points_selector.observe(event_handler)
expression_type_selector.observe(event_handler)

tab_widget = widgets.Tab()
tab_widget.children = [mv_plot_output, dist_plot_output, bp_output, bp_plot_output]
tab_widget.set_title(0, "Mean-variance plot")
tab_widget.set_title(1, "Count distribution")
tab_widget.set_title(2, "Burst parameters")
tab_widget.set_title(3, "Burst parameter plots")

ui_container = widgets.VBox([
    widgets.HBox([
        widgets.VBox([
            gene_selector,
            replicates_selector,
        ]),
        widgets.VBox([
            treatments_selector,
            time_points_selector,
        ]),
        expression_type_selector,
    ]),
    tab_widget,
])

display(ui_container)

update_ui()