In [None]:

import warnings
# avoid DeprecationWarning: np.find_common_type is deprecated due to pandas version (needed by other packages)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas.core.algorithms")

In [None]:
import os
import pandas as pd
import logging
import copy
import matplotlib
import seaborn as sns
import matplotlib.pyplot as plt

from zero_shot_validation_scripts.utils import TABSAP_WELLSTUDIED_COLORMAPPING, PANCREAS_ORDER, SUFFIX_PREFIX_DICT

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset


In [None]:
#### Parameters ####

matplotlib.style.use(snakemake.input.mpl_style)

result_dir=snakemake.output.plot_dir

dataset_name = snakemake.wildcards.dataset

metadata_col = snakemake.wildcards.metadata_col


In [None]:
result_metrics_dict = {}
os.makedirs(f"{result_dir}", exist_ok=True)

#### Load data
adata = load_and_preprocess_dataset(dataset_name=dataset_name, read_count_table_path = snakemake.input.raw_read_count_table,
                                    obsm_paths={
                                        "X_cellwhisperer": (snakemake.input.processed_dataset, "transcriptome_embeds"),
                                        # "X_geneformer": snakemake.input.TODO
                                    })
logging.info(f"Data loaded and preprocessed. Shape: {adata.shape}")

In [None]:
# Load predictions
predictions = pd.read_csv(snakemake.input.predictions, index_col=0)

adata_no_nans = adata[
    ~(adata.obs[metadata_col].isna()) & ~(adata.obs[metadata_col] == "nan")
].copy()
adata_no_nans.obs = adata_no_nans.obs.join(predictions)

# TODO call it adata as well?

In [None]:
if metadata_col in SUFFIX_PREFIX_DICT:
    prefix, suffix = SUFFIX_PREFIX_DICT[metadata_col]
    text_list=[f"{prefix}{x}{suffix}" for x in adata_no_nans.obs[metadata_col].unique().tolist()]
else:
    logging.warning(f"Label column {metadata_col} not found in SUFFIX_PREFIX_DICT, continuing without prefix/suffix")
    text_list = adata_no_nans.obs[metadata_col].unique().tolist()

In [None]:
#### Plot the confidence distributions
adata = adata_no_nans

hist_dfs_all_terms = {"unnormed": [], "normed": []}
try:  # can lead to errors if the number of unique labels is too high
    if len(adata.obs[metadata_col].unique()) < 1000:
        fig, ax = plt.subplots(
            len(adata.obs[metadata_col].unique()),
            1,
            sharex=True,
            sharey=False,
            figsize=(8, 2 * len(adata.obs[metadata_col].unique())),
        )
        for i, term in enumerate(text_list):
            matching_label = adata.obs[metadata_col].unique().tolist()[i]
            adata.obs["label_matches_term"] = adata.obs[metadata_col] == matching_label
            sns.histplot(
                data=adata.obs,
                x=f"score_for_{term}",
                hue="label_matches_term",
                ax=ax[i],
                bins=20,
                stat="density",
                common_norm=False,
                palette={True: "coral", False: "silver"},
                legend=False,
            )
            hist_df = adata.obs[[f"score_for_{term}", "label_matches_term"]]
            hist_df.columns = ["score", "label_matches_term"]
            hist_dfs_all_terms["unnormed"].append(hist_df.copy())

            plt.sca(ax[i])
            plt.legend(
                title=f"Cell type",
                labels=[matching_label, "other"],
                loc="lower right",
                ncol=1,
            )

            # z-normalize vs the label_matches_term = False
            hist_score_normed = hist_df.copy()
            mean = hist_score_normed[
                hist_score_normed["label_matches_term"] == False
            ]["score"].mean()
            std = hist_score_normed[
                hist_score_normed["label_matches_term"] == False
            ]["score"].std()
            hist_score_normed["score"] = (hist_score_normed["score"] - mean) / std
            hist_dfs_all_terms["normed"].append(hist_score_normed.copy())

        plt.xlabel("Cellwhisperer score for the label")
        plt.savefig(
            f"{result_dir}/confidence_distribution_{metadata_col}_per_label.pdf"
        )
        plt.show()
        plt.close()

    for norm in ["unnormed", "normed"]:
        hist_df_all_terms = pd.concat(hist_dfs_all_terms[norm])
        sns.histplot(
            data=hist_df_all_terms,
            x=f"score",
            hue="label_matches_term",
            bins=20,
            stat="density",
            common_norm=False,
            palette={True: "coral", False: "silver"},
            legend=True,
        )
        plt.xlabel(
            f"{'Normalized c' if norm=='normed' else 'C'}ellwhisperer score for the label"
        )
        plt.ylabel("Density")
        plt.gca().get_legend().set_title("Cell type equals label")
        plt.savefig(
            f"{result_dir}/confidence_distribution_{metadata_col}_all_labels.{norm}.pdf"
        )
        plt.show()
        plt.close()

    # Some specific examples
    if "tabula_sapiens" in dataset_name:
        fig, ax = plt.subplots(3, 1, sharex=True, sharey=False, figsize=(8, 2 * 3))
        for i, term in enumerate(
            [
                "cardiac muscle cell",
                "alveolar fibroblast",
                "thymocyte",
                "erythrocyte",
            ]
        ):
            matching_label = adata.obs[metadata_col].unique().tolist()[i]
            adata.obs["label_matches_term"] = adata.obs[metadata_col] == matching_label
            sns.histplot(
                data=adata.obs,
                x=f"score_for_{term}",
                hue="label_matches_term",
                ax=ax[i],
                bins=20,
                stat="density",
                common_norm=False,
                palette={True: "coral", False: "silver"},
                legend=False,
            )
            hist_df = adata.obs[[f"score_for_{term}", "label_matches_term"]]
            hist_df.columns = ["score", "label_matches_term"]
            hist_dfs_all_terms["unnormed"].append(hist_df.copy())
            plt.sca(ax[i])
        plt.legend(
            title=f"Cell type",
            labels=[matching_label, "other"],
            loc="lower right",
            ncol=1,
        )
        plt.xlabel("Cellwhisperer score for the label")
        plt.savefig(
            f"{result_dir}/confidence_distribution_{metadata_col}_per_label.SELECTED_TERMS.pdf"
        )
        plt.show()
        plt.close()

except Exception as e:
    print(
        f"Got the following error during plotting of confidence distributions (continueing): {e}"
    )


In [None]:
# since 'confidence' (magnitude of the score) is not a good confidence measure, we don't store it anymore (and thus cannot plot it)


if "confidence_cellwhisperer" in adata.obs.columns:

    # Plot the distribution of confidence scores - seperately for cases where the prediction is correct vs incorrect
    sns.kdeplot(
        data=adata.obs,
        x="confidence_cellwhisperer",
        hue="correct_prediction",
        common_norm=False,
    )
    plt.savefig(f"{result_dir}//confidence_distribution_{metadata_col}.pdf")
    plt.close()

    sns.histplot(
        data=adata.obs,
        x="confidence_cellwhisperer",
        hue="correct_prediction",
        common_norm=False,
    )
    plt.savefig(f"{result_dir}//confidence_distribution_{metadata_col}_hist.pdf")
    plt.close()
