### Generate metadata predictions

In [None]:
import os
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import typing as t
import pickle
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)  # Set the logging level

# Create a handler
handler = logging.StreamHandler()

# Create and set a formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(handler)

# To suppress the stupid AnnData warning ...
warnings.filterwarnings("ignore", category=UserWarning, message="Transforming to str index.")

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext
from cellarium.ml.utilities.inference.metadata_benchmarking.calculate_metrics import \
    calculate_metrics_for_prediction_output

In [None]:
# arguments
cuda_device_index = 0
val_adata_index = 1

checkpoint_path = "/home/mehrtash/data/100M_long_run/run_001/lightning_logs/version_3/checkpoints/epoch=5-step=504000.ckpt"
ref_adata_path = "/home/mehrtash/data/data/extract_0.h5ad"
gene_info_path = "/home/mehrtash/data/gene_info/gene_info.tsv"
adata_path = f"/home/mehrtash/data/data/cellariumgpt_validation/extract_{val_adata_index}.h5ad"
ontology_resource_path = "/home/mehrtash/data/data/cellariumgpt_artifacts/ontology"
output_path = "/home/mehrtash/data/data/cellariumgpt_artifacts/metadata_predictions/100M_long_run_last"

rng_seed = 42
n_cells = 1000
n_genes = 4_091
gene_selection_method = "random"  # "random" or "highly_expressed"
chunk_size = 16

os.makedirs(output_path, exist_ok=True)

In [None]:
# load ontology resources
logger.info("Loading ontology resources ...")

ontology_benchmarking_resource_path_dict = {
    'cell_type': os.path.join(ontology_resource_path, 'cl_benchmarking_resource.pkl'),
    'development_stage': os.path.join(ontology_resource_path, 'hsapdv_benchmarking_resource.pkl'),
    'disease': os.path.join(ontology_resource_path, 'mondo_benchmarking_resource.pkl'),
    'tissue': os.path.join(ontology_resource_path, 'uberon_benchmarking_resource.pkl'),
    'sex': os.path.join(ontology_resource_path, 'sex_benchmarking_resource.pkl'),
}

ontology_propagation_resource_path_dict = {
    'cell_type': os.path.join(ontology_resource_path, 'cl_propagation_resource.pkl'),
    'development_stage': os.path.join(ontology_resource_path, 'hsapdv_propagation_resource.pkl'),
    'disease': os.path.join(ontology_resource_path, 'mondo_propagation_resource.pkl'),
    'tissue': os.path.join(ontology_resource_path, 'uberon_propagation_resource.pkl'),
    'sex': os.path.join(ontology_resource_path, 'sex_propagation_resource.pkl'),
}

n_hops_dict = {
    'cell_type': 3,
    'development_stage': 3,
    'disease': 3,
    'tissue': 3,
    'sex': 0,
}

ontology_benchmarking_resource_dicts = {}
for meta_key, path in ontology_benchmarking_resource_path_dict.items():
    with open(path, "rb") as f:
        ontology_benchmarking_resource_dicts[meta_key] = pickle.load(f)

ontology_propagation_resource_dicts = {}
for meta_key, path in ontology_propagation_resource_path_dict.items():
    with open(path, "rb") as f:
        ontology_propagation_resource_dicts[meta_key] = pickle.load(f)

In [None]:
logger.info(f"Loading the model checkpoint from {checkpoint_path} ...")

device = torch.device(f"cuda:{cuda_device_index}")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=checkpoint_path,
    ref_adata_path=ref_adata_path,
    gene_info_tsv_path=gene_info_path,
    device=device,
    attention_backend="mem_efficient",
    verbose=False
)

In [None]:
logger.info(f"Loading the validation AnnData object from {adata_path} ...")
adata = sc.read_h5ad(adata_path)

In [None]:
if n_genes is not None:
    if gene_selection_method == "highly_expressed":
        logger.info(f"Using {n_genes} highly expressed genes.")
        X_g = np.asarray(adata.X.sum(0)).flatten()
        highly_expressed_gene_indices = np.argsort(X_g)[-n_genes:]
        highly_expressed_gene_ids = adata.var_names[highly_expressed_gene_indices]
        adata = adata[:, highly_expressed_gene_ids]
        n_rand_prompt_vars = None
        torch_rng = None
    elif gene_selection_method == "random":
        logger.info(f"Using {n_genes} random genes (seed = {rng_seed}).")
        torch_rng = torch.Generator().manual_seed(rng_seed)
        n_rand_prompt_vars = n_genes
    else:
        raise ValueError(f"Unknown gene selection method: {gene_selection_method}")
else:
    logger.info(f"Using all genes.")

if n_cells is None:
    logger.info(f"Using all cells.")
else:
    n_cells = min(n_cells, len(adata))
    logger.info(f"Using {n_cells} random cells (seed = {rng_seed}).")
    rng = np.random.RandomState(rng_seed)
    adata = adata[rng.choice(len(adata), n_cells, replace=False)]

In [None]:
logger.info(f"Predicting metadata for {len(adata)} cells ...")
preds = ctx.predict_metadata_chunked(
    adata=adata,
    chunk_size=chunk_size,
    n_rand_prompt_vars=n_rand_prompt_vars,
    rng=torch_rng)

In [None]:
def propagate_probs_over_ontology(
        class_probs_nk: np.ndarray,
        class_names_k: t.List[str],
        ontology_resource_dict: dict[str, t.Any],
) -> t.Tuple[t.List[str], np.ndarray]:
    assert len(class_names_k) == class_probs_nk.shape[1]
    all_class_names_q = sorted(ontology_resource_dict.keys())
    propagated_class_probs_nq = np.zeros((class_probs_nk.shape[0], len(all_class_names_q)))

    given_class_name_to_idx = {class_name: idx for idx, class_name in enumerate(class_names_k)}    
    for q, class_name in enumerate(all_class_names_q):
        for descendant_name in ontology_resource_dict[class_name]["all_descendants"]:
            if descendant_name in given_class_name_to_idx:
                propagated_class_probs_nq[:, q] += class_probs_nk[:, given_class_name_to_idx[descendant_name]]
    return all_class_names_q, propagated_class_probs_nq

def convert_meta_adata_to_query_obj_for_scoring(
        meta_adata: sc.AnnData,
        metadata_key: str):
    assert metadata_key + "_propagated_class_probs" in meta_adata.obsm
    assert metadata_key + "_propagated_ontology_term_ids" in meta_adata.uns
    query_objs = []
    ground_truth_ontology_term_ids = []
    obs_index = meta_adata.obs.index.values
    for i_cell in range(len(meta_adata)):
        obs_row = meta_adata.obs.iloc[i_cell]
        ground_truth_ontology_term_id = obs_row[metadata_key + "_ontology_term_id"]
        query_obj = dict()
        query_obj["query_cell_id"] = obs_index[i_cell]
        query_obj["matches"] = []
        for ontology_term_id, score in zip(
                meta_adata.uns[metadata_key + "_propagated_ontology_term_ids"],
                meta_adata.obsm[metadata_key + "_propagated_class_probs"][i_cell]):
            query_obj["matches"].append({
                "ontology_term_id": ontology_term_id,
                "score": score,
            })
        query_objs.append(query_obj)
        ground_truth_ontology_term_ids.append(ground_truth_ontology_term_id)
    return query_objs, ground_truth_ontology_term_ids

logger.info("Inserting predictions into an AnnData object ...")
# put the predictions back into an AnnData object
meta_adata = sc.AnnData(obs=adata.obs.copy())

for meta_key, meta_preds in preds.items():
    meta_adata.obsm[meta_key + "_class_logits"] = meta_preds
    meta_adata.obsm[meta_key + "_class_probs"] = np.exp(meta_preds)
    meta_adata.uns[meta_key + "_ontology_term_ids"] = ctx.metadata_ontology_infos[meta_key]["names"]
    meta_adata.uns[meta_key + "_labels"] = ctx.metadata_ontology_infos[meta_key]["labels"]

# propagate predictions to make them ontologically consistent
for meta_key, ontology_resource_dict in ontology_benchmarking_resource_dicts.items():
    class_probs_nk = meta_adata.obsm[meta_key + "_class_probs"]
    all_class_names_q, propagated_class_probs_nq = propagate_probs_over_ontology(
        class_probs_nk=class_probs_nk,
        class_names_k=meta_adata.uns[meta_key + "_ontology_term_ids"],
        ontology_resource_dict=ontology_resource_dict,
    )
    meta_adata.obsm[meta_key + "_propagated_class_probs"] = propagated_class_probs_nq
    meta_adata.uns[meta_key + "_propagated_ontology_term_ids"] = all_class_names_q
    meta_adata.uns[meta_key + "_propagated_labels"] = list(
        map(ontology_propagation_resource_dicts[meta_key]['ontology_term_id_to_label'].get, all_class_names_q))

In [None]:
meta_keys = ontology_benchmarking_resource_dicts.keys()

results_dfs = []
for meta_key in meta_keys:
    logger.info(f"Calculating performance metrics for {meta_key} ...")
    query_objs, ground_truth_ontology_term_ids = convert_meta_adata_to_query_obj_for_scoring(
        meta_adata=meta_adata,
        metadata_key=meta_key)
    results_df = calculate_metrics_for_prediction_output(
        model_predictions=query_objs,
        ground_truth_ontology_term_ids=ground_truth_ontology_term_ids,
        ontology_resource=ontology_benchmarking_resource_dicts[meta_key],
        num_hops=n_hops_dict[meta_key])
    results_df.columns = [f"{meta_key}_{col}" if col != "query_cell_id" else col for col in results_df.columns]
    results_dfs.append(results_df)

# merge results dataframes
final_results_df = pd.concat(results_dfs, axis=1)
meta_adata.obs.index.name = "query_cell_id"
meta_adata.obs = pd.concat([meta_adata.obs, final_results_df], axis=1) 

In [None]:
logger.info(f"Saving the results to {output_path} ...")
meta_adata_output_file_path = os.path.join(output_path, f"extract_{val_adata_index}_metadata_prediction_scores.h5ad")
meta_adata.write_h5ad(meta_adata_output_file_path)

In [None]:
logger.info("Done!")