In [2]:
"""
Plot the ground truth celltype and the keyword search results on the UMAP.
"""

'\nPlot the ground truth celltype and the keyword search results on the UMAP.\n'

In [None]:
import scanpy as sc
import torch
import matplotlib
import logging

import matplotlib.pyplot as plt
from cellwhisperer.utils.inference import score_transcriptomes_vs_texts
from cellwhisperer.utils.model_io import load_cellwhisperer_model

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset

In [None]:
matplotlib.style.use(snakemake.input.mpl_style)
sc.set_figure_params(
    vector_friendly=True, dpi_save=500
)  # Makes PDFs of scatter plots much smaller in size but still high-quality


ckpt_file_path=snakemake.input.model


dataset_name = snakemake.params.dataset

celltype = snakemake.wildcards.celltype

In [None]:
#### Load model
# Load the cellwhisperer model
(
    pl_model_cellwhisperer,
    text_processor_cellwhisperer,
    cellwhisperer_transcriptome_processor,
) = load_cellwhisperer_model(model_path=ckpt_file_path, eval=True)

cellwhisperer_model  =  pl_model_cellwhisperer.model

#### Load data
adata = load_and_preprocess_dataset(dataset_name=dataset_name, read_count_table_path = snakemake.input.raw_read_count_table,
                            obsm_paths = {
                                          "X_umap_on_neighbors_cellwhisperer": (snakemake.input.umap, "neighbors"),
                                          "X_cellwhisperer": (snakemake.input.processed_dataset, "transcriptome_embeds"),
                                           # "X_geneformer": snakemake.input.TODO
                                          }, 
                            )


logging.info(f"Data loaded and preprocessed. Shape: {adata.shape}")

In [None]:
prefix, suffix = snakemake.params.suffix_prefix_dict["celltype"]

term = snakemake.params.celltype_terms_dict[celltype]

scores, _ = score_transcriptomes_vs_texts(
    model=cellwhisperer_model,
    logit_scale=cellwhisperer_model.discriminator.temperature.exp(),
    transcriptome_input=torch.tensor(adata.obsm["X_cellwhisperer"], device=cellwhisperer_model.device),
    text_list_or_text_embeds=[f"{prefix}{term}{suffix}"],
    average_mode=None,
    grouping_keys=None,
    transcriptome_processor=cellwhisperer_transcriptome_processor,
    batch_size=32,
    score_norm_method=None,
    )
scores = scores.T  # n_cells * n_text
adata.obs[f"score_for_{term}"] = scores.squeeze().tolist()
adata.obs[f"label contains '{celltype}'"] = adata.obs["celltype"].str.contains(celltype).astype(int)

In [None]:
sc.pl.embedding(
    adata,
    basis="X_umap_on_neighbors_cellwhisperer",
    color=[f"label contains '{celltype}'"],
    cmap=matplotlib.colors.ListedColormap(["silver", "firebrick"]),
    show=False,
)
plt.title("Ground truth label")
plt.gcf().axes[0].set_facecolor("white")
plt.gcf().axes[1].remove()

plt.savefig(snakemake.output.umap_on_neighbors_celltype)
plt.tight_layout()
plt.show()
plt.close()

for make_colorscale_symmetrical, fn in zip([True, False], [snakemake.output.colorscale_symmetrical, snakemake.output.colorscale_asymmetrical]):
    vmax = adata.obs[f"score_for_{term}"].max()
    sc.pl.embedding(
        adata,
        basis="X_umap_on_neighbors_cellwhisperer",  # if not "X_umap_original" in adata.obsm.keys() else "X_umap_original",
        color=[f"score_for_{term}"],
        cmap="RdBu_r",
        vmin=-vmax if make_colorscale_symmetrical else None,
        vmax=vmax if make_colorscale_symmetrical else None,
        show=False,
    )
    plt.title("Keyword search results")
    # label the colorbar
    plt.gcf().axes[1].set_ylabel(f"Score for: '{prefix}{term}{suffix}'", fontsize=7)
    plt.gcf().axes[0].set_facecolor("white")

    plt.savefig(fn)
    plt.tight_layout()
    plt.show()
    plt.close()