In [None]:
import pandas as pd
import numpy as np
import torch
import anndata
import torchmetrics
from tqdm.auto import tqdm

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset
from finetuning_eval.models.geneformer import GeneformerCelltypeModel
from cellwhisperer.jointemb.geneformer_model import (
    GeneformerConfig,
    GeneformerTranscriptomeProcessor,
)
from finetuning_eval.models.scgpt import ScGPTCelltypeModel, ScGPTConfig
from finetuning_eval.models.uce import UCECelltypeModel, UCEConfig
from cellwhisperer.jointemb.uce_model import UCETranscriptomeProcessor
from cellwhisperer.jointemb.scgpt_model import ScGPTTranscriptomeProcessor

In [None]:
logfile_handle = open(snakemake.log.progress, "w")

In [None]:
#### Load data
adata = load_and_preprocess_dataset(
    dataset_name=snakemake.wildcards.dataset,
    read_count_table_path=snakemake.input.eval_data,
)

#### Load transfered labels
transfered_labels = pd.read_csv(snakemake.input.transfered_labels)
transfered_labels

assert set(transfered_labels.evaluation_cell_type.loc[lambda x: x != "none"]).issubset(
    set(adata.obs[snakemake.params.label_col])
)

# It may happen that the celltypes annotated in CELLxGENE Census not fully cover all cell types in our evaluation datasets (e.g. Tabula Sapiens).
# Sicne these celltypes cannot be predicted, we will exclude them from the evaluation
uncovered_celltypes = set(adata.obs[snakemake.params.label_col]) - set(
    transfered_labels.evaluation_cell_type.loc[lambda x: x != "none"]
)

if len(uncovered_celltypes) > 0:
    adata = adata[~adata.obs[snakemake.params.label_col].isin(uncovered_celltypes)]
    print(
        "Excluded the following cell types (as they cannot be predicted by the fine-tuned model): "
        + ", ".join(list(uncovered_celltypes))
    )

In [None]:

label_col = snakemake.params.label_col
num_classes = len(adata.obs[label_col].unique())

if snakemake.wildcards.model == "geneformer":
    model = GeneformerCelltypeModel(
        GeneformerConfig(), num_classes=len(transfered_labels)
    )
    transcriptome_processor = GeneformerTranscriptomeProcessor(
        snakemake.threads, ["natural_language_annotation"]
    )  # second argument is irrelevant
elif snakemake.wildcards.model == "scgpt":
    model = ScGPTCelltypeModel(ScGPTConfig(), num_classes=len(transfered_labels))
    transcriptome_processor = ScGPTTranscriptomeProcessor(
        snakemake.threads
    )  # second argument is irrelevant
elif snakemake.wildcards.model == "uce":
    model = UCECelltypeModel(UCEConfig(), num_classes=len(transfered_labels))
    transcriptome_processor = UCETranscriptomeProcessor(snakemake.threads)
else:
    raise NotImplementedError(f"Model {snakemake.wildcards.model} not implemented")

model.load_state_dict(torch.load(snakemake.input.model))
model.eval()
model = model.to("cuda")

In [None]:
# Predict
batch_size = snakemake.params.batch_size
predictions = []
for i in tqdm(range(0, len(adata), batch_size)):
    batch = adata[i : i + batch_size]

    inputs = transcriptome_processor(batch, return_tensors="pt", padding=True)
    inputs = {key: value.to(torch.device("cuda")) for key, value in inputs.items()}
    with torch.inference_mode():
        predictions.append(model(**inputs).detach().cpu())

    logfile_handle.write(f"{i}/{len(adata)/batch_size}\n")
    logfile_handle.flush()

predictions = torch.cat(predictions)

In [None]:
# Save predictions
predictions_raw_df = pd.DataFrame(
    torch.softmax(predictions, 1).numpy(), columns=transfered_labels.training_cell_type
)
predictions_raw_df.index = adata.obs.index

# create directory if it does not exist
import os

os.makedirs(os.path.dirname(snakemake.output.predictions_raw), exist_ok=True)

predictions_raw_df.to_csv(snakemake.output.predictions_raw)

In [None]:
transfered_labels

In [None]:
# Aggregate the columns, according to the transfered labels

predictions_df = pd.DataFrame(index=adata.obs.index)
for evaluation_cell_type, training_cell_types in transfered_labels.groupby(
    "evaluation_cell_type"
)["training_cell_type"]:
    predictions_df[evaluation_cell_type] = predictions_raw_df[training_cell_types].sum(
        axis=1
    )

# sort the predictions:
predictions_df=predictions_df[adata.obs[label_col].cat.categories]
    

predictions_df.to_csv(snakemake.output.predictions)

In [None]:
# Evaluate predictions using torchmetrics
labels = torch.tensor(adata.obs[label_col].cat.codes.values)
predictions = torch.tensor(predictions_df.values)

accuracy = torchmetrics.functional.accuracy(
    predictions,
    labels,
    average="macro",
    task="multiclass",
    num_classes=predictions_df.shape[1],
)
precision = torchmetrics.functional.precision(
    predictions,
    labels,
    average="macro",
    task="multiclass",
    num_classes=predictions_df.shape[1],
)
recall = torchmetrics.functional.recall(
    predictions,
    labels,
    average="macro",
    task="multiclass",
    num_classes=predictions_df.shape[1],
)
f1 = torchmetrics.functional.f1_score(
    predictions,
    labels,
    average="macro",
    task="multiclass",
    num_classes=predictions_df.shape[1],
)
auroc = torchmetrics.functional.auroc(
    torch.tensor(predictions_df.values),
    labels,
    task="multiclass",
    num_classes=predictions_df.shape[1],
)

performance = pd.Series(
    {
        "accuracy": accuracy.item(),
        "precision": precision.item(),
        "recall": recall.item(),
        "f1": f1.item(),
        "auroc": auroc.item(),
    },
    name="value",
)
performance.index.name = "metric"
performance.to_csv(snakemake.output.performance)

In [None]:
performance

In [None]:
logfile_handle.close()