In [None]:
import warnings
# avoid DeprecationWarning: np.find_common_type is deprecated due to pandas version (needed by other packages)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas.core.algorithms")

In [None]:
import os
import pandas as pd
import logging
import copy
import matplotlib
import anndata
import scanpy as sc
import matplotlib.pyplot as plt

from server.common.colors import CSS4_NAMED_COLORS 
from zero_shot_validation_scripts.utils import TABSAP_WELLSTUDIED_COLORMAPPING, SUFFIX_PREFIX_DICT

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset

sc.set_figure_params(
    vector_friendly=True, dpi_save=500
)  # Makes PDFs of scatter plots much smaller in size but still high-quality

In [None]:
#### Parameters ####
matplotlib.style.use(snakemake.input.mpl_style)
sc.set_figure_params(
    vector_friendly=True, dpi_save=500
)  # Makes PDFs of scatter plots much smaller in size but still high-quality



result_dir=snakemake.output.result_dir

dataset_name = snakemake.wildcards.dataset

metadata_col = snakemake.wildcards.metadata_col

In [None]:
def plot_cellwhisperer_predictions_on_umap(
    adata: anndata.AnnData,
    result_dir: str,
    metadata_col="celltype",
    color_mapping=None,
    background_adata=None,
) -> None:
    """Plot the single-cellm predicted labels on the UMAP."""

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=DeprecationWarning)
        warnings.filterwarnings("ignore", category=FutureWarning)
        warnings.filterwarnings("ignore", category=UserWarning)

        embedding_basis = "X_umap_on_neighbors_cellwhisperer"

        if color_mapping is None:
            if f"{metadata_col}_colors" in adata.uns.keys():
                del adata.uns[f"{metadata_col}_colors"]
            # to assign colors to metadata_col_colors, plot once but don't show
            sc.pl.embedding(
                adata,
                basis=embedding_basis,
                color=[metadata_col, "batch"],
                frameon=False,
                s=10,
                alpha=0.5,
                legend_fontsize=8,
                legend_loc="right margin",
                legend_fontoutline=2,
                show=False,
            )
            plt.close()

            color_mapping = dict(
                zip(
                    adata.obs[metadata_col].cat.categories,
                    adata.uns[f"{metadata_col}_colors"],
                )
            )
        try:
            color_mapping.update(
                dict(zip(adata.obs["batch"].drop_duplicates().values, adata.uns[f"batch_colors"]))
            )
        except KeyError:
             # create new colormap
             color_mapping.update(
                dict(zip(adata.obs["batch"].drop_duplicates().values, sc.pl.palettes.default_102))
            )


        # adata.uns[f"predicted_labels_cellwhisperer_colors"] = adata.uns[f"{metadata_col}_colors"] breaks and has no effect
        for color in [
            metadata_col,
            "predicted_labels",
            "batch",
        ]:
            ax = plt.gca()
            if background_adata is not None:
                # Plot the background in grey
                sc.pl.embedding(
                    background_adata,
                    basis=embedding_basis,
                    frameon=False,
                    s=10,
                    alpha=0.3,
                    legend_fontsize=6,
                    show=False,
                    # palette=color_mapping,
                    ncols=1,
                    ax=ax,
                )

            sc.pl.embedding(
                adata[adata.obs[metadata_col].isin(list(color_mapping.keys()))],
                basis=embedding_basis,
                color=color,
                frameon=False,
                s=10,
                alpha=0.5,
                legend_fontsize=6,
                show=False,
                palette=color_mapping,
                ncols=1,
                ax=ax,
            )
            plt.gcf().set_size_inches(10, 5)
            plt.subplots_adjust(right=0.55)
            for suffix in ["png", "pdf"]:
                plt.savefig(
                    f"{result_dir}/UMAP.{metadata_col}_as_label.{color}.{suffix}",
                    dpi=900 if suffix == "png" else None,
                )
            plt.show()
            plt.close()



In [None]:

result_metrics_dict = {}
os.makedirs(f"{result_dir}", exist_ok=True)

#### Load data
adata = load_and_preprocess_dataset(dataset_name=dataset_name, read_count_table_path = snakemake.input.raw_read_count_table,
                            obsm_paths={"X_umap_on_neighbors_cellwhisperer": (snakemake.input.umap, "neighbors"),
                                        "X_cellwhisperer": (snakemake.input.processed_dataset, "transcriptome_embeds"),
                                        # "X_geneformer": snakemake.input.TODO                   
                                        }  # TODO change naming relative to model name?              
                                   )

logging.info(f"Data loaded and preprocessed. Shape: {adata.shape}")

In [None]:
# Load predictions
predictions = pd.read_csv(snakemake.input.predictions, index_col=0)


adata_no_nans = adata[
    ~(adata.obs[metadata_col].isna()) & ~(adata.obs[metadata_col] == "nan")
].copy()
adata_no_nans.obs = adata_no_nans.obs.join(predictions)

# adata_no_nans["celltype"] = pd.Categorical(adata_no_nans.obs.celltype)
adata_no_nans.obs["batch"] = pd.Categorical(adata_no_nans.obs.batch)
adata_no_nans.obs[metadata_col] = pd.Categorical(adata_no_nans.obs[metadata_col])

if snakemake.params.use_prefix_suffix_version and metadata_col in SUFFIX_PREFIX_DICT:
    prefix, suffix = SUFFIX_PREFIX_DICT[metadata_col]
    text_list=[f"{prefix}{x}{suffix}" for x in adata_no_nans.obs[metadata_col].unique().tolist()]
elif metadata_col not in SUFFIX_PREFIX_DICT:
    logging.warning(f"Label column {metadata_col} not found in SUFFIX_PREFIX_DICT, continuing without prefix/suffix")
    text_list = adata_no_nans.obs[metadata_col].unique().tolist()

In [None]:
if "tabula_sapiens" in dataset_name:
    color_mapping = copy.copy(TABSAP_WELLSTUDIED_COLORMAPPING)
else:
    try:
        color_mapping = dict(zip(adata_no_nans.obs["celltype"].cat.categories, adata_no_nans.uns["celltype_colors"]))
    except KeyError:
        color_mapping = {}
try:
    # TODO batch_colors don't work for tab sap (and others)
    color_mapping.update(dict(zip(adata_no_nans.obs["batch"].cat.categories, adata_no_nans.uns["batch_colors"])))
except KeyError:
    pass

In [None]:
#### Plot the cellwhisperer predicted labels
# Put them in correct order
if "well_studied_celltypes" in dataset_name and metadata_col=="celltype":
    adata_no_nans.obs["celltype"] = pd.Categorical(
        values=adata_no_nans.obs["celltype"],
        categories=list(TABSAP_WELLSTUDIED_COLORMAPPING.keys()),
        ordered=True)
    adata_no_nans.obs["predicted_labels"] = pd.Categorical(
        values=adata_no_nans.obs["predicted_labels"],
        categories=adata_no_nans.obs["celltype"].cat.categories)

if "tabula_sapiens" in dataset_name and metadata_col == "celltype" and not "well_studied_celltypes" in dataset_name:
    # extra predictions for the TabSap dataset: Only predict for the well-studied cell types
    # This allows plotting them on the same UMAP as the full dataset
    adata_wellstudied = adata_no_nans[
        adata_no_nans.obs["celltype"].isin(TABSAP_WELLSTUDIED_COLORMAPPING.keys())
    ].copy()
    if snakemake.params.use_prefix_suffix_version and metadata_col in SUFFIX_PREFIX_DICT:
        wellstudied_texts = [f"{prefix}{x}{suffix}" for x in TABSAP_WELLSTUDIED_COLORMAPPING.keys()]
    else:
        wellstudied_texts = list(TABSAP_WELLSTUDIED_COLORMAPPING.keys())

    textlist_idx_wellstudied=[text_list.index(x) for x in text_list if x in wellstudied_texts]
    textlist_wellstudied = ["score_for_" + text_list[x] for x in textlist_idx_wellstudied]
    scores_wellstudied = adata_no_nans.obs.loc[adata_no_nans.obs["celltype"].isin(TABSAP_WELLSTUDIED_COLORMAPPING.keys()), textlist_wellstudied]
    predicted_labels_wellstudied = [textlist_wellstudied[x].replace(suffix,"").replace(prefix,"").replace("score_for_", "") for x in scores_wellstudied.to_numpy().argmax(axis=1)]
    adata_wellstudied.obs["predicted_labels"] = predicted_labels_wellstudied
    plot_cellwhisperer_predictions_on_umap(
        adata=adata_wellstudied,
        result_dir=result_dir,
        metadata_col=metadata_col,
        color_mapping=color_mapping if metadata_col == "celltype" else None,
        background_adata=adata_no_nans[~adata_no_nans.obs["celltype"].isin(TABSAP_WELLSTUDIED_COLORMAPPING.keys())]
    ) 
else:
    plot_cellwhisperer_predictions_on_umap(
    adata=adata_no_nans,
    result_dir=result_dir,
    metadata_col=metadata_col,
    color_mapping=color_mapping if metadata_col == "celltype" else None,
)                       