In [None]:
import warnings
# avoid DeprecationWarning: np.find_common_type is deprecated due to pandas version (needed by other packages)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas.core.algorithms")

In [None]:
import os
from pathlib import Path
import pandas as pd
import logging
import numpy as np
from collections import defaultdict
import copy
import matplotlib
import torch

from cellwhisperer.config import get_path
from cellwhisperer.utils.inference import score_transcriptomes_vs_texts
from cellwhisperer.validation.integration.functions import eval_scib_metrics
from cellwhisperer.validation.zero_shot.single_cell_annotation import (
    get_performance_metrics_transcriptome_vs_text
)
from cellwhisperer.utils.model_io import load_cellwhisperer_model
from server.common.colors import CSS4_NAMED_COLORS 
from zero_shot_validation_scripts.utils import umap_on_embedding, prepare_integration_df,TABSAP_WELLSTUDIED_COLORMAPPING, PANCREAS_ORDER, SUFFIX_PREFIX_DICT

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset
from zero_shot_validation_scripts.plotting import (
    plot_embeddings_with_scores,
    plot_cellwhisperer_predictions_on_umap,
    plot_confusion_matrix,
    plot_term_search_result,
    plot_confidence_distributions,
    plot_integration_metrics,
)



In [None]:
#### Parameters ####

matplotlib.style.use(snakemake.input.mpl_style)

ckpt_file_path=snakemake.input.model

result_dir=snakemake.output.output_directory

dataset_name = snakemake.wildcards.dataset

## Choose the datasets and analysis types to run
analysis_types = [
    "cellwhisperer",
    snakemake.params.transcriptome_model_name,
    ]

label_cols=snakemake.params.metadata_cols_per_dataset[dataset_name]

use_prefix_suffix_version = True


In [None]:
#### Load model
# Load the cellwhisperer model
(
    pl_model_cellwhisperer,
    text_processor_cellwhisperer,
    cellwhisperer_transcriptome_processor,
) = load_cellwhisperer_model(model_path=ckpt_file_path, eval=True)
cellwhisperer_model  =  pl_model_cellwhisperer.model

result_metrics_dict = {}
os.makedirs(f"{result_dir}", exist_ok=True)

#### Load data
adata = load_and_preprocess_dataset(dataset_name=dataset_name, read_count_table_path = snakemake.input.raw_read_count_table,
                            processed_data_path = snakemake.input.processed_dataset,
                            transcriptome_model_name = snakemake.params.transcriptome_model_name)
logging.info(f"Data loaded and preprocessed. Shape: {adata.shape}")

In [None]:
#### Create embeddings and calculate metrics
for analysis_type in analysis_types:
    logging.info(f"Starting with {analysis_type}")

    # Calculate UMAPs based on the embeddings
    adata = umap_on_embedding(
        adata,
        embedding_key=f"X_{analysis_type}",
        neighbors_key=f"neighbors_{analysis_type}",
        umap_key=f"X_umap_on_neighbors_{analysis_type}",
    )

    # Calculate integration metrics
    result_metrics_dict[(dataset_name, analysis_type)] = eval_scib_metrics(
        adata,
        label_key="celltype",
        batch_key="batch",
        embedding_key=f"X_{analysis_type}",
    )

    logging.info(f"Finished with {analysis_type}")

celltype_palette = {celltype:list(CSS4_NAMED_COLORS.values())[i if i<len(CSS4_NAMED_COLORS.values()) else i-len(CSS4_NAMED_COLORS.values())] for i,celltype in enumerate(adata.obs.celltype.unique())}
if "tabula_sapiens" in dataset_name:
    # update the celltype palette with the well-studied cell types
    celltype_palette.update(TABSAP_WELLSTUDIED_COLORMAPPING)

# Plot the embeddings generated by the different methods, colored by celltype and batch
plot_embeddings_with_scores(
    adata=adata,
    analysis_types=analysis_types,
    result_metrics_dict=result_metrics_dict,
    dataset_name=dataset_name,
    result_dir=result_dir,
    celltype_plot_palette=celltype_palette,
)

In [None]:
if adata.obs.batch.nunique() > 1:
    integration_scores_df=prepare_integration_df(result_metrics_dict)
    #### Plot and Save integration metrics
    integration_scores_df.to_csv(f"{result_dir}/metrics_MS_zero_shot.csv")
    plot_integration_metrics(integration_scores_df, result_dir)

if "tabula_sapiens" in dataset_name:
    color_mapping = copy.copy(TABSAP_WELLSTUDIED_COLORMAPPING)
else:
    color_mapping = dict(zip(adata.obs["celltype"].cat.categories, adata.uns[f"celltype_colors"]))
color_mapping.update(dict(zip(adata.obs["batch"].cat.categories, adata.uns[f"batch_colors"])))

In [None]:
#### Predict the labels using cellwhisperer

for label_col in label_cols:
    adata_no_nans = adata[
        ~(adata.obs[label_col].isna()) & ~(adata.obs[label_col] == "nan")
    ].copy()

    if use_prefix_suffix_version and label_col in SUFFIX_PREFIX_DICT:
        prefix, suffix = SUFFIX_PREFIX_DICT[label_col]
        text_list=[f"{prefix}{x}{suffix}" for x in adata_no_nans.obs[label_col].unique().tolist()]
    elif label_col not in SUFFIX_PREFIX_DICT:
        logging.warning(f"Label column {label_col} not found in SUFFIX_PREFIX_DICT, continuing without prefix/suffix")
        text_list = adata_no_nans.obs[label_col].unique().tolist()

    scores, _ = score_transcriptomes_vs_texts(
        model=cellwhisperer_model,
        logit_scale=cellwhisperer_model.discriminator.temperature.exp(),
        transcriptome_input=torch.tensor(adata_no_nans.obsm["X_cellwhisperer"],
                                            device=cellwhisperer_model.device),
        text_list_or_text_embeds=text_list,
        average_mode=None,
        grouping_keys=None,
        transcriptome_processor=cellwhisperer_transcriptome_processor,
        batch_size=32,
        score_norm_method=None,  
    )
    scores = scores.T  # n_cells * n_text
    predicted_labels = [
        adata_no_nans.obs[label_col].unique().tolist()[x]
        for x in scores.argmax(axis=1)
    ]
    for term in text_list:
        adata_no_nans.obs[f"score_for_{term}"] = scores[:, text_list.index(term)].tolist()
    adata_no_nans.obs["predicted_labels_cellwhisperer"] = predicted_labels
    adata_no_nans.obs["confidence_cellwhisperer"] = scores.max(axis=1).values
    adata_no_nans.obs["correct_prediction"] = adata_no_nans.obs["predicted_labels_cellwhisperer"] == adata_no_nans.obs[label_col]

    #### Plot the confidence distributions
    plot_confidence_distributions(adata_no_nans, result_dir, dataset_name, text_list,
                            label_col=label_col)

    #### Plot the cellwhisperer predicted labels
    # Put them in correct order
    if "well_studied_celltypes" in dataset_name and label_col=="celltype":
        adata_no_nans.obs["celltype"] = pd.Categorical(
            values=adata_no_nans.obs["celltype"],
            categories=list(TABSAP_WELLSTUDIED_COLORMAPPING.keys()),
            ordered=True)
        adata_no_nans.obs["predicted_labels_cellwhisperer"] = pd.Categorical(
            values=adata_no_nans.obs["predicted_labels_cellwhisperer"],
            categories=adata_no_nans.obs["celltype"].cat.categories)


    if "tabula_sapiens" in dataset_name and label_col == "celltype" and not "well_studied_celltypes" in dataset_name:
        # extra predictions for the TabSap dataset: Only predict for the well-studied cell types
        # This allows plotting them on the same UMAP as the full dataset
        adata_wellstudied = adata_no_nans[
            adata_no_nans.obs["celltype"].isin(TABSAP_WELLSTUDIED_COLORMAPPING.keys())
        ].copy()
        if use_prefix_suffix_version and label_col in SUFFIX_PREFIX_DICT:
            wellstudied_texts = [f"{prefix}{x}{suffix}" for x in TABSAP_WELLSTUDIED_COLORMAPPING.keys()]
        else:
            wellstudied_texts = list(TABSAP_WELLSTUDIED_COLORMAPPING.keys())

        textlist_idx_wellstudied=[text_list.index(x) for x in text_list if x in wellstudied_texts]
        textlist_wellstudied = [text_list[x] for x in textlist_idx_wellstudied]
        scores_wellstudied = scores[adata_no_nans.obs["celltype"].isin(TABSAP_WELLSTUDIED_COLORMAPPING.keys()), :]
        scores_wellstudied =   scores_wellstudied[:,textlist_idx_wellstudied]
        predicted_labels_wellstudied = [textlist_wellstudied[x].replace(suffix,"").replace(prefix,"") for x in scores_wellstudied.argmax(axis=1)]
        adata_wellstudied.obs["predicted_labels_cellwhisperer"] = predicted_labels_wellstudied
        plot_cellwhisperer_predictions_on_umap(
            adata=adata_wellstudied,
            result_dir=result_dir,
            label_col=label_col,
            color_mapping=color_mapping if label_col == "celltype" else None,
            background_adata=adata_no_nans[~adata_no_nans.obs["celltype"].isin(TABSAP_WELLSTUDIED_COLORMAPPING.keys())]
        ) 
    else:
        plot_cellwhisperer_predictions_on_umap(
        adata=adata_no_nans,
        result_dir=result_dir,
        label_col=label_col,
        color_mapping=color_mapping if label_col == "celltype" else None,
    )                       

    #### Get classification performance metrics for cellwhisperer
        
    correct_text_idx_per_transcriptome=[
            adata_no_nans.obs[label_col].unique().tolist().index(x)
            for x in adata_no_nans.obs[label_col].values
        ]
    shuffled_text_idx_per_transcriptome=copy.copy(correct_text_idx_per_transcriptome)
    np.random.shuffle(shuffled_text_idx_per_transcriptome)
        
    (
        performance_metrics,
        performance_metrics_per_label_df,
    ) = get_performance_metrics_transcriptome_vs_text(
        model=cellwhisperer_model,
        transcriptome_input=torch.tensor(adata_no_nans.obsm["X_cellwhisperer"], device=cellwhisperer_model.device),
        text_list_or_text_embeds=text_list,#adata_no_nans.obs[label_col].unique().tolist(),
        average_mode=None,
        grouping_keys=None,
        transcriptome_processor=cellwhisperer_transcriptome_processor,
        batch_size=32,
        score_norm_method=None,
        correct_text_idx_per_transcriptome=correct_text_idx_per_transcriptome,
    )
    pd.Series(performance_metrics).to_csv(
        f"{result_dir}/performance_metrics_cellwhisperer.{label_col}_as_label.macrovag.csv"
    )
    performance_metrics_per_label_df.to_csv(
        f"{result_dir}/performance_metrics_cellwhisperer.{label_col}_as_label.per_{label_col}.csv"
    )

    ## Plot the confusion matrix
    if dataset_name =="pancreas":
        order=PANCREAS_ORDER
    elif "well_studied_celltypes" in dataset_name and label_col=="celltype":
        order = list(TABSAP_WELLSTUDIED_COLORMAPPING.keys())
    else :
        order=None

    performance_metrics_per_label_df_wo_prefix_suffix = performance_metrics_per_label_df.copy()
    if use_prefix_suffix_version and label_col in SUFFIX_PREFIX_DICT:
        performance_metrics_per_label_df_wo_prefix_suffix.index = [
            x.replace(prefix,"").replace(suffix,"")
            for x in performance_metrics_per_label_df.index.values
        ]
        performance_metrics_per_label_df_wo_prefix_suffix.columns = [
            x.replace(prefix,"").replace(suffix,"")
            for x in performance_metrics_per_label_df.columns.values
        ]
    try:
        title = f"$\\text{{ROC-AUC}}_{{macro}}={round(float(performance_metrics['rocauc_macroAvg']),2)}$"
        plot_confusion_matrix(
            performance_metrics_per_label_df=performance_metrics_per_label_df_wo_prefix_suffix,
            result_dir=result_dir,
            label_col=label_col,
            order=order,
            title=title
        )
    except ValueError as e:
        print(f"Got the following error during plotting of confusion matrix (continueing): {e}")

    del adata_no_nans

In [None]:
## Term search in tabula sapiens
if dataset_name == "tabula_sapiens":

    prefix, suffix = SUFFIX_PREFIX_DICT["celltype"]

    terms_celltype_dict={
        "red blood cells": "erythrocyte",
        "erythrocyte": "erythrocyte",
        "Natural killer cells":"nk cell",
        "T cells":"t cell",
        "B cells":"b cell",
        "blood platelets":"platelet",
    }
    for term,celltype in terms_celltype_dict.items():

        scores, _ = score_transcriptomes_vs_texts(
        model=cellwhisperer_model,
        logit_scale=cellwhisperer_model.discriminator.temperature.exp(),
        transcriptome_input=torch.tensor(adata.obsm["X_cellwhisperer"], device=cellwhisperer_model.device),
        text_list_or_text_embeds=[f"{prefix}{term}{suffix}"],
        average_mode=None,
        grouping_keys=None,
        transcriptome_processor=cellwhisperer_transcriptome_processor,
        batch_size=32,
        score_norm_method=None,
        )
        scores = scores.T  # n_cells * n_text
        adata.obs[f"score_for_{term}"] = scores.squeeze().tolist()
        adata.obs[f"label contains '{celltype}'"] = adata.obs["celltype"].str.contains(celltype).astype(int)

        plot_term_search_result(term, celltype, adata, result_dir, prefix, suffix)