In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display
from sklearn.linear_model import LinearRegression

import rp2
from rp2 import hagai_2018
from rp2.paths import get_txburst_results_csv_path

rp2.check_environment()

In [None]:
def load_txburst_params(species, index_columns, count_type):
    params_df = pd.read_csv(get_txburst_results_csv_path(species, index_columns, count_type=count_type))
    params_df.replicate = params_df.replicate.astype(str)

    for param in ["bf", "bs"]:
        params_df[f"valid_{param}_point"] = params_df[f"{param}_point"].notna()
        params_df[f"valid_{param}_interval"] = params_df[f"{param}_lower"].notna() & params_df[
            f"{param}_upper"].notna()
    params_df["valid_points"] = params_df.valid_bs_point & params_df.valid_bf_point
    params_df["valid_intervals"] = params_df.valid_bs_interval & params_df.valid_bf_interval
    params_df["valid_params"] = params_df.valid_points & params_df.valid_intervals
    params_df["k_burstiness"] = params_df.k_off / params_df.k_on

    return params_df


notebook_species = "mouse"
notebook_count_type = "median"
compare_count_types = ["umi"]
condition_index_columns = ["replicate", "treatment", "time_point"]

txburst_results_map = {count_type: load_txburst_params(notebook_species, condition_index_columns, count_type)
                      for count_type in [notebook_count_type] + compare_count_types}

txburst_params_df = txburst_results_map[notebook_count_type]

In [None]:
counts_adata = hagai_2018.load_counts(notebook_species, scaling=notebook_count_type)
counts_adata = counts_adata[:, txburst_params_df.gene.unique()].copy()
counts_stats_df = hagai_2018.calculate_counts_condition_stats(counts_adata)

gene_symbols = counts_adata.var.symbol

In [None]:
gene_index_columns = ["gene"] + condition_index_columns

condition_info_df = txburst_params_df.set_index(gene_index_columns).join(counts_stats_df.set_index(gene_index_columns)).reset_index()

all_treatments = condition_info_df.treatment.sort_values().unique().tolist()

In [None]:
def calculate_linear_regression(df, x_var, y_var):
    lr_x, lr_y = df.loc[:, [x_var, y_var]].to_numpy().reshape(1, -1, 2).T
    lr = LinearRegression().fit(lr_x, lr_y)

    return {
        "slope": lr.coef_.item(),
        "intercept": lr.intercept_.item(),
        "r2": lr.score(lr_x, lr_y),
    }


def create_gene_regression(condition_df):
    lr_values = calculate_linear_regression(condition_df, "mean", "variance")

    return pd.Series(data=lr_values)


def create_gene_info(treatments):
    condition_info_subset = condition_info_df.loc[condition_info_df.treatment.isin(treatments)]
    gene_info_df = condition_info_subset.groupby("gene").apply(create_gene_regression)
    return gene_info_df


treatment_groups = {
    "LPS only": ["unst", "lps"],
    "PIC only": ["unst", "pic"],
    "LPS and PIC": ["unst", "lps", "pic"],
}

gene_info_map = {name: create_gene_info(treatments) for name, treatments in treatment_groups.items()}

In [None]:
def make_scale_selector(default="linear"):
    return widgets.Dropdown(
        options=["linear", "log"],
        value=default,
    )


def make_treatment_group_selector():
    return widgets.Dropdown(
        options=treatment_groups.keys(),
        label="LPS and PIC",
    )


def make_condition_colour_selector(default="time_point"):
    return widgets.Dropdown(
        options=["replicate", "treatment", "time_point", None],
        value=default,
    )


def make_valid_point_flag_selector():
    return widgets.Dropdown(options=["valid_points", "valid_intervals"])

In [None]:
@widgets.interact(treatment_group=make_treatment_group_selector(), valid_flag=make_valid_point_flag_selector(), y_scale=make_scale_selector("log"))
def plot_count_type_comparison(treatment_group, valid_flag, y_scale):
    treatments = treatment_groups[treatment_group]

    bar_width = 0.35
    _, ax = plt.subplots()
    for i, (count_type, burst_params_df) in enumerate(txburst_results_map.items()):
        burst_params_subset = burst_params_df[burst_params_df.treatment.isin(treatments)]
        counts = burst_params_subset.groupby("gene")[valid_flag].agg(np.count_nonzero)
        count_frequencies = counts.value_counts()
        ax.bar(
            count_frequencies.index + (bar_width * i),
            count_frequencies.values,
            width=bar_width,
            label=count_type
        )
    plt.xlabel("No. conditions")
    plt.ylabel("No. genes")
    plt.yscale(y_scale)
    plt.legend(loc="upper right")
    plt.show()

In [None]:
@widgets.interact(y_scale=make_scale_selector())
def plot_per_gene_condition_frequencies(y_scale):
    n_treatments = len(all_treatments)
    _, axes = plt.subplots(ncols=n_treatments, figsize=(n_treatments * 5, 5), sharey=True)
    for treatment, ax in zip(all_treatments, axes):
        counts = condition_info_df.loc[condition_info_df.treatment == treatment].groupby("gene")[["keep", "valid_intervals"]].agg(np.count_nonzero).melt()
        sns.countplot(
            x="value",
            hue="variable",
            ax=ax,
            data=counts,
        )
        ax.set_title(treatment)
        ax.set_xlabel("No. conditions")
        ax.set_ylabel("No. genes")
        ax.set_yscale(y_scale)
        ax.legend(loc="upper right")
    plt.tight_layout()
    plt.show()

In [None]:
@widgets.interact(valid_flag=make_valid_point_flag_selector(), y_scale=make_scale_selector("log"))
def plot_condition_fit_successes(valid_flag, y_scale):
    vars = ["mean", "min", "max", "std_dev", "n_barcodes"]
    n_vars = len(vars)
    _, axes = plt.subplots(1, n_vars, figsize=(4 * n_vars, 4))
    for v, ax in zip(vars, axes):
        ax.boxplot(
            [condition_info_df.loc[condition_info_df[valid_flag], v],
             condition_info_df.loc[~condition_info_df[valid_flag], v]],
            labels=["True", "False"],
        )
        ax.set_ylabel(v)
        ax.set_yscale(y_scale)
    plt.tight_layout()
    plt.show()

In [None]:
@widgets.interact(valid_flag=make_valid_point_flag_selector(), scale=make_scale_selector())
def plot_burst_size_vs_frequency(valid_flag, scale):
    x_var = "bs_point"
    y_var = "bf_point"
    n_bins = 40
    bin_function = np.geomspace if scale == "log" else np.linspace

    condition_info_subset = condition_info_df.loc[condition_info_df[valid_flag]]
    x_bins, y_bins = (bin_function(condition_info_subset[v].min(), condition_info_subset[v].max(), n_bins)
                      for v in (x_var, y_var))                          

    grid = sns.JointGrid(
        x=x_var,
        y=y_var,
        data=condition_info_subset,
    )
    grid.plot_joint(plt.scatter, color="black", edgecolor="black")
    grid.ax_marg_x.hist(condition_info_subset[x_var], bins=x_bins)
    grid.ax_marg_y.hist(condition_info_subset[y_var], bins=y_bins, orientation="horizontal")
    grid.ax_joint.set_xscale(scale)
    grid.ax_joint.set_yscale(scale)
    grid.ax_marg_x.set_xscale(scale)
    grid.ax_marg_y.set_yscale(scale)
    plt.show()

In [None]:
@widgets.interact(valid_flag=make_valid_point_flag_selector(), colour_by=make_condition_colour_selector())
def plot_burst_param_pairs(valid_flag, colour_by):
    sns.pairplot(
        vars=["k_on", "k_off", "k_syn", "bs_point", "bf_point"],
        hue=colour_by,
        data=condition_info_df.loc[condition_info_df[valid_flag]],
    )

In [None]:
@widgets.interact(valid_flag=make_valid_point_flag_selector(), colour_by=make_condition_colour_selector())
def plot_condition_params_vs_burst_params(valid_flag, colour_by):
    sns.pairplot(
        x_vars=["mean", "variance", "k_burstiness"],
        y_vars=["bs_point", "bf_point"],
        hue=colour_by,
        data=condition_info_df.loc[condition_info_df[valid_flag]],
    )

In [None]:
@widgets.interact(treatment_group=make_treatment_group_selector(), y_scale=make_scale_selector("log"))
def plot_regression_histograms(treatment_group, y_scale):
    gene_info_df = gene_info_map[treatment_group]
    n_columns = len(gene_info_df.columns)
    _, axes = plt.subplots(1, n_columns, figsize=(4 * n_columns, 4))
    for column, ax in zip(gene_info_df.columns, axes):
        ax.hist(
            gene_info_df.loc[:, column].values,
            bins=30,
        )
        ax.set_xlabel(column)
        ax.set_yscale(y_scale)
        
    for label in axes[1].get_xticklabels():
        label.set_rotation(20)
        label.set_ha("right")
    plt.tight_layout()
    plt.show()

In [None]:
@widgets.interact(treatment_group=make_treatment_group_selector(), sort_by=["slope", "intercept", "r2"], ascending=True)
def display_regression_list(treatment_group, sort_by, ascending):
    gene_info_df = gene_info_map[treatment_group].copy()
    gene_info_df.insert(0, "symbol", gene_symbols[gene_info_df.index])
    display(gene_info_df.sort_values(by=sort_by, ascending=ascending))

In [None]:
def create_modulation_stats(condition_df):
    bs_cv, bf_cv = (np.std(a, ddof=1) / np.mean(a)
                    for a in (condition_df.bs_point, condition_df.bf_point))

    bs_lr, bf_lr = (calculate_linear_regression(condition_df, "mean", point)
                    for point in ("bs_point", "bf_point"))

    return pd.Series(data={
        "n_conditions": len(condition_df),
        "bs_cv": bs_cv,
        "bf_cv": bf_cv,
        "bs_slope": bs_lr["slope"],
        "bf_slope": bf_lr["slope"],
    })


@widgets.interact(treatment_group=make_treatment_group_selector(), valid_flag=make_valid_point_flag_selector())
def plot_slope_vs_cv(treatment_group, valid_flag):
    min_valid = 9

    treatments = treatment_groups[treatment_group]
    condition_info_subset = condition_info_df.loc[condition_info_df[valid_flag] & condition_info_df.treatment.isin(treatments)]

    condition_counts = condition_info_subset.gene.value_counts()
    accepted_genes = condition_counts[condition_counts >= min_valid]

    condition_info_subset = condition_info_subset.loc[condition_info_subset.gene.isin(accepted_genes.index)]

    modulation_stats_df = condition_info_subset.groupby("gene").apply(create_modulation_stats)
    modulation_stats_df["cv_ratio"] = modulation_stats_df.bf_cv / modulation_stats_df.bs_cv

    sns.scatterplot(
        x="bs_cv",
        y="bf_cv",
        data=modulation_stats_df,
    )
    ab_point = [0, modulation_stats_df.loc[:, ["bs_cv", "bf_cv"]].max().max()]
    plt.plot(ab_point, ab_point, ":")
    plt.show()

    over_line = modulation_stats_df.cv_ratio > 1
    print(f"{np.count_nonzero(over_line.values):,} with ratio > 1 (versus {len(over_line) - np.count_nonzero(over_line.values)})")

    gene_info_df = gene_info_map[treatment_group]
    modulation_stats_df = modulation_stats_df.join(gene_info_df)

    sns.pairplot(
        vars=["slope", "bs_cv", "bf_cv", "bs_slope", "bf_slope"],
        data=modulation_stats_df,
    )
    plt.show()

    modulation_stats_df.insert(0, "symbol", gene_symbols[modulation_stats_df.index])
    display(modulation_stats_df.loc[over_line].sort_values(by="cv_ratio", ascending=False)[:20])

In [None]:
@widgets.interact(treatment_group=make_treatment_group_selector(), valid_flag=make_valid_point_flag_selector())
def plot_mean_and_slope(treatment_group, valid_flag):
    gene_info_df = gene_info_map[treatment_group]

    joined_df = gene_info_df.reset_index("gene").merge(condition_info_df, on="gene")

    sns.pairplot(
        vars=["mean", "variance", "slope", "bs_point", "bf_point"],
        data=joined_df.loc[joined_df[valid_flag]],
    )
    plt.show()