# Model and Dataset Evaluation

This notebook demonstrates how to load and evaluate a model on a given dataset. It includes:

1. Parameter configuration
2. Data loading (both the dataset and any relevant AnnData objects)
3. Generating embeddings via a model
4. Visualizing embeddings with UMAP
5. Pairwise embedding analysis (e.g., similarity)
6. Metrics computation (e.g., scibEvaluator)
7. Annotation and zero-shot classification

You can adapt the parameters at the top to switch between different models and datasets.

## 1. Imports and Configuration

In [3]:
import logging

import anndata
import numpy as np
import pandas as pd
from datasets import load_dataset
from sentence_transformers import SentenceTransformer

from mmcontext.engine import OmicsQueryAnnotator
from mmcontext.eval import evaluate_annotation_accuracy, scibEvaluator, zero_shot_classification_roc
from mmcontext.eval.utils import create_emb_pair_dataframe
from mmcontext.pl import plot_umap, visualize_embedding_clusters
from mmcontext.pl.plotting import plot_embedding_similarity, plot_query_scores_umap
from mmcontext.pp.utils import consolidate_low_frequency_categories

# mmcontext imports
from mmcontext.utils import load_test_adata_from_hf_dataset

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

############################################
#           CONFIGURABLE PARAMETERS       #
############################################

SAVE_RESULTS = False  # Optionally toggle saving results
SAVE_FIGURES = False  # Optionally toggle saving figures

MODEL_NAME = "jo-mengr/mmcontext-geo7k-cellxgene3.5k-multiplets"  # Example model name
DATASET_NAME = "cellxgene_pseudo_bulk_3_5k_multiplets_natural_language_annotation"  # Example dataset name

# These keys can be adapted to your AnnData
BATCH_KEY = "_scvi_batch"
LABEL_KEY = "cell_type"  # The column used for cell type labels
# Additional keys for scibEvaluator
EMBEDDING_KEYS = ["mmcontext_emb", "mmcontext_text_emb", "X_geneformer", "X_hvg", "X_pca", "X_scvi"]

# Zero-shot classification function parameters
ZERO_SHOT_LABEL_KEY = LABEL_KEY
ZERO_SHOT_EMB_KEY = "mmcontext_emb"
ZERO_SHOT_TEXT_TEMPLATE = "A sample of {} from a healthy individual"

logger.info("Configuration parameters set.")

## 2. Data Loading

In [2]:
logger.info("Loading dataset from HuggingFace...")
dataset = load_dataset(f"jo-mengr/{DATASET_NAME}")

logger.info("Splitting to test subset...")
test_dataset = dataset["val"]

logger.info("Loading model...")
model = SentenceTransformer(MODEL_NAME)

logger.info("Loading AnnData from dataset...")
adata = load_test_adata_from_hf_dataset(test_dataset)

logger.info("Ensuring batch_key is categorical...")
adata.obs[BATCH_KEY] = adata.obs[BATCH_KEY].astype("category")

logger.info("Data loading complete.")
adata

Downloading readme:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/755k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/78.3k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

MissingSchema: Invalid URL 'data/RNA/processed/train/cellxgene_pseudo_bulk_3_5k/val.h5ad': No scheme supplied. Perhaps you meant https://data/RNA/processed/train/cellxgene_pseudo_bulk_3_5k/val.h5ad?

## 3. Generate Embeddings

We generate:

- `mmcontext_emb` from the omics data
- `mmcontext_text_emb` from the text annotation

In [None]:
logger.info("Generating omics embeddings...")
omics_embeddings = model.encode(test_dataset["anndata_ref"])

logger.info("Generating text embeddings...")
text_annotations = adata.obs["natural_language_annotation"].values.tolist()
text_embeddings = model.encode(text_annotations)

logger.info("Storing embeddings in AnnData...")
adata.obsm["mmcontext_emb"] = omics_embeddings
adata.obsm["mmcontext_text_emb"] = text_embeddings

logger.info("Embedding generation complete.")
adata

## 4. UMAP Visualization

Here, we visualize one of the embeddings (e.g., `mmcontext_text_emb`) with UMAP. We also remove low-frequency categories for better clarity in the plot.

In [None]:
logger.info("Consolidating low-frequency categories...")
adata_cut = consolidate_low_frequency_categories(adata, [LABEL_KEY], threshold=10)

# The color_key can be changed to anything in adata.obs, e.g. BATCH_KEY or LABEL_KEY.
color_key = BATCH_KEY

logger.info("Plotting UMAP...")
plot_umap(adata_cut, color_key=color_key, embedding_key="mmcontext_text_emb")

if SAVE_FIGURES:
    import matplotlib.pyplot as plt

    plt.savefig("umap_visualization.png", dpi=150)
    logger.info("UMAP plot saved to umap_visualization.png")

## 5. Pairwise Embedding Analysis

We use `create_emb_pair_dataframe` to create a paired DataFrame of two embeddings (omics vs text). Then:

- `visualize_embedding_clusters` to see how clusters form in a joint space.
- `plot_embedding_similarity` to examine similarity distributions for subsets.

In [None]:
logger.info("Creating embedding pair dataframe...")
emb_pair_df = create_emb_pair_dataframe(
    adata,
    emb1_key="mmcontext_emb",
    emb2_key="mmcontext_text_emb",
    subset_size=20,
    label_keys=[BATCH_KEY, LABEL_KEY],
)

logger.info("Visualizing embedding clusters with UMAP...")
visualize_embedding_clusters(emb_pair_df, method="umap", metric="cosine", n_neighbors=15, min_dist=0.1, random_state=42)

if SAVE_FIGURES:
    import matplotlib.pyplot as plt

    plt.savefig("emb_clusters.png", dpi=150)
    logger.info("Embedding clusters plot saved to emb_clusters.png")

In [None]:
logger.info("Plot embedding similarity with subset=10...")
plot_embedding_similarity(emb_pair_df, emb1_type="omics", emb2_type="text", subset=10, label_key=BATCH_KEY)

if SAVE_FIGURES:
    import matplotlib.pyplot as plt

    plt.savefig("similarity_subset10_batch.png", dpi=150)
    logger.info("Similarity plot (subset=10, batch) saved.")

logger.info("Plot embedding similarity with subset=10 using cell_type...")
plot_embedding_similarity(emb_pair_df, emb1_type="omics", emb2_type="text", subset=10, label_key=LABEL_KEY)

if SAVE_FIGURES:
    import matplotlib.pyplot as plt

    plt.savefig("similarity_subset10_celltype.png", dpi=150)
    logger.info("Similarity plot (subset=10, cell_type) saved.")

If you want a larger subset for the similarity visualization (e.g., 200 samples):

In [None]:
# Example: subset of 200
logger.info("Plot embedding similarity with subset=200...")
plot_embedding_similarity(emb_pair_df, emb1_type="omics", emb2_type="text", subset=200, label_key=LABEL_KEY)

if SAVE_FIGURES:
    import matplotlib.pyplot as plt

    plt.savefig("similarity_subset200_celltype.png", dpi=150)
    logger.info("Similarity plot (subset=200, cell_type) saved.")

## 6. scibEvaluator

We can use scibEvaluator for evaluating batch integration and bio-conservation metrics (Luecken et al.).

In [None]:
logger.info("Initializing scibEvaluator...")
evaluator = scibEvaluator(
    adata=adata,
    batch_key=BATCH_KEY,
    label_key=LABEL_KEY,
    embedding_key=EMBEDDING_KEYS,
    n_top_genes=5000,
    max_cells=5000,
)

logger.info("Running scibEvaluator...")
res = evaluator.evaluate()
res_df = pd.DataFrame(res)
res_df

## 7. Annotation and Query

We can annotate and query our omics data using `OmicsQueryAnnotator`.

In [None]:
logger.info("Annotating omics data...")
annotator = OmicsQueryAnnotator(model)

# Suppose we have some labels to annotate, e.g. from an external source
labels = ["T-Cell", "B-Cell"]
annotator.annotate_omics_data(adata, labels)

logger.info("Plotting annotated data...")
plot_umap(adata, color_key="best_label", embedding_key="mmcontext_emb")

logger.info("Evaluating annotation accuracy...")
score = evaluate_annotation_accuracy(
    adata,
    true_key=BATCH_KEY,  # or whichever key is ground truth for your scenario
    inferred_key="best_label",
)
logger.info(f"Accuracy of annotation: {score}")

### Example: Query with text
You can query the dataset with text prompts and then visualize the result.

In [None]:
# Example usage
qa = annotator  # if the OmicsQueryAnnotator has a method query_with_text
logger.info("Querying dataset for 'B-Cell' and 'macrophage'...")
adata_new = qa.query_with_text(adata, ["B-Cell", "macrophage"])

# Visualize
plot_query_scores_umap(adata_new)
if SAVE_FIGURES:
    import matplotlib.pyplot as plt

    plt.savefig("query_scores_umap.png", dpi=150)
    logger.info("Query scores UMAP saved.")

## 8. Zero-Shot Classification ROC

Below is a function to compute a ROC-AUC for cell-type prediction (or any label) in a zero-shot manner.

In [None]:
logger.info("Computing zero-shot classification ROC...")
macro_auc, auc_details = zero_shot_classification_roc(
    adata,
    model,
    label_key=ZERO_SHOT_LABEL_KEY,
    emb_key=ZERO_SHOT_EMB_KEY,
    text_template=ZERO_SHOT_TEXT_TEMPLATE,
    device="cpu",
)
logger.info(f"Macro AUC: {macro_auc}")
logger.info(f"Detail per label: {auc_details}")

## 9. (Optional) Saving Notebook Outputs

If you wish to programmatically save this notebook (e.g., for batch runs or CI/CD), you can use a dedicated library like `nbformat` or `papermill`. Here's a small snippet as an example (commented out by default).

In [None]:
# Uncomment and adapt if you want to save the executed notebook programmatically
# import nbformat
# from nbformat.v4 import new_notebook
# from nbconvert.preprocessors import ExecutePreprocessor
# from nbconvert import HTMLExporter
#
# if SAVE_RESULTS:
#     logger.info("Saving notebook...")
#     with open("evaluation_notebook.ipynb", "r") as f:
#         nb = nbformat.read(f, as_version=4)
#     ep = ExecutePreprocessor(timeout=600)
#     ep.preprocess(nb, {'metadata': {'path': './'}})
#     with open("evaluation_notebook_executed.ipynb", "w") as f:
#         nbformat.write(nb, f)
#
#     # Optionally export to HTML
#     html_exporter = HTMLExporter()
#     body, _ = html_exporter.from_notebook_node(nb)
#     with open("evaluation_notebook_executed.html", "w") as f:
#         f.write(body)
#
#     logger.info("Notebook saved as evaluation_notebook_executed.ipynb and HTML version.")
