In [None]:
from pathlib import Path

In [None]:
import anndata
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.core.display import display

from rp2 import hagai_2018, create_folder, working_directory, create_gene_symbol_map
from rp2.paths import get_data_path, get_output_path, get_scripts_path

In [None]:
study_species = "mouse"

umi_count_ad = anndata.read_h5ad(get_data_path("ArrayExpress", f"E-MTAB-6754.processed.2.{study_species}.h5ad"))

In [None]:
symbol_map = create_gene_symbol_map(study_species)

In [None]:
responsive_phagocyte_genes = hagai_2018.read_lps_responsive_genes()
print(f"{len(responsive_phagocyte_genes):,} responsive phagocyte genes")

In [None]:
def write_umi_subsets(output_path, obs_groups):
    create_folder(output_path, create_clean=True)

    for group_values, df in umi_count_ad.obs.groupby(obs_groups):
        suffix = "_".join(group_values)
        csv_path = output_path.joinpath(f"{study_species}_umi_{suffix}.csv")
        print("Writing:", csv_path.name)
        subset_ad = umi_count_ad[df.index.values, genes]
        subset_ad.to_df().T.to_csv(csv_path, index_label="gene")

    output_path.joinpath("groups.txt").write_text("\t".join(obs_groups))


gene_sets = {
    f"{study_species}_responsive_genes": responsive_phagocyte_genes,
}

txburst_output_path = get_output_path("txburst")
create_folder(txburst_output_path)

for gene_set_name, genes in gene_sets.items():
    all_sets = (
        (gene_set_name, ["replicate", "treatment", "time_point"]),
        (gene_set_name + "_combined_replicates", ["treatment", "time_point"]),
    )

    for set_name, umi_groups in all_sets:
        gene_set_output_path = txburst_output_path.joinpath(set_name)
        if gene_set_output_path.exists():
            print("Skipping:", set_name)
            continue

        write_umi_subsets(gene_set_output_path, umi_groups)

In [None]:
def run_txburst_fitting(umi_files_path, execute_commands=True):
    txburst_script_path = get_scripts_path("txburst")

    with working_directory(umi_files_path):
        for full_csv_file_path in umi_files_path.glob("*.csv"):
            csv_file_path = full_csv_file_path.name
            ml_file_path = Path(full_csv_file_path.stem + "_ML.pkl")
            pl_file_path = Path(full_csv_file_path.stem + "_PL.pkl")

            txburst_ml_script_path = txburst_script_path.joinpath("txburstML.py")
            txburst_pl_script_path = txburst_script_path.joinpath("txburstPL.py")

            commands = []
            if not pl_file_path.exists():
                commands.append(f"{txburst_ml_script_path} --njobs 4 {csv_file_path}")
            if not ml_file_path.exists():
                commands.append(f"{txburst_pl_script_path} --njobs 4 --file {csv_file_path} --MLFile {ml_file_path}")

            for cmd in commands:
                print("Executing:", cmd)
                if execute_commands:
                    %run {cmd}


def collate_txburst_results(pkl_files_path):
    umi_groups = pkl_files_path.joinpath("groups.txt").read_text().split("\t")

    txburst_df_list = []

    for pl_path in pkl_files_path.glob("*_PL.pkl"):
        umi_group_values = pl_path.stem[:-3].split("_")[-len(umi_groups):]
        pl_df = pd.read_pickle(pl_path)

        condition_df = pd.DataFrame(data={"gene": pl_df.index})
        for col_name, col_value in zip(umi_groups, umi_group_values):
            condition_df[col_name] = col_value
        condition_df[["k_on", "k_off", "k_syn"]] = pd.DataFrame(pl_df.iloc[:, 0].to_list())
        condition_df[["bf_point", "bf_lower", "bf_upper"]] = pd.DataFrame(pl_df.iloc[:, 1].to_list())
        condition_df[["bs_point", "bs_lower", "bs_upper"]] = pd.DataFrame(pl_df.iloc[:, 2].to_list())

        txburst_df_list.append(condition_df)

    txburst_df = pd.concat(txburst_df_list, ignore_index=True)

    sort_columns = [c for c in ["gene", "replicate", "time_point", "treatment"] if c in txburst_df.columns]
    return txburst_df.sort_values(by=sort_columns)


for sub_folder in filter(Path.is_dir, txburst_output_path.iterdir()):
    collated_csv_path = sub_folder.with_suffix(".csv")
    if collated_csv_path.exists():
        print("Skipping:", sub_folder.name)
        txburst_df = pd.read_csv(collated_csv_path)
    else:
        run_txburst_fitting(sub_folder)
        txburst_df = collate_txburst_results(sub_folder)
        txburst_df.to_csv(collated_csv_path, index=False)

    display(symbol_map.added_to(txburst_df))

    totals_df = txburst_df.gene.value_counts().sort_values(ascending=False).to_frame("n_conditions")
    totals_df = totals_df.sort_values(by="n_conditions", ascending=False)
    totals_df.index.name = "gene"
    display(symbol_map.added_to(totals_df))
    sns.countplot(
        x="n_conditions",
        data=totals_df,
    )
    plt.show()