In [None]:
"""Plot a heatmap of the confusion matrix in 2 versions: normalized and not normalized."""

In [None]:
import warnings
# avoid DeprecationWarning: np.find_common_type is deprecated due to pandas version (needed by other packages)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas.core.algorithms")

In [None]:
import pandas as pd
import logging
import numpy as np
import copy
import matplotlib
import torch
import seaborn as sns
import scanpy as sc
import matplotlib.pyplot as plt
from cellwhisperer.validation.zero_shot.single_cell_annotation import (
    get_performance_metrics_transcriptome_vs_text
)
from cellwhisperer.utils.model_io import load_cellwhisperer_model
from zero_shot_validation_scripts.utils import TABSAP_WELLSTUDIED_COLORMAPPING, PANCREAS_ORDER, SUFFIX_PREFIX_DICT

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset

sc.set_figure_params(
    vector_friendly=True, dpi_save=500
)  # Makes PDFs of scatter plots much smaller in size but still high-quality


In [None]:
#### Parameters ####

matplotlib.style.use(snakemake.input.mpl_style)

dataset_name = snakemake.wildcards.dataset
metadata_col = snakemake.wildcards.metadata_col

In [None]:
# Load model
(
    pl_model_cellwhisperer,
    text_processor_cellwhisperer,
    cellwhisperer_transcriptome_processor,
) = load_cellwhisperer_model(model_path=snakemake.input.model, eval=True)
cellwhisperer_model  =  pl_model_cellwhisperer.model

In [None]:
result_metrics_dict = {}

#### Load data
adata = load_and_preprocess_dataset(dataset_name=dataset_name, read_count_table_path = snakemake.input.raw_read_count_table,
                            obsm_paths = {
                                          "X_cellwhisperer": (snakemake.input.processed_dataset, "transcriptome_embeds"),
                                           "X_geneformer": (snakemake.input.processed_dataset, "transcriptome_features")
                                          }
                            )
logging.info(f"Data loaded and preprocessed. Shape: {adata.shape}")

In [None]:
# Load predictions
# predictions = pd.read_csv(snakemake.input.predictions, index_col=0) # PP: This does not seem necessary

adata_no_nans = adata[
    ~(adata.obs[metadata_col].isna()) & ~(adata.obs[metadata_col] == "nan")
].copy()
# adata_no_nans.obs = adata_no_nans.obs.join(predictions) # PP: This does not seem necessary

if snakemake.params.use_prefix_suffix_version and metadata_col in SUFFIX_PREFIX_DICT:
    prefix, suffix = SUFFIX_PREFIX_DICT[metadata_col]
    text_list=[f"{prefix}{x}{suffix}" for x in adata_no_nans.obs[metadata_col].unique().tolist()]
elif metadata_col not in SUFFIX_PREFIX_DICT:
    logging.warning(f"Label column {metadata_col} not found in SUFFIX_PREFIX_DICT, continuing without prefix/suffix")
    text_list = adata_no_nans.obs[metadata_col].unique().tolist()

In [None]:
#### Get classification performance metrics for cellwhisperer
    
correct_text_idx_per_transcriptome=[
        adata_no_nans.obs[metadata_col].unique().tolist().index(x)
        for x in adata_no_nans.obs[metadata_col].values
    ]
shuffled_text_idx_per_transcriptome=copy.copy(correct_text_idx_per_transcriptome)
np.random.shuffle(shuffled_text_idx_per_transcriptome)
    
(
    performance_metrics,
    performance_metrics_per_label_df,
) = get_performance_metrics_transcriptome_vs_text(
    model=cellwhisperer_model,
    transcriptome_input=torch.tensor(adata_no_nans.obsm["X_cellwhisperer"], device=cellwhisperer_model.device),
    text_list_or_text_embeds=text_list,#adata_no_nans.obs[metadata_col].unique().tolist(),
    average_mode=None,
    grouping_keys=None,
    transcriptome_processor=cellwhisperer_transcriptome_processor,
    batch_size=32,
    score_norm_method=None,
    correct_text_idx_per_transcriptome=correct_text_idx_per_transcriptome,
)

pd.Series(performance_metrics).to_csv(snakemake.output.performance_metrics)
performance_metrics_per_label_df.to_csv(snakemake.output.performance_metrics_per_metadata)

In [None]:
## Plot the confusion matrix
if dataset_name =="pancreas":
    order=PANCREAS_ORDER
elif "well_studied_celltypes" in dataset_name and metadata_col=="celltype":
    order = list(TABSAP_WELLSTUDIED_COLORMAPPING.keys())
else :
    order=None

performance_metrics_per_label_df_wo_prefix_suffix = performance_metrics_per_label_df.copy()
if snakemake.params.use_prefix_suffix_version and metadata_col in SUFFIX_PREFIX_DICT:
    performance_metrics_per_label_df_wo_prefix_suffix.index = [
        x.replace(prefix,"").replace(suffix,"")
        for x in performance_metrics_per_label_df.index.values
    ]
    performance_metrics_per_label_df_wo_prefix_suffix.columns = [
        x.replace(prefix,"").replace(suffix,"")
        for x in performance_metrics_per_label_df.columns.values
    ]

In [None]:
title = f"$\\text{{ROC-AUC}}_{{macro}}={round(float(performance_metrics['rocauc_macroAvg']),2)}$"
performance_metrics_per_label_df=performance_metrics_per_label_df_wo_prefix_suffix

confusion_matrix = performance_metrics_per_label_df[
    [
        x
        for x in performance_metrics_per_label_df
        if x.startswith("n_samples_predicted_as_")
    ]
]
if snakemake.params.normed:
    confusion_matrix = confusion_matrix.div(
        confusion_matrix.sum(axis=1), axis=0
    )
confusion_matrix.columns = [
    x.replace("n_samples_predicted_as_", "") for x in confusion_matrix.columns
]
if order is not None:
    confusion_matrix = confusion_matrix[order]
    confusion_matrix = confusion_matrix.loc[order]
confusion_matrix.to_excel(snakemake.output.confusion_matrix_table, index=True)

plt.figure(figsize=(10, 10))
sns.heatmap(
    confusion_matrix,
    cmap="Blues",
    annot=False,
    square=True,
    cbar_kws={"shrink": 0.7},
    vmin=0,
    vmax=1 if snakemake.params.normed else None,
)
plt.yticks(
    [x + 0.5 for x in range(len(confusion_matrix.index))],
    confusion_matrix.index,
)
plt.xticks(
    [x + 0.5 for x in range(len(confusion_matrix.columns))],
    confusion_matrix.columns,
    rotation=45,
    ha="right",
)
plt.xlabel("Best-matching keyword")
plt.ylabel("True class")
plt.tight_layout()
cbar = plt.gca().collections[0].colorbar
cbar.set_label("Fraction of cells in true class" if snakemake.params.normed else "Number of cells")

# mark the diagonal with boxes around the cells:
for i in range(len(confusion_matrix.index)):
    for j in range(len(confusion_matrix.columns)):
        if i == j:
            plt.gca().add_patch(
                plt.Rectangle((j, i), 1, 1, fill=False, edgecolor="grey", lw=2)
            )

plt.gcf().set_size_inches(
    max(5, len(confusion_matrix.index) // 2),
    max(5, len(confusion_matrix.index) // 2),
)
plt.title(title)

plt.savefig(snakemake.output.confusion_matrix_plot)
plt.show()
plt.close()
