In [1]:
import os
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import typing as t
import pickle
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)  # Set the logging level

# Create a handler
handler = logging.StreamHandler()

# Create and set a formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(handler)

# To suppress the stupid AnnData warning ...
warnings.filterwarnings("ignore", category=UserWarning, message="Transforming to str index.")

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext
from cellarium.ml.utilities.inference.metadata_benchmarking.calculate_metrics import \
    calculate_metrics_for_prediction_output

In [30]:
# arguments
cuda_device_index = 0
val_adata_index = 1

checkpoint_path = "/home/mehrtash/data/100M_long_run/run_001/lightning_logs/version_3/checkpoints/epoch=5-step=504000.ckpt"
ref_adata_path = "/home/mehrtash/data/data/extract_0.h5ad"
gene_info_path = "/home/mehrtash/data/gene_info/gene_info.tsv"
adata_path = f"/home/mehrtash/data/data/cellariumgpt_validation/extract_{val_adata_index}.h5ad"
ontology_resource_path = "/home/mehrtash/data/data/cellariumgpt_artifacts/ontology"
output_path = "/home/mehrtash/data/data/cellariumgpt_artifacts/metadata_predictions/100M_long_run_last"

rng_seed = 42
n_cells = 10000
n_genes = 4_091
gene_selection_method = "random"  # "random" or "highly_expressed"
chunk_size = 16

os.makedirs(output_path, exist_ok=True)

In [31]:
# load ontology resources
logger.info("Loading ontology resources ...")

ontology_benchmarking_resource_path_dict = {
    'cell_type': os.path.join(ontology_resource_path, 'cl_benchmarking_resource.pkl'),
    'development_stage': os.path.join(ontology_resource_path, 'hsapdv_benchmarking_resource.pkl'),
    'disease': os.path.join(ontology_resource_path, 'mondo_benchmarking_resource.pkl'),
    'tissue': os.path.join(ontology_resource_path, 'uberon_benchmarking_resource.pkl'),
    'sex': os.path.join(ontology_resource_path, 'sex_benchmarking_resource.pkl'),
}

ontology_propagation_resource_path_dict = {
    'cell_type': os.path.join(ontology_resource_path, 'cl_propagation_resource.pkl'),
    'development_stage': os.path.join(ontology_resource_path, 'hsapdv_propagation_resource.pkl'),
    'disease': os.path.join(ontology_resource_path, 'mondo_propagation_resource.pkl'),
    'tissue': os.path.join(ontology_resource_path, 'uberon_propagation_resource.pkl'),
    'sex': os.path.join(ontology_resource_path, 'sex_propagation_resource.pkl'),
}

n_hops_dict = {
    'cell_type': 3,
    'development_stage': 3,
    'disease': 3,
    'tissue': 3,
    'sex': 0,
}

ontology_benchmarking_resource_dicts = {}
for meta_key, path in ontology_benchmarking_resource_path_dict.items():
    with open(path, "rb") as f:
        ontology_benchmarking_resource_dicts[meta_key] = pickle.load(f)

ontology_propagation_resource_dicts = {}
for meta_key, path in ontology_propagation_resource_path_dict.items():
    with open(path, "rb") as f:
        ontology_propagation_resource_dicts[meta_key] = pickle.load(f)

2025-02-27 17:36:09,593 - __main__ - INFO - Loading ontology resources ...


In [32]:
logger.info(f"Loading the model checkpoint from {checkpoint_path} ...")

device = torch.device(f"cuda:{cuda_device_index}")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=checkpoint_path,
    ref_adata_path=ref_adata_path,
    gene_info_tsv_path=gene_info_path,
    device=device,
    attention_backend="mem_efficient",
    verbose=False
)

2025-02-27 17:36:10,323 - __main__ - INFO - Loading the model checkpoint from /home/mehrtash/data/100M_long_run/run_001/lightning_logs/version_3/checkpoints/epoch=5-step=504000.ckpt ...


In [33]:
# load a reference adata
adata = sc.read_h5ad(ref_adata_path)

In [34]:
if n_genes is not None:
    if gene_selection_method == "highly_expressed":
        logger.info(f"Using {n_genes} highly expressed genes.")
        X_g = np.asarray(adata.X.sum(0)).flatten()
        highly_expressed_gene_indices = np.argsort(X_g)[-n_genes:]
        highly_expressed_gene_ids = adata.var_names[highly_expressed_gene_indices]
        adata = adata[:, highly_expressed_gene_ids]
        n_rand_prompt_vars = None
        torch_rng = None
    elif gene_selection_method == "random":
        logger.info(f"Using {n_genes} random genes (seed = {rng_seed}).")
        torch_rng = torch.Generator().manual_seed(rng_seed)
        n_rand_prompt_vars = n_genes
    else:
        raise ValueError(f"Unknown gene selection method: {gene_selection_method}")
else:
    logger.info(f"Using all genes.")

if n_cells is None:
    logger.info(f"Using all cells.")
else:
    n_cells = min(n_cells, len(adata))
    logger.info(f"Using {n_cells} random cells (seed = {rng_seed}).")
    rng = np.random.RandomState(rng_seed)
    adata = adata[rng.choice(len(adata), n_cells, replace=False)]

2025-02-27 17:36:17,330 - __main__ - INFO - Using 4091 random genes (seed = 42).
2025-02-27 17:36:17,331 - __main__ - INFO - Using 10000 random cells (seed = 42).


In [35]:
logger.info(f"Predicting metadata for {len(adata)} cells ...")
preds = ctx.predict_metadata_chunked(
    adata=adata,
    chunk_size=chunk_size,
    n_rand_prompt_vars=n_rand_prompt_vars,
    rng=torch_rng)

2025-02-27 17:36:22,307 - __main__ - INFO - Predicting metadata for 10000 cells ...


  0%|          | 0/625 [00:00<?, ?it/s]

In [None]:
metadata_key = 'cell_type'
probs = np.exp(preds[metadata_key])
ontology_names = np.asarray(ctx.metadata_ontology_infos[metadata_key]['names'])
ontology_labels = np.asarray(ctx.metadata_ontology_infos[metadata_key]['labels'])
names_to_labels_map = {n: l for n, l in zip(ontology_names, ontology_labels)}

best_indices = np.argmax(probs, -1)
best_probs = np.max(probs, -1)
best_names = ontology_names[best_indices]
best_labels = list(map(names_to_labels_map.get, best_names))

new_obs = adata.obs.copy()
new_obs['predicted_' + metadata_key] = best_names

In [42]:
for metadata_key in ['cell_type', 'development_stage', 'disease', 'tissue', 'sex']:
    truth_term_ids = adata.obs[f"{metadata_key}_ontology_term_id"].values

    vocab_names = ctx.metadata_ontology_infos[metadata_key]["names"]
    vocab_names_to_idx = {v: i for i, v in enumerate(vocab_names)}
    vocab_names_to_idx["unknown"] = -1
    truth_vocab_indices = np.array([vocab_names_to_idx[t] for t in truth_term_ids])
    truth_logprobs = preds[metadata_key][np.arange(len(preds[metadata_key])), truth_vocab_indices]
    good_indices = np.where(truth_vocab_indices != -1)[0]
    truth_logprobs = truth_logprobs[good_indices]
    mean_truth_logprobs = -truth_logprobs.mean()
    std_truth_logprobs = truth_logprobs.std()
    print(f"{metadata_key} xent loss: {mean_truth_logprobs:.2f} +/- {std_truth_logprobs:.2f}")


cell_type xent loss: 3.58 +/- 2.81
development_stage xent loss: 7.35 +/- 2.68
disease xent loss: 3.51 +/- 5.61
tissue xent loss: 5.31 +/- 3.37
sex xent loss: 1.00 +/- 1.52
