In [None]:
from pathlib import Path

import numpy as np
import pandas as pd

from rp2 import create_folder, working_directory, hagai_2018, notebooks
from rp2.paths import get_scripts_path, get_txburst_results_csv_path, get_data_path

nb_env, data_proc_nb = notebooks.initialise_environment(
    "Burst_Model_Fitting",
    dependencies=["Data_Processing"],
)

In [None]:
study_species = "mouse"

analysis_gene_ids = data_proc_nb.access_path(f"all_analysis_genes-species={study_species}.txt").read_text().split("\n")
print(f"Wish to fit burst parameters for {len(analysis_gene_ids):,} genes")

In [None]:
def load_mouse_orthologues():
    df = pd.read_table(
        get_data_path("BioMart", "mouse_orthologues.tsv"),
        names=["mouse_gene", "pig_gene", "rabbit_gene", "rat_gene"],
        index_col=0,
    )
    df = df.dropna(axis=0)
    df = df.loc[~df.index.duplicated(False)]
    return df


orthologues_df = load_mouse_orthologues()

In [None]:
class BurstParameterSet:
    def __init__(self, gene_list, index_columns, count_type, species=study_species):
        self.species = species
        self.gene_list = gene_list
        self.index_columns = index_columns
        self.count_type = count_type

    def load_counts(self):
        counts_adata = hagai_2018.load_counts(self.species, scaling=self.count_type)
        counts_adata = counts_adata[:, self.gene_list].copy()
        return counts_adata

    @property
    def counts_key(self):
        return self.species + "-" + self.count_type

    @property
    def results_path(self):
        return get_txburst_results_csv_path(self.species, self.index_columns, self.count_type)


txburst_param_sets = []

for count_type in ["umi", "median"]:
    txburst_param_sets.append(BurstParameterSet(
        analysis_gene_ids,
        ["replicate", "treatment", "time_point"],
        count_type,
    ))

for species in ["pig", "rabbit", "rat"]:
    txburst_param_sets.append(BurstParameterSet(
        orthologues_df.loc[orthologues_df.index.isin(analysis_gene_ids), f"{species}_gene"],
        ["replicate", "treatment", "time_point"],
        "median",
        species=species,
    ))

txburst_param_sets.append(BurstParameterSet(
    analysis_gene_ids,
    ["treatment", "time_point"],
    "umi",
))

In [None]:
def write_counts(counts_adata, output_path):
    csv_path = output_path.joinpath("counts.csv")
    counts_adata.to_df().T.to_csv(csv_path, index_label="gene")


def run_txburst_fitting(count_files_path, execute_commands=True):
    txburst_script_path = get_scripts_path("txburst")

    with working_directory(count_files_path):
        for full_csv_file_path in count_files_path.glob("*.csv"):
            csv_file_path = full_csv_file_path.name
            ml_file_path = Path(full_csv_file_path.stem + "_ML.pkl")
            pl_file_path = Path(full_csv_file_path.stem + "_PL.pkl")

            txburst_ml_script_path = txburst_script_path.joinpath("txburstML.py")
            txburst_pl_script_path = txburst_script_path.joinpath("txburstPL.py")

            common_options = "--njobs 4"
            commands = []
            if not pl_file_path.exists():
                commands.append(f"{txburst_ml_script_path} {common_options} {csv_file_path}")
            if not ml_file_path.exists():
                commands.append(f"{txburst_pl_script_path} {common_options} --file {csv_file_path} --MLFile {ml_file_path}")

            for cmd in commands:
                print("Executing:", cmd)
                if execute_commands:
                    %run {cmd}


def load_txburst_pkl_file(pkl_path):
    return pd.read_pickle(pkl_path).reset_index("gene")
    

def explode_txburst_columns(df, column_id, new_columns):
    if df.empty:
        for c in new_columns:
            df[c] = None
        return

    df[new_columns] = pd.DataFrame(df.loc[:, column_id].to_list())
    df.drop(columns=column_id, inplace=True)


def collate_txburst_results(pkl_files_path):
    ml_df, pl_df = [load_txburst_pkl_file(pkl_files_path.joinpath(f"counts_{suffix}.pkl"))
                    for suffix in ["ML", "PL"]]

    ml_df.rename(columns={1: "keep"}, inplace=True)

    for df in [ml_df, pl_df]:
        explode_txburst_columns(df, 0, ["k_on", "k_off", "k_syn"])

    explode_txburst_columns(pl_df, 1, ["bf_point", "bf_lower", "bf_upper"])
    explode_txburst_columns(pl_df, 2, ["bs_point", "bs_lower", "bs_upper"])

    ml_df = ml_df.set_index("gene")
    pl_df = pl_df.set_index("gene")

    txburst_df = pl_df.join(ml_df.keep, how="outer")
    txburst_df.update(ml_df)
    txburst_df = txburst_df.reset_index()

    return txburst_df


def get_genes_with_results(txburst_df, index_values):
    for column, value in index_values:
        txburst_df = txburst_df.loc[txburst_df[column] == value]
    return txburst_df.gene


counts_adata_map = {}

for txburst_param_set in txburst_param_sets:
    print(f"Processing {txburst_param_set.results_path}")

    if txburst_param_set.counts_key in counts_adata_map:
        counts_adata = counts_adata_map[txburst_param_set.counts_key]
    else:
        print(f'  Loading "{txburst_param_set.counts_key}" counts')
        counts_adata = txburst_param_set.load_counts()
        counts_adata_map[txburst_param_set.counts_key] = counts_adata

    if txburst_param_set.results_path.exists():
        results_df = pd.read_csv(txburst_param_set.results_path, float_precision="round_trip")
    else:
        results_df = pd.DataFrame(columns=["gene", "replicate", "treatment", "time_point", "k_on", "k_off", "k_syn", "bf_point", "bf_lower", "bf_upper", "bs_point", "bs_lower", "bs_upper", "keep"])

    for index_columns in txburst_param_set.index_columns:
        results_df[index_columns] = results_df[index_columns].astype(str)

    sort_columns = [c for c in ["gene", "replicate", "time_point", "treatment"] if c in results_df.columns]

    full_gene_set = set(txburst_param_set.gene_list)
    print(f"  {len(full_gene_set):,} genes per condition")

    for index_values, index_df in counts_adata.obs.groupby(txburst_param_set.index_columns):
        genes_with_results = get_genes_with_results(results_df, zip(txburst_param_set.index_columns, index_values))
        required_genes = list(full_gene_set.difference(genes_with_results))
        print(f"  {len(required_genes):,} genes required for condition {index_values}")
        if len(required_genes) == 0:
            continue

        tmp_folder_name = "-".join(f"{n}={v}" for n, v in zip(txburst_param_set.index_columns, index_values))
        tmp_path = nb_env.get_intermediate_path("txburst", txburst_param_set.results_path.stem, tmp_folder_name)
        create_folder(tmp_path, create_clean=True)

        counts_subset = counts_adata[index_df.index, required_genes]
        write_counts(counts_subset, tmp_path)

        run_txburst_fitting(tmp_path)

        txburst_df = collate_txburst_results(tmp_path)

        for loc, (column, value) in enumerate(zip(txburst_param_set.index_columns, index_values), start=1):
            txburst_df.insert(loc, column, value)

        results_df = results_df.append(txburst_df)
        results_df = results_df.sort_values(by=sort_columns)
        results_df.to_csv(txburst_param_set.results_path, index=False)

    n_results = len(results_df)
    n_keep = np.count_nonzero(results_df.keep)
    print(f"  Summary:")
    print(f"    {n_results} results")
    print(f"    {n_keep} ({(n_keep / n_results) * 100:.0f}%) flagged to keep")

print("All done")