In [None]:
from collections import namedtuple
from pathlib import Path

import pandas as pd

from rp2 import create_folder, working_directory, hagai_2018, notebooks
from rp2.paths import get_scripts_path, get_txburst_results_csv_path

nb_env, data_proc_nb = notebooks.initialise_environment(
    "Burst_Model_Fitting",
    dependencies=["Data_Processing"],
)

In [None]:
study_species = "mouse"

full_gene_list = data_proc_nb.access_path(f"all_analysis_genes-species={study_species}.txt").read_text().split("\n")
print(f"Wish to fit burst parameters for {len(full_gene_list):,} genes")

In [None]:
umi_counts_ad = hagai_2018.load_umi_counts_with_additional_annotation(study_species)
umi_counts_ad = umi_counts_ad[:, full_gene_list].copy()

In [None]:
BurstParameterSet = namedtuple("BurstParameterSet", ["gene_list", "index_columns", "results_path"])

txburst_param_sets = [
    BurstParameterSet(
        full_gene_list,
        ["replicate", "treatment", "time_point"],
        get_txburst_results_csv_path(study_species),
    ),
    BurstParameterSet(
        full_gene_list,
        ["treatment", "time_point"],
        get_txburst_results_csv_path(study_species, combined_replicates=True),
    ),
]

In [None]:
def write_umi_counts(umi_counts, output_path):
    csv_path = output_path.joinpath("umi_counts.csv")
    umi_counts.to_df().T.to_csv(csv_path, index_label="gene")


def run_txburst_fitting(umi_files_path, execute_commands=True):
    txburst_script_path = get_scripts_path("txburst")

    with working_directory(umi_files_path):
        for full_csv_file_path in umi_files_path.glob("*.csv"):
            csv_file_path = full_csv_file_path.name
            ml_file_path = Path(full_csv_file_path.stem + "_ML.pkl")
            pl_file_path = Path(full_csv_file_path.stem + "_PL.pkl")

            txburst_ml_script_path = txburst_script_path.joinpath("txburstML.py")
            txburst_pl_script_path = txburst_script_path.joinpath("txburstPL.py")

            commands = []
            if not pl_file_path.exists():
                commands.append(f"{txburst_ml_script_path} --njobs 4 {csv_file_path}")
            if not ml_file_path.exists():
                commands.append(f"{txburst_pl_script_path} --njobs 4 --file {csv_file_path} --MLFile {ml_file_path}")

            for cmd in commands:
                print("Executing:", cmd)
                if execute_commands:
                    %run {cmd}


def load_txburst_pkl_file(pkl_path):
    return pd.read_pickle(pkl_path).reset_index("gene")
    

def explode_txburst_columns(df, column_id, new_columns):
    if df.empty:
        for c in new_columns:
            df[c] = None
        return

    df[new_columns] = pd.DataFrame(df.loc[:, column_id].to_list())
    df.drop(columns=column_id, inplace=True)


def collate_txburst_results(pkl_files_path):
    ml_df, pl_df = [load_txburst_pkl_file(pkl_files_path.joinpath(f"umi_counts_{suffix}.pkl"))
                    for suffix in ["ML", "PL"]]

    ml_df.rename(columns={1: "keep"}, inplace=True)

    for df in [ml_df, pl_df]:
        explode_txburst_columns(df, 0, ["k_on", "k_off", "k_syn"])

    explode_txburst_columns(pl_df, 1, ["bf_point", "bf_lower", "bf_upper"])
    explode_txburst_columns(pl_df, 2, ["bs_point", "bs_lower", "bs_upper"])

    index_columns = ["gene"]
    ml_df = ml_df.set_index(index_columns)
    pl_df = pl_df.set_index(index_columns)

    txburst_df = pl_df.join(ml_df.keep, how="outer")
    txburst_df.update(ml_df)
    txburst_df = txburst_df.reset_index()

    return txburst_df


def get_genes_with_results(txburst_df, index_values):
    for column, value in index_values:
        txburst_df = txburst_df.loc[txburst_df[column] == value]
    return txburst_df.gene



for txburst_param_set in txburst_param_sets:
    print(f"Processing {txburst_param_set.results_path}")

    results_df = pd.read_csv(txburst_param_set.results_path, float_precision="round_trip")
    for index_column in txburst_param_set.index_columns:
        results_df[index_column] = results_df[index_column].astype(str)

    sort_columns = [c for c in ["gene", "replicate", "time_point", "treatment"] if c in results_df.columns]

    full_gene_set = set(txburst_param_set.gene_list)
    print(f"  {len(full_gene_set):,} genes per condition")

    for index_values, index_df in umi_counts_ad.obs.groupby(txburst_param_set.index_columns):
        genes_with_results = get_genes_with_results(results_df, zip(txburst_param_set.index_columns, index_values))
        required_genes = list(full_gene_set.difference(genes_with_results))
        print(f"  {len(required_genes):,} genes required for condition {index_values}")
        if len(required_genes) == 0:
            continue

        tmp_folder_name = "-".join(f"{n}={v}" for n, v in zip(txburst_param_set.index_columns, index_values))
        tmp_path = nb_env.get_intermediate_path("txburst", txburst_param_set.results_path.stem, tmp_folder_name)
        create_folder(tmp_path, create_clean=True)

        umi_count_subset = umi_counts_ad[index_df.index, required_genes]
        write_umi_counts(umi_count_subset, tmp_path)

        run_txburst_fitting(tmp_path)

        txburst_df = collate_txburst_results(tmp_path)

        for loc, (column, value) in enumerate(zip(txburst_param_set.index_columns, index_values), start=1):
            txburst_df.insert(loc, column, value)

        results_df = results_df.append(txburst_df)
        results_df = results_df.sort_values(by=sort_columns)
        results_df.to_csv(txburst_param_set.results_path, index=False)

print("All done")