In [None]:
import os
import tempfile
import sys

import scanpy as sc
import anndata as ad
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import scvi
import seaborn as sns
import torch
from scvi.external import CellAssign
import torchmetrics
import logging

from cellwhisperer.utils.processing import ensure_raw_counts_adata
from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset

torch.set_float32_matmul_precision("medium")

In [None]:
# Connect logging to file snakemake.log.progress
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)

In [None]:
adata = load_and_preprocess_dataset(
    dataset_name=snakemake.wildcards.dataset,
    read_count_table_path=snakemake.input.read_count_table,
)

ensure_raw_counts_adata(adata)
adata

In [None]:
# compute size factors before gene filtering
lib_size = adata.X.sum(1)
adata.obs["size_factor"] = lib_size / np.mean(lib_size)

In [None]:
# Load markers
markers = pd.read_csv(snakemake.input.prepared_markers, index_col=0)
markers

In [None]:
adata.var

In [None]:
genes = markers.index.intersection(adata.var["gene_name"])
adata = adata[:, adata.var["gene_name"].isin(genes)].copy()
adata.var.set_index("gene_name", inplace=True)

In [None]:
logging.info("Computing size factors")
scvi.external.CellAssign.setup_anndata(adata, size_factor_key="size_factor")

In [None]:
logging.info("Training CellAssign")
model = CellAssign(adata, markers.loc[genes])

model.train(batch_size=16)

In [None]:
logging.info("Inferring cell types with CellAssign")
predictions_df = model.predict()
predictions_df.index = adata.obs.index

In [None]:
predictions_df.to_csv(snakemake.output.predictions_raw)

In [None]:
predictions_df=predictions_df[adata.obs[snakemake.params.label_col].cat.categories]

In [None]:
adata

In [None]:
adata.obs[snakemake.params.label_col].cat.categories

In [None]:
predictions_df

In [None]:
# Evaluate predictions using torchmetrics
logging.info("Evaluating predictions")
labels = torch.tensor(adata.obs[snakemake.params.label_col].cat.codes.values)
predictions = torch.tensor(predictions_df.values)

accuracy = torchmetrics.functional.accuracy(
    predictions,
    labels,
    average="macro",
    task="multiclass",
    num_classes=predictions_df.shape[1],
)
precision = torchmetrics.functional.precision(
    predictions,
    labels,
    average="macro",
    task="multiclass",
    num_classes=predictions_df.shape[1],
)
recall = torchmetrics.functional.recall(
    predictions,
    labels,
    average="macro",
    task="multiclass",
    num_classes=predictions_df.shape[1],
)
f1 = torchmetrics.functional.f1_score(
    predictions,
    labels,
    average="macro",
    task="multiclass",
    num_classes=predictions_df.shape[1],
)
auroc = torchmetrics.functional.auroc(
    torch.tensor(predictions_df.values),
    labels,
    task="multiclass",
    num_classes=predictions_df.shape[1],
)

performance = pd.Series(
    {
        "accuracy": accuracy.item(),
        "precision": precision.item(),
        "recall": recall.item(),
        "f1": f1.item(),
        "auroc": auroc.item(),
    },
    name="value",
)
performance.index.name = "metric"
performance.to_csv(snakemake.output.performance)