In [None]:
import ipywidgets as widgets
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display

import rp2.data
from rp2 import hagai_2018


rp2.check_environment()

In [None]:
species = "mouse"
count_type = "median"
index_columns = ["replicate", "treatment", "time_point"]
gene_index_columns = ["gene"] + index_columns

In [None]:
gene_symbol_df = rp2.load_biomart_gene_symbols_df(species)

In [None]:
def load_ppfit_results(species, index_columns, count_type):
    csv_path = rp2.paths.get_burst_model_csv_path("ppfit", species, index_columns, count_type)
    dtype_dict = {column: "category" for column in index_columns}

    df = pd.read_csv(csv_path, dtype=dtype_dict)

    df["bs_point"] = df.k_syn / df.k_off
    df["bf_point"] = (2 * df.k_on * df.k_off) / (df.k_on + df.k_off) / df.k_deg
    df["bf_point_2"] = (2 * df.k_on) / df.k_deg

    return df


def calculate_condition_stats(species, count_type, genes):
    count_data = hagai_2018.load_counts(species, count_type)[:, genes]
    return hagai_2018.calculate_counts_condition_stats(count_data)    

In [None]:
model_df_map = {
    "txburst": rp2.data.load_txburst_results(species, index_columns, count_type),
    "ppfit": load_ppfit_results(species, index_columns, count_type),
}

In [None]:
for model_name, model_df in model_df_map.items():
    print(f'Model "{model_name}":')
    print(f"  Contains {len(model_df):,} conditions for {model_df.gene.nunique():,} genes")

In [None]:
common_genes = gene_symbol_df.symbol[set.intersection(*[set(df.gene) for df in model_df_map.values()])].sort_values()
condition_stats_df = calculate_condition_stats(species, count_type, common_genes.index)

In [None]:
def copy_df_with_column_prefix(df, prefix):
    df = df.copy()
    df.columns = [c if c in gene_index_columns else f"{prefix}_{c}" for c in df.columns]
    return df


model_info_df = condition_stats_df.sort_values(by=["gene", "replicate", "time_point", "treatment"]).set_index(gene_index_columns)
for model_name, model_df in model_df_map.items():
    model_info_df = model_info_df.join(copy_df_with_column_prefix(model_df, model_name).set_index(gene_index_columns))

model_info_df.reset_index(inplace=True)

In [None]:
def plot_model(gene_id, of_cutoff):
    model_names = model_df_map.keys()
    point_types = ["bs", "bf"]
    colours = ["black", "white"]
    outputs = widgets.HBox([widgets.Output() for _ in point_types])

    model_info_subset = model_info_df.loc[model_info_df.gene == gene_id].copy()
    model_info_subset.loc[model_info_subset.ppfit_of > of_cutoff, [f"ppfit_{point_type}_point" for point_type in point_types]] = np.nan

    for point_type, output in zip(point_types, outputs.children):
        with output:
            _, ax = plt.subplots()

            for model_name, colour in zip(model_names, colours):
                model_info_subset.plot.scatter("mean", f"{model_name}_{point_type}_point", label=model_name, ax=ax, color=colour, edgecolor="black")

            plt.ylabel(point_type)
            plt.legend()
            plt.show()

    display(outputs)


widgets.interactive(
    plot_model,
    gene_id=widgets.Select(options=list(zip(common_genes.values, common_genes.index)), rows=4),
    of_cutoff=widgets.FloatLogSlider(value=0.01, min=-4, max=0),
)