In [None]:
from pathlib import Path

In [None]:
import anndata
import pandas as pd
from IPython.core.display import display

from rp2 import create_folder, get_data_path, get_output_path, get_scripts_path, working_directory, GeneSymbolMap

In [None]:
study_species = "mouse"

umi_count_ad = anndata.read_h5ad(get_data_path("ArrayExpress", f"E-MTAB-6754.processed.2.{study_species}.h5ad"))

In [None]:
gene_symbols_df = pd.read_table(
    get_data_path("BioMart", f"{study_species}_genes.tsv"),
    names=["id", "symbol", "description"],
    index_col=0
)
symbol_map = GeneSymbolMap(gene_symbols_df)

In [None]:
maximum_responsive_gene_padj = 0.01

phagocyte_genes_df = pd.read_excel(
    get_data_path("hagai_2018", "41586_2018_657_MOESM4_ESM.xlsx"),
    sheet_name="phagocytes_FC_diveregnce"
)
responsive_phagocyte_genes = phagocyte_genes_df.loc[phagocyte_genes_df[f"{study_species}_padj"] < maximum_responsive_gene_padj].gene
print(f"{len(responsive_phagocyte_genes):,} responsive phagocyte genes")

In [None]:
gene_sets = {
    f"{study_species}_responsive_genes": responsive_phagocyte_genes
}

txburst_output_path = get_output_path("txburst")

for gene_set_name, genes in gene_sets.items():
    gene_set_output_path = txburst_output_path.joinpath(gene_set_name)
    if gene_set_output_path.exists():
        print("Skipping:", gene_set_output_path.name)
        continue

    create_folder(gene_set_output_path)

    for (replicate, treatment, time_point), df in umi_count_ad.obs.groupby(["replicate", "treatment", "time_point"]):
        csv_path = gene_set_output_path.joinpath(f"{study_species}_umi_{replicate}_{treatment}_{time_point}.csv")
        print("Writing:", csv_path.name)
        subset_ad = umi_count_ad[df.index.values, genes]
        subset_ad.to_df().T.to_csv(csv_path, index_label="gene")   

In [None]:
txburst_script_path = get_scripts_path("txburst")

run_txburst = True

for sub_folder in filter(Path.is_dir, txburst_output_path.iterdir()):
    output_csv_path = sub_folder.with_suffix(".csv")
    if output_csv_path.exists():
        print("Skipping:", sub_folder.name)
        continue

    with working_directory(sub_folder):
        for full_csv_file_path in sub_folder.glob("*.csv"):
            csv_file_path = full_csv_file_path.name
            ml_file_path = Path(full_csv_file_path.stem + "_ML.pkl")
            pl_file_path = Path(full_csv_file_path.stem + "_PL.pkl")

            txburst_ml_script_path = txburst_script_path.joinpath("txburstML.py")
            txburst_pl_script_path = txburst_script_path.joinpath("txburstPL.py")

            commands = []
            if not pl_file_path.exists():
                commands.append(f"{txburst_ml_script_path} --njobs 4 {csv_file_path}")
            if not ml_file_path.exists():
                commands.append(f"{txburst_pl_script_path} --njobs 4 --file {csv_file_path} --MLFile {ml_file_path}")

            for cmd in commands:
                print("Executing:", cmd)
                if run_txburst:
                    %run {cmd}

    txburst_df = []

    for pl_path in sub_folder.glob("*_PL.pkl"):
        replicate, treatment, time_point = pl_path.stem.split("_")[2:5]
        pl_df = pd.read_pickle(pl_path)

        condition_df = pd.DataFrame(data={
            "gene": pl_df.index,
            "replicate": replicate,
            "treatment": treatment,
            "time_point": time_point,
        })
        condition_df[["k_on", "k_off", "k_syn"]] = pd.DataFrame(pl_df.iloc[:, 0].to_list())
        condition_df[["bf_point", "bf_lower", "bf_upper"]] = pd.DataFrame(pl_df.iloc[:, 1].to_list())
        condition_df[["bs_point", "bs_lower", "bs_upper"]] = pd.DataFrame(pl_df.iloc[:, 2].to_list())

        txburst_df.append(condition_df)

    txburst_df = pd.concat(txburst_df, ignore_index=True).sort_values(by=["gene", "replicate", "time_point", "treatment"])
    txburst_df.to_csv(output_csv_path)
    display(symbol_map.added_to(txburst_df))

    totals_df = txburst_df.gene.value_counts().sort_values(ascending=False).to_frame("total")
    totals_df.index.name = "gene"
    display(symbol_map.added_to(totals_df))