In [None]:
from pathlib import Path

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from IPython.display import display

from rp2 import create_folder, working_directory, create_gene_symbol_map, hagai_2018
from rp2.paths import get_output_path, get_scripts_path

In [None]:
study_species = "mouse"

umi_count_ad = hagai_2018.load_umi_count(study_species)

In [None]:
symbol_map = create_gene_symbol_map(study_species)

In [None]:
responsive_phagocyte_genes = hagai_2018.load_lps_responsive_genes()
print(f"{len(responsive_phagocyte_genes):,} responsive phagocyte genes")

In [None]:
def write_umi_subsets(output_path, obs_groups):
    create_folder(output_path, create_clean=True)

    for group_values, df in umi_count_ad.obs.groupby(obs_groups):
        suffix = "_".join(group_values)
        csv_path = output_path.joinpath(f"{study_species}_umi_{suffix}.csv")
        print("Writing:", csv_path.name)
        subset_ad = umi_count_ad[df.index.values, genes]
        subset_ad.to_df().T.to_csv(csv_path, index_label="gene")

    output_path.joinpath("groups.txt").write_text("\t".join(obs_groups))


gene_sets = {
    f"{study_species}_responsive_genes": responsive_phagocyte_genes,
}

txburst_output_path = get_output_path("txburst")
create_folder(txburst_output_path)

for gene_set_name, genes in gene_sets.items():
    all_sets = (
        (gene_set_name, ["replicate", "treatment", "time_point"]),
        (gene_set_name + "_combined_replicates", ["treatment", "time_point"]),
    )

    for set_name, umi_groups in all_sets:
        gene_set_output_path = txburst_output_path.joinpath(set_name)
        if gene_set_output_path.exists():
            print("Skipping:", set_name)
            continue

        write_umi_subsets(gene_set_output_path, umi_groups)

In [None]:
def run_txburst_fitting(umi_files_path, execute_commands=True):
    txburst_script_path = get_scripts_path("txburst")

    with working_directory(umi_files_path):
        for full_csv_file_path in umi_files_path.glob("*.csv"):
            csv_file_path = full_csv_file_path.name
            ml_file_path = Path(full_csv_file_path.stem + "_ML.pkl")
            pl_file_path = Path(full_csv_file_path.stem + "_PL.pkl")

            txburst_ml_script_path = txburst_script_path.joinpath("txburstML.py")
            txburst_pl_script_path = txburst_script_path.joinpath("txburstPL.py")

            commands = []
            if not pl_file_path.exists():
                commands.append(f"{txburst_ml_script_path} --njobs 4 {csv_file_path}")
            if not ml_file_path.exists():
                commands.append(f"{txburst_pl_script_path} --njobs 4 --file {csv_file_path} --MLFile {ml_file_path}")

            for cmd in commands:
                print("Executing:", cmd)
                if execute_commands:
                    %run {cmd}


def load_txburst_pkl_files(pkl_paths, condition_names):
    df_list = []

    for pkl_path in pkl_paths:
        condition_values = pkl_path.stem[:-3].split("_")[-len(condition_names):]
        pkl_df = pd.read_pickle(pkl_path)

        condition_df = pkl_df.reset_index("gene")
        for condition_name, condition_value in zip(condition_names, condition_values):
            condition_df[condition_name] = condition_value

        df_list.append(condition_df)

    return pd.concat(df_list, ignore_index=True)


def collate_txburst_results(pkl_files_path):
    condition_groups = pkl_files_path.joinpath("groups.txt").read_text().split("\t")

    ml_df = load_txburst_pkl_files(pkl_files_path.glob("*_ML.pkl"), condition_groups)
    pl_df = load_txburst_pkl_files(pkl_files_path.glob("*_PL.pkl"), condition_groups)

    for df in [ml_df, pl_df]:
        df[["k_on", "k_off", "k_syn"]] = pd.DataFrame(df.loc[:, 0].to_list())

    ml_df.rename(columns={1: "keep"}, inplace=True)
    ml_df.drop(columns=[0], inplace=True)

    pl_df[["bf_point", "bf_lower", "bf_upper"]] = pd.DataFrame(pl_df.loc[:, 1].to_list())
    pl_df[["bs_point", "bs_lower", "bs_upper"]] = pd.DataFrame(pl_df.loc[:, 2].to_list())
    pl_df.drop(columns=[0, 1, 2], inplace=True)

    index_columns = ["gene"] + condition_groups
    ml_df = ml_df.set_index(index_columns)
    pl_df = pl_df.set_index(index_columns)

    txburst_df = pl_df.join(ml_df.keep, how="outer")
    txburst_df.update(ml_df)
    txburst_df = txburst_df.reset_index()

    sort_columns = [c for c in ["gene", "replicate", "time_point", "treatment"] if c in txburst_df.columns]
    return txburst_df.sort_values(by=sort_columns)


for sub_folder in filter(Path.is_dir, txburst_output_path.iterdir()):
    collated_csv_path = sub_folder.with_suffix(".csv")
    if collated_csv_path.exists():
        print("Skipping:", sub_folder.name)
        txburst_df = pd.read_csv(collated_csv_path)
    else:
        run_txburst_fitting(sub_folder)
        txburst_df = collate_txburst_results(sub_folder)
        txburst_df.to_csv(collated_csv_path, index=False)

    txburst_df = txburst_df.loc[txburst_df.keep]

    display(symbol_map.added_to(txburst_df))

    totals_df = txburst_df.gene.value_counts().sort_values(ascending=False).to_frame("n_conditions")
    totals_df = totals_df.sort_values(by="n_conditions", ascending=False)
    totals_df.index.name = "gene"
    display(symbol_map.added_to(totals_df))
    sns.countplot(
        x="n_conditions",
        data=totals_df,
    )
    plt.show()