## Example: Benchmarking a Model

This section demonstrates how to benchmark a model's embedding using czbenchmarks. 

### Step 1: Setup and Imports

In [None]:
# Create isolated virtual environment for scVI and czbenchmarks (run once)
# uncomment below code if you need to create a new virtual environment
# !python3 -m venv .venv_scvi

# # Install model required packages
# !.venv_scvi/bin/python -m pip install --upgrade pip
# !.venv_scvi/bin/python -m pip install ipykernel numpy pandas scvi-tools #czbenchmarks

# # Register the new environment as a Jupyter kernel (if not already registered)
# !.venv_scvi/bin/python -m ipykernel install --user --name venv_scvi --display-name "Python (.venv_scvi)"

# print("Virtual environment '.venv_scvi' created, dependencies installed, and kernel registered.")


In [12]:
import logging
import sys
import json
import numpy as np
from czbenchmarks.datasets import load_dataset
from czbenchmarks.datasets.single_cell_labeled import SingleCellLabeledDataset
from czbenchmarks.tasks.types import CellRepresentation
from czbenchmarks.tasks import (
    ClusteringTask,
    EmbeddingTask,
    MetadataLabelPredictionTask,
)
from czbenchmarks.tasks.clustering import ClusteringTaskInput
from czbenchmarks.tasks.embedding import EmbeddingTaskInput
from czbenchmarks.tasks.label_prediction import MetadataLabelPredictionTaskInput

# Model specific imports
import scvi #other imports can be used as required by model
import functools

# Set up basic logging to see the library's output
logging.basicConfig(level=logging.INFO, stream=sys.stdout)

### Step 2: Load a Dataset

Load the pre-configured `tsv2_prostate` dataset. The library handles automatic download, caching, and loading as a `SingleCellLabeledDataset` for streamlined reuse.

**Loaded dataset provides:**
- `dataset.adata`: AnnData object with gene expression data.
- `dataset.labels`: pandas Series of cell type labels.

In [13]:
# The 'dataset' object is a validated AnnData wrapper, ensuring efficient downstream processing.
dataset: SingleCellLabeledDataset = load_dataset("tsv2_prostate")

INFO:czbenchmarks.file_utils:File already exists in cache: /Users/sgupta/.cz-benchmarks/datasets/homo_sapiens_10df7690-6d10-4029-a47e-0f071bb2df83_Prostate_v2_curated.h5ad
INFO:czbenchmarks.datasets.single_cell:Loading dataset from /Users/sgupta/.cz-benchmarks/datasets/homo_sapiens_10df7690-6d10-4029-a47e-0f071bb2df83_Prostate_v2_curated.h5ad


### Step 3: Run Model Inference and Get Output
Use a pre-trained scVI model to generate cell embeddings for evaluation within the benchmarking framework.

In [16]:
model_weights_dir = "/Users/sgupta/.cz-benchmarks/models"
required_obs_keys =  ["dataset_id", "assay", "suspension_type", "donor_id"]

adata = dataset.adata.copy()

batch_keys = required_obs_keys
adata.obs["batch"] = functools.reduce(
    lambda a, b: a + b, [adata.obs[c].astype(str) for c in batch_keys]
)
# Use the scvi-tools API to map our dataset to the reference model
scvi.model.SCVI.prepare_query_anndata(adata, model_weights_dir)
scvi_model = scvi.model.SCVI.load_query_data(adata, model_weights_dir)
scvi_model.is_trained = True  # Ensure the model is marked as trained

# Now, generate the latent representation (the embedding)
scvi_embedding = scvi_model.get_latent_representation()
model_output = scvi_embedding

print(f"Generated scVI embedding with shape: {scvi_embedding.shape}")

[34mINFO    [0m File [35m/Users/sgupta/.cz-benchmarks/models/[0m[95mmodel.pt[0m already downloaded                                      
[34mINFO    [0m Found [1;36m44.05[0m% reference vars in query data.                                                                


  return _pad_and_sort_query_anndata(adata, var_names, inplace)


[34mINFO    [0m File [35m/Users/sgupta/.cz-benchmarks/models/[0m[95mmodel.pt[0m already downloaded                                      


  warn(
  _, _, device = parse_device_args(


Generated scVI embedding with shape: (2044, 50)


#### Optional: Fine-Tune the Model

Fine-tuning adjusts the pre-trained model to your dataset. Use a small number of epochs and a low learning rate to refine model weights without overwriting pre-trained knowledge.

In [None]:
print("Starting scVI model fine-tuning...")
scvi_model.train(
    max_epochs=50,              # A smaller number of epochs for fine-tuning
    plan_kwargs={"lr": 5e-5},   # A lower learning rate is crucial for stable fine-tuning
    early_stopping=True,        # Recommended to prevent overfitting
    early_stopping_patience=10
)
print("Fine-tuning complete.")

# Generate the latent representation from the *fine-tuned* model
model_output = scvi_model.get_latent_representation()

print(f"Generated fine-tuned scVI embedding with shape: {model_output.shape}")

### Step 4: Run the Clustering Task

Evaluate the embedding by measuring clustering performance using Adjusted Rand Index (ARI) and Normalized Mutual Information (NMI). The task compares Leiden clusters from the embedding to true labels. Higher scores indicate better clustering. Compare `clustering_results` to `clustering_baseline_results` to assess model performance against the PCA baseline.

In [17]:
# 1. Initialize the task
clustering_task = ClusteringTask()

# 2. Define the inputs for the task
clustering_task_input = ClusteringTaskInput(
    obs=dataset.adata.obs,      # The full observation metadata
    input_labels=dataset.labels # The ground-truth labels for comparison
)

# 3. Run the task on your model's output
clustering_results = clustering_task.run(
    cell_representation=model_output,
    task_input=clustering_task_input,
)

# 4. Compute and run the baseline for comparison
expression_data = dataset.adata.X
clustering_baseline = clustering_task.compute_baseline(expression_data)
clustering_baseline_results = clustering_task.run(
    cell_representation=clustering_baseline,
    task_input=clustering_task_input,
)

print("--- Clustering Model Results ---")
for result in clustering_results:
    print(result.model_dump_json(indent=2))

print("\n--- Clustering Baseline Results ---")
for result in clustering_baseline_results:
    print(result.model_dump_json(indent=2))

--- Clustering Model Results ---
{
  "metric_type": "adjusted_rand_index",
  "value": 0.7282581538681618,
  "params": {}
}
{
  "metric_type": "normalized_mutual_info",
  "value": 0.8693815660627174,
  "params": {}
}

--- Clustering Baseline Results ---
{
  "metric_type": "adjusted_rand_index",
  "value": 0.626707020983652,
  "params": {}
}
{
  "metric_type": "normalized_mutual_info",
  "value": 0.8326481406592264,
  "params": {}
}


### Step 5: Run Additional Benchmarking Tasks

Initialize each task, specify inputs, and evaluate both the model output and PCA baseline for comprehensive benchmarking.

In [18]:
# Get raw expression data for baseline computation
expression_data = dataset.adata.X

# --- Run Embedding Task ---
embedding_task = EmbeddingTask()
embedding_task_input = EmbeddingTaskInput(input_labels=dataset.labels)
embedding_results = embedding_task.run(model_output, embedding_task_input)
embedding_baseline = embedding_task.compute_baseline(expression_data)
embedding_baseline_results = embedding_task.run(embedding_baseline, embedding_task_input)

# --- Run Prediction Task ---
prediction_task = MetadataLabelPredictionTask()
prediction_task_input = MetadataLabelPredictionTaskInput(labels=dataset.labels)
prediction_results = prediction_task.run(model_output, prediction_task_input)
prediction_baseline = prediction_task.compute_baseline(expression_data)
prediction_baseline_results = prediction_task.run(prediction_baseline, prediction_task_input)

INFO:2025-07-18 09:45:13,061:jax._src.xla_bridge:925: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-07-18 09:45:13,068:jax._src.xla_bridge:925: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/sgupta/.pyenv/versions/3.10.16/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Users/sgupta/.pyenv/versions/3.10.16/lib/libtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no suc

In [19]:
all_results = {
    "clustering": {
        "model": [r.model_dump() for r in clustering_results],
        "baseline": [r.model_dump() for r in clustering_baseline_results],
    },
    "embedding": {
        "model": [r.model_dump() for r in embedding_results],
        "baseline": [r.model_dump() for r in embedding_baseline_results],
    },
    "prediction": {
        "model": [r.model_dump() for r in prediction_results],
        "baseline": [r.model_dump() for r in prediction_baseline_results],
    },
}

print(json.dumps(all_results, indent=2, default=str))

{
  "clustering": {
    "model": [
      {
        "metric_type": "MetricType.ADJUSTED_RAND_INDEX",
        "value": 0.7282581538681618,
        "params": {}
      },
      {
        "metric_type": "MetricType.NORMALIZED_MUTUAL_INFO",
        "value": 0.8693815660627174,
        "params": {}
      }
    ],
    "baseline": [
      {
        "metric_type": "MetricType.ADJUSTED_RAND_INDEX",
        "value": 0.626707020983652,
        "params": {}
      },
      {
        "metric_type": "MetricType.NORMALIZED_MUTUAL_INFO",
        "value": 0.8326481406592264,
        "params": {}
      }
    ]
  },
  "embedding": {
    "model": [
      {
        "metric_type": "MetricType.SILHOUETTE_SCORE",
        "value": 0.627913236618042,
        "params": {}
      }
    ],
    "baseline": [
      {
        "metric_type": "MetricType.SILHOUETTE_SCORE",
        "value": 0.6501824855804443,
        "params": {}
      }
    ]
  },
  "prediction": {
    "model": [
      {
        "metric_type": "MetricType