## Benchmarking scVI MLflow Model with czbenchmarks

This notebook demonstrates how to:

- Package a pre-trained scVI model using MLflow
- Use czbenchmarks datasets for evaluation
- Run inference through the MLflow model interface
- Evaluate the model embeddings using czbenchmarks tasks
- Compare results with PCA baselines



### Setup and Model Packaging

In this section, we'll clone the scvi_mlflow_pkg repository, set up the environment, and package the scVI model as an MLflow artifact. We'll modify the requirements to work on Mac by commenting out NVIDIA-specific libraries.

In [None]:
import os
import subprocess
import sys
import json
import tempfile
import shutil
from pathlib import Path

# Clone the repository if it doesn't exist
repo_url = "https://github.com/chanzuckerberg/vcp-model-pkg-client-tools.git"
repo_dir = "vcp-model-pkg-client-tools"

if not os.path.exists(repo_dir):
    print("🔄 Cloning repository...")
    subprocess.run(["git", "clone", repo_url], check=True)
    print("✅ Repository cloned successfully")

# Navigate to the scvi_mlflow_pkg directory
scvi_pkg_dir = os.path.join(repo_dir, "examples", "mlflow_pkgs", "scvi_mlflow_pkg")
os.chdir(scvi_pkg_dir)

print(f"📁 Working directory: {os.getcwd()}")

# Read and modify requirements.txt to remove NVIDIA dependencies for Mac compatibility
requirements_path = "requirements.txt"
if os.path.exists(requirements_path):
    with open(requirements_path, 'r') as f:
        requirements = f.read()
    
    # Comment out NVIDIA-related libraries
    nvidia_libs = ['nvidia-', 'cupy-', 'torch+cu', 'cuda']
    modified_requirements = []
    
    for line in requirements.split('\n'):
        if any(lib in line.lower() for lib in nvidia_libs):
            modified_requirements.append(f"# {line}  # Commented for Mac compatibility")
        else:
            modified_requirements.append(line)
    
    # Write modified requirements
    with open(requirements_path, 'w') as f:
        f.write('\n'.join(modified_requirements))
    
    print("✅ Modified requirements.txt for Mac compatibility")

# Create virtual environment for MLflow packaging
venv_name = ".venv_mlflow_scvi"
if not os.path.exists(venv_name):
    print(f"🔄 Creating virtual environment: {venv_name}")
    subprocess.run([sys.executable, "-m", "venv", venv_name], check=True)
    print("✅ Virtual environment created")

# Install requirements in the virtual environment
pip_path = os.path.join(venv_name, "bin", "pip") if os.name != 'nt' else os.path.join(venv_name, "Scripts", "pip.exe")
python_path = os.path.join(venv_name, "bin", "python") if os.name != 'nt' else os.path.join(venv_name, "Scripts", "python.exe")

print("🔄 Installing requirements...")
subprocess.run([pip_path, "install", "--upgrade", "pip"], check=True)
subprocess.run([pip_path, "install", "--no-deps", "-r", requirements_path], check=True)
print("✅ Requirements installed")

# Download model artifacts from S3
model_data_dir = "model_data"
os.makedirs(model_data_dir, exist_ok=True)

print("🔄 Downloading model artifacts from S3...")
try:
    # Download human model
    subprocess.run([
        "aws", "s3", "sync", 
        "s3://cz-benchmarks-data/models/v1/scvi_2023_12_15/homo_sapiens", 
        f"{model_data_dir}/homo_sapiens"
    ], check=True)
    
    # Download mouse model
    subprocess.run([
        "aws", "s3", "sync", 
        "s3://cz-benchmarks-data/models/v1/scvi_2023_12_15/mus_musculus", 
        f"{model_data_dir}/mus_musculus"
    ], check=True)
    
    print("✅ Model artifacts downloaded successfully")
except subprocess.CalledProcessError as e:
    print(f"❌ Error downloading model artifacts: {e}")
    print("Please ensure AWS CLI is installed and configured with appropriate credentials")

# Package the MLflow model
print("🔄 Packaging MLflow model...")
mlflow_artifact_dir = "mlflow_model_artifact"

# Remove existing artifact directory if it exists
if os.path.exists(mlflow_artifact_dir):
    shutil.rmtree(mlflow_artifact_dir)

# Run the MLflow packager
package_cmd = [
    python_path, "mlflow_packager.py",
    "--model-class", "model_code.scvi_mlflow_model:ScviMLflowModel",
    "--artifact", "homo_sapiens=homo_sapiens",
    "--artifact", "mus_musculus=mus_musculus", 
    "--model-config-json", '{"organism":"human"}',
    "--skip-inference"
]

try:
    subprocess.run(package_cmd, check=True, cwd=".")
    print("✅ MLflow model packaged successfully")
    
    # Verify the artifact structure
    if os.path.exists(mlflow_artifact_dir):
        print(f"📦 MLflow model artifact created at: {os.path.abspath(mlflow_artifact_dir)}")
        print("📁 Artifact structure:")
        for root, dirs, files in os.walk(mlflow_artifact_dir):
            level = root.replace(mlflow_artifact_dir, '').count(os.sep)
            indent = ' ' * 2 * level
            print(f"{indent}{os.path.basename(root)}/")
            subindent = ' ' * 2 * (level + 1)
            for file in files:
                print(f"{subindent}{file}")
    
except subprocess.CalledProcessError as e:
    print(f"❌ Error packaging MLflow model: {e}")

# Store the MLflow model path for later use
MLFLOW_MODEL_PATH = os.path.abspath(mlflow_artifact_dir)
print(f"🎯 MLflow model ready at: {MLFLOW_MODEL_PATH}")

### Dataset Preparation with czbenchmarks

Now we'll load a dataset from czbenchmarks and prepare it for inference with our MLflow-packaged scVI model. We need to ensure the dataset has the required observation columns (batch_keys) and save it in the correct format.

In [None]:
# Change back to the original working directory for czbenchmarks
os.chdir("../../../..")  # Navigate back to the notebook directory
print(f"📁 Working directory: {os.getcwd()}")

# Install czbenchmarks if not already installed
try:
    import czbenchmarks
    print("✅ czbenchmarks already available")
except ImportError:
    print("🔄 Installing czbenchmarks...")
    subprocess.run([sys.executable, "-m", "pip", "install", "czbenchmarks"], check=True)
    print("✅ czbenchmarks installed")

# Import required libraries
import logging
import functools
import pandas as pd
import numpy as np
from czbenchmarks.datasets import load_dataset
from czbenchmarks.datasets.single_cell_labeled import SingleCellLabeledDataset
import anndata as ad

# Set up logging
logging.basicConfig(level=logging.INFO, stream=sys.stdout)

print("\n" + "=" * 60)
print("--- Loading and Preparing czbenchmarks Dataset ---")
print("=" * 60)

# Load the dataset
print("🔄 Loading tsv2_prostate dataset...")
dataset: SingleCellLabeledDataset = load_dataset("tsv2_prostate")
print("✅ Dataset loaded successfully")
print(f"📊 Dataset shape: {dataset.adata.shape}")
print(f"🏷️  Labels: {dataset.labels.name} with {len(dataset.labels.unique())} unique values")

# Prepare the dataset for scVI MLflow model
print("\n🔄 Preparing dataset for scVI MLflow model...")

# The MLflow model expects specific batch_keys - let's check what we have
adata = dataset.adata.copy()
print(f"📋 Available observation columns: {list(adata.obs.columns)}")

# Default batch_keys expected by the model
expected_batch_keys = ["dataset_id", "assay", "suspension_type", "donor_id"]

# Check which batch_keys are available and create missing ones
available_batch_keys = []
for key in expected_batch_keys:
    if key in adata.obs.columns:
        available_batch_keys.append(key)
        print(f"✅ Found {key}: {adata.obs[key].nunique()} unique values")
    else:
        # Create a default value for missing batch keys
        adata.obs[key] = f"default_{key}"
        available_batch_keys.append(key)
        print(f"⚠️  Created missing {key} with default value")

# Create batch identifier by combining batch_keys
print(f"🔗 Creating batch identifier from: {available_batch_keys}")
adata.obs["batch"] = functools.reduce(
    lambda a, b: a + b, 
    [adata.obs[c].astype(str) + "_" for c in available_batch_keys]
)

# Remove trailing underscore
adata.obs["batch"] = adata.obs["batch"].str.rstrip("_")

print(f"📦 Created {adata.obs['batch'].nunique()} unique batch identifiers")

# Save the prepared dataset
prepared_data_path = "prepared_dataset.h5ad"
adata.write_h5ad(prepared_data_path)
print(f"💾 Prepared dataset saved to: {os.path.abspath(prepared_data_path)}")

# Store information for next steps
DATASET_PATH = os.path.abspath(prepared_data_path)
BATCH_KEYS_STR = ",".join(available_batch_keys)
ORIGINAL_LABELS = dataset.labels.copy()
ORIGINAL_ADATA_FOR_BASELINE = dataset.adata.copy()  # Keep original for baseline computation

print(f"🎯 Dataset ready for MLflow model inference")
print(f"📁 Dataset path: {DATASET_PATH}")
print(f"🔑 Batch keys: {BATCH_KEYS_STR}")

### MLflow Model Inference

In this section, we'll prepare the input for our MLflow-packaged scVI model, run inference, and extract the embeddings. The MLflow model expects input in a specific JSON format with the dataset path and parameters.

In [None]:
import mlflow
import json
import tempfile

print("\n" + "=" * 60)
print("--- Running MLflow Model Inference ---")
print("=" * 60)

# Prepare input for MLflow model
print("🔄 Preparing MLflow model input...")

# Create the input JSON structure expected by the MLflow model
mlflow_input = {
    "dataframe_split": {
        "columns": ["input_uri"],
        "data": [[DATASET_PATH]]
    },
    "params": {
        "organism": "human",
        "return_dist": True,  # Return mean embeddings
        "batch_keys": BATCH_KEYS_STR
    }
}

# Save input to temporary file
input_json_path = "mlflow_input.json"
with open(input_json_path, 'w') as f:
    json.dump(mlflow_input, f, indent=2)

print("✅ MLflow input prepared:")
print(json.dumps(mlflow_input, indent=2))

# Run MLflow model prediction
print(f"\n🔄 Running inference with MLflow model...")
print(f"📦 Model path: {MLFLOW_MODEL_PATH}")
print(f"📄 Input path: {os.path.abspath(input_json_path)}")

output_json_path = "mlflow_output.json"

try:
    # Use MLflow's predict API
    print("🚀 Starting MLflow model prediction...")
    
    # Load the MLflow model
    loaded_model = mlflow.pyfunc.load_model(MLFLOW_MODEL_PATH)
    
    # Prepare input DataFrame for prediction
    input_df = pd.DataFrame({"input_uri": [DATASET_PATH]})
    
    # Run prediction with parameters
    predictions = loaded_model.predict(
        input_df, 
        params={
            "organism": "human",
            "return_dist": True,
            "batch_keys": BATCH_KEYS_STR
        }
    )
    
    print("✅ MLflow model prediction completed successfully")
    print(f"📊 Predictions shape: {predictions.shape}")
    print(f"📈 Predictions type: {type(predictions)}")
    
    # Convert to numpy array if needed
    if hasattr(predictions, 'values'):
        model_embeddings = predictions.values
    else:
        model_embeddings = np.array(predictions)
    
    print(f"🎯 Generated embeddings shape: {model_embeddings.shape}")
    print(f"📊 Embedding statistics:")
    print(f"   Mean: {model_embeddings.mean():.4f}")
    print(f"   Std: {model_embeddings.std():.4f}")
    print(f"   Min: {model_embeddings.min():.4f}")
    print(f"   Max: {model_embeddings.max():.4f}")
    
    # Save embeddings for later use
    np.save("scvi_mlflow_embeddings.npy", model_embeddings)
    print("💾 Embeddings saved to: scvi_mlflow_embeddings.npy")
    
except Exception as e:
    print(f"❌ Error running MLflow model: {e}")
    print("🔍 Troubleshooting tips:")
    print("   1. Ensure the MLflow model was packaged correctly")
    print("   2. Check that the dataset has required batch_keys columns")
    print("   3. Verify the model artifacts are properly downloaded")
    raise

# Verify embeddings are ready for czbenchmarks tasks
print(f"\n🎯 Embeddings ready for czbenchmarks evaluation!")
print(f"📐 Embedding dimensions: {model_embeddings.shape}")
print(f"🔢 Number of cells: {model_embeddings.shape[0]}")
print(f"📏 Latent dimensions: {model_embeddings.shape[1]}")

### czbenchmarks Task Evaluation

Finally, we'll use the embeddings from our MLflow model to run czbenchmarks tasks including clustering, embedding quality assessment, and metadata label prediction. We'll compare the results with PCA baselines and generate visualizations.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate
import scanpy as sc

# Import czbenchmarks tasks
from czbenchmarks.tasks import (
    ClusteringTask,
    EmbeddingTask,
    MetadataLabelPredictionTask,
)
from czbenchmarks.tasks.clustering import ClusteringTaskInput
from czbenchmarks.tasks.embedding import EmbeddingTaskInput
from czbenchmarks.tasks.label_prediction import MetadataLabelPredictionTaskInput

print("\n" + "=" * 60)
print("--- czbenchmarks Task Evaluation ---")
print("=" * 60)

# Set up visualization style
sns.set_theme(style="whitegrid")
plt.style.use('default')

# Load embeddings from previous step
try:
    model_embeddings = np.load("scvi_mlflow_embeddings.npy")
    print(f"✅ Loaded MLflow model embeddings: {model_embeddings.shape}")
except:
    print("❌ Could not load embeddings from previous step")
    raise

# Prepare data for tasks
expression_data = ORIGINAL_ADATA_FOR_BASELINE.X
labels = ORIGINAL_LABELS
obs_data = ORIGINAL_ADATA_FOR_BASELINE.obs

print(f"📊 Expression data shape: {expression_data.shape}")
print(f"🏷️  Labels shape: {labels.shape}")
print(f"📋 Number of unique labels: {labels.nunique()}")

# Store all results
all_results = {}

# =============================================================================
# Task 1: Clustering Performance
# =============================================================================
print(f"\n{'='*50}")
print("🎯 Task 1: Clustering Performance")
print(f"{'='*50}")

print("🔄 Initializing clustering task...")
clustering_task = ClusteringTask()
clustering_task_input = ClusteringTaskInput(
    obs=obs_data,
    input_labels=labels,
)

print("🚀 Running clustering evaluation on MLflow model embeddings...")
clustering_results_model = clustering_task.run(
    cell_representation=model_embeddings,
    task_input=clustering_task_input,
)

print("📊 Computing PCA baseline...")
clustering_baseline_embedding = clustering_task.compute_baseline(expression_data)
clustering_results_baseline = clustering_task.run(
    cell_representation=clustering_baseline_embedding,
    task_input=clustering_task_input,
)

# Store results
all_results["clustering"] = {
    "model": [r.model_dump() for r in clustering_results_model],
    "baseline": [r.model_dump() for r in clustering_results_baseline],
}

print("✅ Clustering evaluation completed")

# Visualize clustering results
print("📈 Creating clustering performance visualization...")
df_clustering_model = pd.DataFrame(all_results["clustering"]["model"])
df_clustering_baseline = pd.DataFrame(all_results["clustering"]["baseline"])
df_clustering_model["source"] = "scVI MLflow Model"
df_clustering_baseline["source"] = "PCA Baseline"
df_clustering = pd.concat([df_clustering_model, df_clustering_baseline])
df_clustering["metric_name"] = df_clustering["metric_type"].apply(lambda x: x.name)

plt.figure(figsize=(12, 6))
sns.barplot(
    data=df_clustering, x="metric_name", y="value", hue="source", 
    palette=["#2E86AB", "#A23B72"]
)
plt.title("Clustering Performance: scVI MLflow Model vs. PCA Baseline", fontsize=16, pad=20)
plt.ylabel("Score", fontsize=12)
plt.xlabel("Metric", fontsize=12)
plt.legend(title="Method", frameon=True)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig("clustering_performance_mlflow.png", dpi=300, bbox_inches='tight')
plt.show()

# =============================================================================
# Task 2: Embedding Quality
# =============================================================================
print(f"\n{'='*50}")
print("🎯 Task 2: Embedding Quality Assessment")
print(f"{'='*50}")

print("🔄 Initializing embedding quality task...")
embedding_task = EmbeddingTask()
embedding_task_input = EmbeddingTaskInput(input_labels=labels)

print("🚀 Evaluating MLflow model embedding quality...")
embedding_results_model = embedding_task.run(model_embeddings, embedding_task_input)

print("📊 Computing PCA baseline...")
embedding_baseline_embedding = embedding_task.compute_baseline(expression_data)
embedding_results_baseline = embedding_task.run(
    embedding_baseline_embedding, embedding_task_input
)

# Store results
all_results["embedding"] = {
    "model": [r.model_dump() for r in embedding_results_model],
    "baseline": [r.model_dump() for r in embedding_results_baseline],
}

print("✅ Embedding quality evaluation completed")

# Visualize embedding quality results
print("📈 Creating embedding quality visualization...")
df_embedding_model = pd.DataFrame(all_results["embedding"]["model"])
df_embedding_baseline = pd.DataFrame(all_results["embedding"]["baseline"])
df_embedding_model["source"] = "scVI MLflow Model"
df_embedding_baseline["source"] = "PCA Baseline"
df_embedding = pd.concat([df_embedding_model, df_embedding_baseline])
df_embedding["metric_name"] = df_embedding["metric_type"].apply(lambda x: x.name)

plt.figure(figsize=(10, 6))
sns.barplot(
    data=df_embedding, x="metric_name", y="value", hue="source",
    palette=["#F18F01", "#C73E1D"]
)
plt.title("Embedding Quality: scVI MLflow Model vs. PCA Baseline", fontsize=16, pad=20)
plt.ylabel("Silhouette Score", fontsize=12)
plt.xlabel("Metric", fontsize=12)
plt.legend(title="Method", frameon=True)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig("embedding_quality_mlflow.png", dpi=300, bbox_inches='tight')
plt.show()

# =============================================================================
# Task 3: Metadata Label Prediction
# =============================================================================
print(f"\n{'='*50}")
print("🎯 Task 3: Metadata Label Prediction")
print(f"{'='*50}")

print("🔄 Initializing label prediction task...")
prediction_task = MetadataLabelPredictionTask()
prediction_task_input = MetadataLabelPredictionTaskInput(labels=labels)

print("🚀 Running label prediction on MLflow model embeddings...")
prediction_results_model = prediction_task.run(model_embeddings, prediction_task_input)

print("📊 Computing PCA baseline...")
prediction_baseline_embedding = prediction_task.compute_baseline(expression_data)
prediction_results_baseline = prediction_task.run(
    prediction_baseline_embedding, prediction_task_input
)

# Store results
all_results["prediction"] = {
    "model": [r.model_dump() for r in prediction_results_model],
    "baseline": [r.model_dump() for r in prediction_results_baseline],
}

print("✅ Label prediction evaluation completed")

# Visualize prediction results
print("📈 Creating label prediction visualization...")
df_pred_model = pd.DataFrame(all_results["prediction"]["model"])
df_pred_baseline = pd.DataFrame(all_results["prediction"]["baseline"])
df_pred_model["source"] = "scVI MLflow Model"
df_pred_baseline["source"] = "PCA Baseline"
df_pred = pd.concat([df_pred_model, df_pred_baseline])
df_pred["metric_name"] = df_pred["metric_type"].apply(lambda x: x.name)
df_pred["classifier"] = df_pred["params"].apply(lambda p: p.get("classifier", "Overall"))

# Filter for mean metrics for cleaner visualization
df_pred_mean = df_pred[df_pred["classifier"].str.contains("MEAN", na=False)]

plt.figure(figsize=(14, 7))
sns.barplot(
    data=df_pred_mean, x="metric_name", y="value", hue="source",
    palette=["#3A86FF", "#FF006E"]
)
plt.title("Label Prediction Performance: scVI MLflow Model vs. PCA Baseline", fontsize=16, pad=20)
plt.ylabel("Score", fontsize=12)
plt.xlabel("Metric", fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.legend(title="Method", frameon=True)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig("label_prediction_mlflow.png", dpi=300, bbox_inches='tight')
plt.show()

# =============================================================================
# Summary Results Table
# =============================================================================
print(f"\n{'='*60}")
print("📊 COMPREHENSIVE RESULTS SUMMARY")
print(f"{'='*60}")

def create_summary_table(results_dict):
    """Create a comprehensive summary table of all results."""
    summary_rows = []
    
    for task, sources in results_dict.items():
        for source, metrics in sources.items():
            for metric in metrics:
                metric_name = metric.get("metric_type", {}).get("name", "Unknown")
                value = metric.get("value", 0)
                
                # Format value based on metric type
                if isinstance(value, float):
                    formatted_value = f"{value:.4f}"
                else:
                    formatted_value = str(value)
                
                summary_rows.append({
                    "Task": task.replace("_", " ").title(),
                    "Method": "scVI MLflow" if source == "model" else "PCA Baseline",
                    "Metric": metric_name,
                    "Score": formatted_value
                })
    
    return summary_rows

# Create and display summary table
summary_data = create_summary_table(all_results)
summary_df = pd.DataFrame(summary_data)

print("\n🏆 Performance Summary:")
print(tabulate(summary_df, headers='keys', tablefmt='grid', showindex=False))

# Calculate improvement over baseline
print(f"\n📈 Performance Improvements (scVI MLflow vs PCA Baseline):")
improvement_summary = []

for task in all_results:
    model_results = pd.DataFrame(all_results[task]["model"])
    baseline_results = pd.DataFrame(all_results[task]["baseline"])
    
    for _, model_row in model_results.iterrows():
        metric_type = model_row["metric_type"]
        model_value = model_row["value"]
        
        # Find corresponding baseline result
        baseline_row = baseline_results[
            baseline_results["metric_type"].apply(lambda x: x == metric_type)
        ]
        
        if not baseline_row.empty:
            baseline_value = baseline_row.iloc[0]["value"]
            improvement = ((model_value - baseline_value) / baseline_value) * 100
            
            improvement_summary.append({
                "Task": task.replace("_", " ").title(),
                "Metric": metric_type.name,
                "Improvement": f"{improvement:+.2f}%"
            })

improvement_df = pd.DataFrame(improvement_summary)
print(tabulate(improvement_df, headers='keys', tablefmt='grid', showindex=False))

print(f"\n🎯 Evaluation completed successfully!")
print(f"📁 Visualizations saved:")
print(f"   • clustering_performance_mlflow.png")
print(f"   • embedding_quality_mlflow.png") 
print(f"   • label_prediction_mlflow.png")

print(f"\n💡 Key Insights:")
print(f"   • scVI MLflow model generated {model_embeddings.shape[1]}D embeddings for {model_embeddings.shape[0]} cells")
print(f"   • Performance compared against PCA baseline across multiple tasks")
print(f"   • Results show the effectiveness of pre-trained scVI models via MLflow deployment")