In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display

import rp2.data
from rp2 import hagai_2018

rp2.check_environment()

In [None]:
species = "mouse"
count_type = "median"
condition_columns = ["replicate", "treatment", "time_point"]
index_columns = ["gene"] + condition_columns

In [None]:
gene_symbol_df = rp2.load_biomart_gene_symbols_df(species)

In [None]:
def load_txburst_results(species, condition_columns, count_type):
    df = rp2.data.load_txburst_results(species, condition_columns, count_type)
    valid_points = df.bs_point.notna() & df.bf_point.notna()
    valid_time = df.time_point != "6A"
    return df.loc[valid_points & valid_time]


def load_modified_txburst_results(species, condition_columns, count_type, two_alleles):
    df = rp2.data.load_modified_txburst_results(species, condition_columns, count_type, two_alleles=two_alleles)
    df = df.loc[df.time_point != "6A"]
    return df


def load_ppfit_results(species, condition_columns, count_type):
    csv_path = rp2.paths.get_burst_model_csv_path("ppfit", species, condition_columns, count_type)
    dtype_dict = {column: "category" for column in condition_columns}

    df = pd.read_csv(csv_path, dtype=dtype_dict)

    df["bs_point"] = df.k_syn / df.k_off
    df["bf_point"] = (2 * df.k_on * df.k_off) / (df.k_on + df.k_off) / df.k_deg
    df["bf_point_2"] = (2 * df.k_on) / df.k_deg

    return df


model_df_map = {
    "ppfit": load_ppfit_results(species, condition_columns, count_type),
    "txburst": load_txburst_results(species, condition_columns, count_type),
    "txburst_mod1": load_modified_txburst_results(species, condition_columns, count_type, two_alleles=False),
    "txburst_mod2": load_modified_txburst_results(species, condition_columns, count_type, two_alleles=True),
}

In [None]:
def clean_model(model_name, model_df):
    values_to_remove = [0, np.inf, -np.inf]
    columns_to_check = ["bs_point", "bf_point"]

    print(f'Model "{model_name}":')
    print(f"  Contains {len(model_df):,} conditions for {model_df.gene.nunique():,} genes before cleaning")

    for column in columns_to_check:
        for value in values_to_remove:
            n_values = np.count_nonzero(model_df[column] == value)
            if n_values > 0:
                print(f"    with {n_values:,} {column} values of {value}")
                model_df[column].replace(value, np.nan, inplace=True)

    model_df.dropna(subset=columns_to_check, axis=0, how="any", inplace=True)
    print(f"  Contains {len(model_df):,} conditions for {model_df.gene.nunique():,} genes after cleaning")


for model_name, model_df in model_df_map.items():
    clean_model(model_name, model_df)

In [None]:
def calculate_condition_stats(species, count_type, genes):
    count_data = hagai_2018.load_counts(species, count_type)[:, genes]
    return hagai_2018.calculate_counts_condition_stats(count_data)    


common_genes = gene_symbol_df.symbol[set.intersection(*[set(df.gene) for df in model_df_map.values()])].sort_values()
condition_stats_df = calculate_condition_stats(species, count_type, common_genes.index)

In [None]:
def copy_df_with_column_prefix(df, prefix):
    df = df.copy()
    df.columns = [c if c in index_columns else f"{prefix}_{c}" for c in df.columns]
    return df


model_info_df = condition_stats_df.sort_values(by=["gene", "replicate", "time_point", "treatment"]).set_index(index_columns)
for model_name, model_df in model_df_map.items():
    model_info_df = model_info_df.join(copy_df_with_column_prefix(model_df, model_name).set_index(index_columns))

model_info_df.reset_index(inplace=True)

In [None]:
def plot_model(gene_id, models, of_cutoff):
    model_names = models
    point_types = ["bs", "bf"]
    colours = ["black", "white", "red", "green"]
    plot_outputs = [widgets.Output() for _ in point_types]
    info_output = widgets.Output()

    model_info_subset = model_info_df.loc[model_info_df.gene == gene_id].copy()
    model_info_subset.loc[model_info_subset.ppfit_of > of_cutoff, [f"ppfit_{point_type}_point" for point_type in point_types]] = np.nan

    for point_type, output in zip(point_types, plot_outputs):
        with output:
            _, ax = plt.subplots()

            for model_name, colour in zip(model_names, colours):
                model_info_subset.plot.scatter("mean", f"{model_name}_{point_type}_point", label=model_name, ax=ax, s=40, color=colour, edgecolor="black")

            plt.ylabel(point_type)
            plt.legend()
            plt.show()

    with info_output:
        for model_name in model_names:
            print(f"{model_name} points:", model_info_subset[f"{model_name}_bs_point"].count())

    display(widgets.HBox(plot_outputs + [info_output]))


widgets.interactive(
    plot_model,
    gene_id=widgets.Select(options=list(zip(common_genes.values, common_genes.index)), rows=4),
    models=widgets.SelectMultiple(options=model_df_map.keys(), index=[0, 1]),
    of_cutoff=widgets.FloatLogSlider(value=0.01, min=-4, max=0),
)