In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.linear_model import LinearRegression

import rp2
from rp2 import hagai_2018
from rp2.paths import get_txburst_results_csv_path

rp2.check_environment()

In [None]:
def load_txburst_params(species):
    params_df = pd.read_csv(get_txburst_results_csv_path(species))
    params_df.replicate = params_df.replicate.astype(str)

    for param in ["bf", "bs"]:
        params_df[f"valid_{param}_point"] = params_df[f"{param}_point"].notna()
        params_df[f"valid_{param}_interval"] = params_df[f"{param}_lower"].notna() & params_df[
            f"{param}_upper"].notna()
    params_df["valid_points"] = params_df.valid_bs_point & params_df.valid_bf_point
    params_df["valid_intervals"] = params_df.valid_bs_interval & params_df.valid_bf_interval
    params_df["valid_params"] = params_df.valid_points & params_df.valid_intervals
    params_df["k_burstiness"] = params_df.k_off / params_df.k_on

    return params_df


notebook_species = "mouse"

txburst_params_df = load_txburst_params(notebook_species)

In [None]:
umi_counts_adata = hagai_2018.load_umi_counts_with_additional_annotation(notebook_species)
umi_counts_adata = umi_counts_adata[:, txburst_params_df.gene.unique()].copy()
umi_stats_df = hagai_2018.calculate_umi_condition_stats(umi_counts_adata)

gene_symbols = umi_counts_adata.var.symbol

In [None]:
index_columns = ["gene", "replicate", "treatment", "time_point"]

condition_info_df = txburst_params_df.set_index(index_columns).join(umi_stats_df.set_index(index_columns)).reset_index()

all_treatments = condition_info_df.treatment.sort_values().unique().tolist()

In [None]:
def create_gene_regression(condition_df):
    lr_x, lr_y = condition_df.loc[:, ["mean", "variance"]].to_numpy().reshape(1, -1, 2).T
    lr = LinearRegression().fit(lr_x, lr_y)

    return pd.Series(data={
        "slope": lr.coef_.item(),
        "intercept": lr.intercept_.item(),
        "r2": lr.score(lr_x, lr_y),
    })


def create_gene_info(treatments):
    condition_info_subset = condition_info_df.loc[condition_info_df.treatment.isin(treatments)]
    gene_info_df = condition_info_subset.groupby("gene").apply(create_gene_regression)
    return gene_info_df


treatment_groups = {
    "LPS only": ["unst", "lps"],
    "PIC only": ["unst", "pic"],
    "LPS and PIC": ["unst", "lps", "pic"],
}

gene_info_map = {name: create_gene_info(treatments) for name, treatments in treatment_groups.items()}

In [None]:
@widgets.interact(y_scale=["linear", "log"])
def plot_per_gene_condition_frequencies(y_scale):
    n_treatments = len(all_treatments)
    _, axes = plt.subplots(ncols=n_treatments, figsize=(n_treatments * 5, 5), sharey=True)
    for treatment, ax in zip(all_treatments, axes):
        counts = condition_info_df.loc[condition_info_df.treatment == treatment].groupby("gene")[["keep", "valid_intervals"]].agg(np.count_nonzero).melt()
        sns.countplot(
            x="value",
            hue="variable",
            ax=ax,
            data=counts,
        )
        ax.set_title(treatment)
        ax.set_xlabel("No. conditions")
        ax.set_ylabel("No. genes")
        ax.set_yscale(y_scale)
        ax.legend(loc="upper right")
    plt.tight_layout()
    plt.show()

In [None]:
@widgets.interact(valid_flag=["valid_points", "valid_intervals"], y_scale=widgets.Combobox(opeions=["linear", "log"], value="log"))
def plot_condition_fit_successes(valid_flag, y_scale):
    vars = ["mean", "min", "max", "std_dev", "n_barcodes"]
    n_vars = len(vars)
    _, axes = plt.subplots(1, n_vars, figsize=(4 * n_vars, 4))
    for v, ax in zip(vars, axes):
        ax.boxplot(
            [condition_info_df.loc[condition_info_df[valid_flag], v],
             condition_info_df.loc[~condition_info_df[valid_flag], v]],
            labels=["True", "False"],
        )
        ax.set_ylabel(v)
        ax.set_yscale(y_scale)
    plt.tight_layout()
    plt.show()

In [None]:
@widgets.interact(valid_flag=["valid_points", "valid_intervals"])
def plot_burst_size_vs_frequency(valid_flag):
    sns.jointplot(
        x="bs_point",
        y="bf_point",
        data=condition_info_df.loc[condition_info_df[valid_flag]],
    )
    plt.show()

In [None]:
@widgets.interact(valid_flag=["valid_points", "valid_intervals"], colour_by=["replicate", "treatment", "time_point", None])
def plot_burst_param_pairs(valid_flag, colour_by):
    sns.pairplot(
        vars=["k_on", "k_off", "k_syn", "bs_point", "bf_point"],
        hue=colour_by,
        data=condition_info_df.loc[condition_info_df[valid_flag]],
    )

In [None]:
@widgets.interact(valid_flag=["valid_points", "valid_intervals"], colour_by=["replicate", "treatment", "time_point"])
def plot_condition_params_vs_burst_params(valid_flag, colour_by):
    sns.pairplot(
        x_vars=["mean", "k_burstiness"],
        y_vars=["bs_point", "bf_point"],
        hue=colour_by,
        data=condition_info_df.loc[condition_info_df[valid_flag]],
    )



In [None]:
@widgets.interact(treatment_group=gene_info_map.keys(), y_scale=["linear", "log"])
def plot_regression_histograms(treatment_group, y_scale):
    gene_info_df = gene_info_map[treatment_group]
    n_columns = len(gene_info_df.columns)
    _, axes = plt.subplots(1, n_columns, figsize=(4 * n_columns, 4))
    for column, ax in zip(gene_info_df.columns, axes):
        ax.hist(
            gene_info_df.loc[:, column].values,
            bins=30,
        )
        ax.set_xlabel(column)
        ax.set_yscale(y_scale)
        
    for label in axes[1].get_xticklabels():
        label.set_rotation(20)
        label.set_ha("right")
    plt.tight_layout()
    plt.show()

In [None]:
@widgets.interact(treatment_group=gene_info_map.keys(), sort_by=["slope", "intercept", "r2"], ascending=True)
def display_regression_list(treatment_group, sort_by, ascending):
    gene_info_df = gene_info_map[treatment_group].copy()
    gene_info_df.insert(0, "symbol", gene_symbols[gene_info_df.index])
    display(gene_info_df.sort_values(by=sort_by, ascending=ascending))