In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import rp2
from rp2 import hagai_2018
from rp2.paths import get_txburst_results_csv_path

rp2.check_environment()

In [None]:
species = "mouse"

txburst_params_df = pd.read_csv(get_txburst_results_csv_path(species))
txburst_params_df.replicate = txburst_params_df.replicate.astype(str)

for param in ["bf", "bs"]:
    txburst_params_df[f"valid_{param}_point"] = txburst_params_df[f"{param}_point"].notna()
    txburst_params_df[f"valid_{param}_interval"] = txburst_params_df[f"{param}_lower"].notna() & txburst_params_df[f"{param}_upper"].notna()
txburst_params_df["valid_points"] = txburst_params_df.valid_bs_point & txburst_params_df.valid_bf_point
txburst_params_df["valid_intervals"] = txburst_params_df.valid_bs_interval & txburst_params_df.valid_bf_interval
txburst_params_df["valid_params"] = txburst_params_df.valid_points & txburst_params_df.valid_intervals
txburst_params_df["k_burstiness"] = txburst_params_df.k_off / txburst_params_df.k_on

In [None]:
umi_counts_adata = hagai_2018.load_umi_counts_with_additional_annotation(species)
umi_stats_df = hagai_2018.calculate_umi_condition_stats(umi_counts_adata)

In [None]:
index_columns = ["gene", "replicate", "treatment", "time_point"]

condition_info_df = txburst_params_df.set_index(index_columns).join(umi_stats_df.set_index(index_columns)).reset_index()

all_treatments = condition_info_df.treatment.sort_values().unique().tolist()

In [None]:
@widgets.interact(y_scale=["linear", "log"])
def plot_per_gene_condition_frequency(y_scale):
    n_treatments = len(all_treatments)
    _, axes = plt.subplots(ncols=n_treatments, figsize=(n_treatments * 5, 5), sharey=True)
    for treatment, ax in zip(all_treatments, axes):
        counts = condition_info_df.loc[condition_info_df.treatment == treatment].groupby("gene")[["keep", "valid_intervals"]].agg(np.count_nonzero).melt()
        sns.countplot(
            x="value",
            hue="variable",
            ax=ax,
            data=counts,
        )
        ax.set_title(treatment)
        ax.set_xlabel("No. conditions")
        ax.set_ylabel("No. genes")
        ax.set_yscale(y_scale)
        ax.legend(loc="upper right")
    plt.tight_layout()
    plt.show()

In [None]:
@widgets.interact(valid_flag=["valid_points", "valid_intervals"], y_scale=widgets.Combobox(opeions=["linear", "log"], value="log"))
def plot_(valid_flag, y_scale):
    vars = ["mean", "min", "max", "std_dev", "n_barcodes"]
    n_vars = len(vars)
    _, axes = plt.subplots(1, n_vars, figsize=(4 * n_vars, 4))
    for v, ax in zip(vars, axes):
        ax.boxplot(
            [condition_info_df.loc[condition_info_df[valid_flag], v],
             condition_info_df.loc[~condition_info_df[valid_flag], v]],
            labels=["True", "False"],
        )
        ax.set_ylabel(v)
        ax.set_yscale(y_scale)
    plt.tight_layout()
    plt.show()
    widgets.interact()

In [None]:
@widgets.interact(valid_flag=["valid_points", "valid_intervals"])
def plot_burst_size_vs_frequency(valid_flag):
    sns.jointplot(
        x="bs_point",
        y="bf_point",
        data=condition_info_df.loc[condition_info_df[valid_flag]],
    )
    plt.show()

In [None]:
@widgets.interact(valid_flag=["valid_points", "valid_intervals"], colour_by=["replicate", "treatment", "time_point", None])
def plot_burst_param_pairs(valid_flag, colour_by):
    sns.pairplot(
        vars=["k_on", "k_off", "k_syn", "bs_point", "bf_point"],
        hue=colour_by,
        data=condition_info_df.loc[condition_info_df[valid_flag]],
    )

In [None]:
@widgets.interact(valid_flag=["valid_points", "valid_intervals"], colour_by=["replicate", "treatment", "time_point"])
def plot_(valid_flag, colour_by):
    grid = sns.pairplot(
        x_vars=["mean", "k_burstiness"],
        y_vars=["bs_point", "bf_point"],
        hue=colour_by,
        data=condition_info_df.loc[condition_info_df[valid_flag]],
    )
