In [2]:
# TODO test whether ths runs well without the computed UMAP and embeddings

import warnings
# avoid DeprecationWarning: np.find_common_type is deprecated due to pandas version (needed by other packages)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas.core.algorithms")

In [3]:
import os
from pathlib import Path
import pandas as pd
import logging
import numpy as np
from collections import defaultdict
import copy
import matplotlib
import torch

from cellwhisperer.config import get_path
from cellwhisperer.utils.inference import score_transcriptomes_vs_texts
from cellwhisperer.validation.integration.functions import eval_scib_metrics
from cellwhisperer.utils.model_io import load_cellwhisperer_model
from zero_shot_validation_scripts.utils import TABSAP_WELLSTUDIED_COLORMAPPING, PANCREAS_ORDER, SUFFIX_PREFIX_DICT

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset

  self.hub = sentry_sdk.Hub(client)


In [4]:
#### Parameters ####

ckpt_file_path=snakemake.input.model

metadata_col = snakemake.wildcards.metadata_col
dataset_name = snakemake.wildcards.dataset


In [5]:
#### Load model
# Load the cellwhisperer model
(
    pl_model_cellwhisperer,
    text_processor_cellwhisperer,
    cellwhisperer_transcriptome_processor,
) = load_cellwhisperer_model(model_path=ckpt_file_path, eval=True)
cellwhisperer_model  =  pl_model_cellwhisperer.model

#### Load data
adata = load_and_preprocess_dataset(dataset_name=dataset_name, read_count_table_path = snakemake.input.raw_read_count_table,
                                    obsm_paths={"X_features": (snakemake.input.processed_dataset, "transcriptome_embeds")})
logging.info(f"Data loaded and preprocessed. Shape: {adata.shape}")



In [None]:
#### Predict the labels using CellWhisperer
adata_no_nans = adata[
    ~(adata.obs[metadata_col].isna()) & ~(adata.obs[metadata_col] == "nan")
].copy()

labels = adata_no_nans.obs[metadata_col].unique().tolist()

if snakemake.params.use_prefix_suffix_version and metadata_col in SUFFIX_PREFIX_DICT:
    prefix, suffix = SUFFIX_PREFIX_DICT[metadata_col]
    text_list=[f"{prefix}{x}{suffix}" for x in labels]
elif metadata_col not in SUFFIX_PREFIX_DICT:
    logging.warning(f"Label column {metadata_col} not found in SUFFIX_PREFIX_DICT, continuing without prefix/suffix")
    text_list = labels.copy()

scores, true_classes = score_transcriptomes_vs_texts(
    model=cellwhisperer_model,
    logit_scale=cellwhisperer_model.discriminator.temperature.exp(),
    # transcriptome_input=torch.tensor(adata_no_nans.obsm["X_features"], device=cellwhisperer_model.device),
    transcriptome_input=adata_no_nans,
    text_list_or_text_embeds=text_list,
    average_mode="embeddings" if snakemake.params.average_by_class else None,
    grouping_keys=adata_no_nans.obs[metadata_col].values,  # only relevant if average_mode is not None
    transcriptome_processor=cellwhisperer_transcriptome_processor,
    batch_size=32,
    score_norm_method=None,  
)
scores = scores.T  # n_cells * n_text
predicted_labels = [labels[x] for x in scores.argmax(axis=1)]

In [None]:
result_df = pd.DataFrame(index=true_classes if snakemake.params.average_by_class else adata_no_nans.obs.index)

for term in text_list:
    result_df[f"score_for_{term}"] = scores[:, text_list.index(term)]

result_df["predicted_labels"] = predicted_labels

if snakemake.params.average_by_class:
    result_df["label"] = result_df.index
else:
    result_df["label"] = adata_no_nans.obs[metadata_col].values

result_df["is_correct"] = (result_df["predicted_labels"] == result_df["label"])

result_df.to_csv(snakemake.output.predictions, index=True)