In [1]:
import os
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import typing as t
import pickle
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)  # Set the logging level

# Create a handler
handler = logging.StreamHandler()

# Create and set a formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(handler)

# To suppress the stupid AnnData warning ...
warnings.filterwarnings("ignore", category=UserWarning, message="Transforming to str index.")

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext
from cellarium.ml.utilities.inference.metadata_benchmarking.calculate_metrics import \
    calculate_metrics_for_prediction_output

In [56]:
# arguments
cuda_device_index = 0
val_adata_index = 1

checkpoint_path = "/home/mehrtash/data/100M_long_run/run_001/lightning_logs/version_3/checkpoints/epoch=5-step=504000.ckpt"
ref_adata_path = "/home/mehrtash/data/data/extract_0.h5ad"
gene_info_path = "/home/mehrtash/data/gene_info/gene_info.tsv"
adata_path = f"/home/mehrtash/data/data/cellariumgpt_validation/extract_{val_adata_index}.h5ad"
ontology_resource_path = "/home/mehrtash/data/data/cellariumgpt_artifacts/ontology"
rand_prompt_vars_sublist_path = "/home/mehrtash/data/data/cellariumgpt_artifacts/autosomal_gene_ids.txt"
fixed_prompt_vars_sublist_path = "/home/mehrtash/data/data/cellariumgpt_artifacts/empty_gene_ids.txt"
output_path = "/home/mehrtash/data/data/cellariumgpt_artifacts/metadata_predictions/100M_long_run_last"

rng_seed = 42
n_cells = 1000
n_genes = 4_091
gene_selection_method = "highly_expressed"  # "random" or "highly_expressed"
chunk_size = 16

os.makedirs(output_path, exist_ok=True)

In [57]:
# load ontology resources
logger.info("Loading ontology resources ...")

ontology_benchmarking_resource_path_dict = {
    'cell_type': os.path.join(ontology_resource_path, 'cl_benchmarking_resource.pkl'),
    'development_stage': os.path.join(ontology_resource_path, 'hsapdv_benchmarking_resource.pkl'),
    'disease': os.path.join(ontology_resource_path, 'mondo_benchmarking_resource.pkl'),
    'tissue': os.path.join(ontology_resource_path, 'uberon_benchmarking_resource.pkl'),
    'sex': os.path.join(ontology_resource_path, 'sex_benchmarking_resource.pkl'),
}

ontology_propagation_resource_path_dict = {
    'cell_type': os.path.join(ontology_resource_path, 'cl_propagation_resource.pkl'),
    'development_stage': os.path.join(ontology_resource_path, 'hsapdv_propagation_resource.pkl'),
    'disease': os.path.join(ontology_resource_path, 'mondo_propagation_resource.pkl'),
    'tissue': os.path.join(ontology_resource_path, 'uberon_propagation_resource.pkl'),
    'sex': os.path.join(ontology_resource_path, 'sex_propagation_resource.pkl'),
}

n_hops_dict = {
    'cell_type': 3,
    'development_stage': 3,
    'disease': 3,
    'tissue': 3,
    'sex': 0,
}

ontology_benchmarking_resource_dicts = {}
for meta_key, path in ontology_benchmarking_resource_path_dict.items():
    with open(path, "rb") as f:
        ontology_benchmarking_resource_dicts[meta_key] = pickle.load(f)

ontology_propagation_resource_dicts = {}
for meta_key, path in ontology_propagation_resource_path_dict.items():
    with open(path, "rb") as f:
        ontology_propagation_resource_dicts[meta_key] = pickle.load(f)

2025-03-02 22:37:14,887 - __main__ - INFO - Loading ontology resources ...


In [58]:
logger.info(f"Loading the model checkpoint from {checkpoint_path} ...")

device = torch.device(f"cuda:{cuda_device_index}")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=checkpoint_path,
    ref_adata_path=ref_adata_path,
    gene_info_tsv_path=gene_info_path,
    device=device,
    attention_backend="mem_efficient",
    verbose=False
)

2025-03-02 22:37:16,275 - __main__ - INFO - Loading the model checkpoint from /home/mehrtash/data/100M_long_run/run_001/lightning_logs/version_3/checkpoints/epoch=5-step=504000.ckpt ...


In [59]:
logger.info(f"Loading the validation AnnData object from {adata_path} ...")
adata = sc.read_h5ad(adata_path)

2025-03-02 22:37:20,519 - __main__ - INFO - Loading the validation AnnData object from /home/mehrtash/data/data/cellariumgpt_validation/extract_1.h5ad ...


In [None]:
if n_genes is not None:

    with open(fixed_prompt_vars_sublist_path, "r") as f:
        fixed_prompt_var_names_sublist = f.read().splitlines()
    logger.info(f"Starting with {len(fixed_prompt_var_names_sublist)} fixed genes.")

    if gene_selection_method == "highly_expressed":
        logger.info(f"In addition, using up to {n_genes} highly expressed genes.")
        X_g = np.asarray(adata.X.sum(0)).flatten()
        highly_expressed_gene_indices = np.argsort(X_g)[::-1]
        selected_gene_set = set(fixed_prompt_var_names_sublist)
        target_n_genes = n_genes + len(fixed_prompt_var_names_sublist)
        for idx in highly_expressed_gene_indices:
            if len(selected_gene_set) >= target_n_genes:
                break
            gene_id = adata.var_names[highly_expressed_gene_indices[idx]]
            selected_gene_set.add(gene_id)
        logger.info(f"Selected {len(selected_gene_set)} genes.")
        fixed_prompt_var_names_sublist = list(selected_gene_set)
        rand_prompt_var_names_sublist = []
        n_rand_prompt_vars = 0
        torch_rng = torch.Generator().manual_seed(rng_seed)
    
    elif gene_selection_method == "random":
        logger.info(f"In addition, using {n_genes} random genes (seed = {rng_seed}).")
        torch_rng = torch.Generator().manual_seed(rng_seed)
        n_rand_prompt_vars = n_genes
        with open(rand_prompt_vars_sublist_path, "r") as f:
            rand_prompt_var_names_sublist = f.read().splitlines()
    
    else:
        raise ValueError(f"Unknown gene selection method: {gene_selection_method}")
else:
    logger.info(f"Using all genes.")
    n_rand_prompt_vars = None
    rand_prompt_var_names_sublist = None
    fixed_prompt_var_names_sublist = None
    
if n_cells is None:
    logger.info(f"Using all cells.")
else:
    n_cells = min(n_cells, len(adata))
    logger.info(f"Using {n_cells} random cells (seed = {rng_seed}).")
    rng = np.random.RandomState(rng_seed)
    adata = adata[rng.choice(len(adata), n_cells, replace=False)]

2025-03-02 22:37:23,433 - __main__ - INFO - Starting with 0 fixed genes.
2025-03-02 22:37:23,433 - __main__ - INFO - In addition, using up to 4091 highly expressed genes.
2025-03-02 22:37:23,475 - __main__ - INFO - Selected 4091 genes.
2025-03-02 22:37:23,476 - __main__ - INFO - Using 1000 random cells (seed = 42).


In [61]:
logger.info(f"Predicting metadata for {len(adata)} cells ...")
preds = ctx.predict_metadata_chunked(
    adata=adata,
    chunk_size=chunk_size,
    n_rand_prompt_vars=n_rand_prompt_vars,
    rand_prompt_var_names_sublist=rand_prompt_var_names_sublist,
    fixed_prompt_var_names_sublist=fixed_prompt_var_names_sublist,
    rng=torch_rng)

2025-03-02 22:37:25,800 - __main__ - INFO - Predicting metadata for 1000 cells ...


  0%|          | 0/63 [00:00<?, ?it/s]

In [62]:
for key in preds.keys():
    best_call_n = np.argmax(preds[key], -1)
    best_label = [ctx.metadata_ontology_infos[key]["labels"][idx] for idx in best_call_n]
    adata.obs[key + "_pred"] = best_label

cols = []
for key in preds.keys():
    cols.append(key)
    cols.append(key + "_pred")
new_obs = adata.obs[cols]

  adata.obs[key + "_pred"] = best_label


In [63]:
new_obs.head(50)

Unnamed: 0,cell_type,cell_type_pred,tissue,tissue_pred,development_stage,development_stage_pred,disease,disease_pred,sex,sex_pred
31589848,oligodendrocyte,oligodendrocyte,dentate nucleus,temporal lobe,52-year-old human stage,61-year-old stage,normal,normal,male,female
31594264,oligodendrocyte,oligodendrocyte,dentate nucleus,cerebellum,52-year-old human stage,61-year-old stage,normal,normal,male,female
31601455,unknown,cerebellar granule cell,dentate nucleus,cerebellum,52-year-old human stage,61-year-old stage,normal,normal,male,female
31589769,oligodendrocyte,oligodendrocyte,dentate nucleus,temporal lobe,52-year-old human stage,61-year-old stage,normal,normal,male,female
31593942,unknown,neuron,dentate nucleus,cerebellum,52-year-old human stage,23-year-old stage,normal,normal,male,female
31598517,cerebellar granule cell,epithelial cell,dentate nucleus,cerebellum,52-year-old human stage,61-year-old stage,normal,normal,male,male
31599161,cerebellar granule cell,stratified epithelial cell,dentate nucleus,cerebellum,52-year-old human stage,61-year-old stage,normal,normal,male,female
31601198,cerebellar granule cell,neuron,dentate nucleus,cerebellum,52-year-old human stage,61-year-old stage,normal,normal,male,female
31599402,cerebellar granule cell,stratified epithelial cell,dentate nucleus,cerebellum,52-year-old human stage,61-year-old stage,normal,normal,male,male
31594186,microglial cell,macrophage,dentate nucleus,temporal lobe,52-year-old human stage,61-year-old stage,normal,normal,male,male


In [71]:
_adata = adata[adata.obs['cell_type'] == 'cerebellar granule cell'].copy()

In [72]:
sc.pp.normalize_total(_adata, target_sum=1e4)
# sc.pp.log1p(_adata)

In [73]:
expr_g = np.asarray(_adata.X.mean(0)).flatten()
order = np.argsort(expr_g)

In [74]:
for o in order[::-1][:50]:
    gene_id = _adata.var_names[o]
    expr = expr_g[o]
    gene_name = ctx.gene_id_to_gene_symbol_map[gene_id]
    print(gene_name, expr)

MALAT1 1176.4282
KCNIP4 51.889442
RBFOX1 39.08155
NRXN1 37.05956
ZNF385D 31.274553
NRXN3 29.412807
RALYL 28.302118
FGF14 27.987516
CDH18 26.29924
KCND2 25.589275
SYT1 25.392553
CADM2 24.365475
GRIK2 22.64451
FSTL5 20.709612
ADGRB3 20.459864
CADPS2 19.70784
DPP6 19.099762
CHN2 18.342234
TIAM1 17.454777
JMJD1C 16.945967
RIMS1 16.903675
SNHG14 16.197824
STXBP5L 15.808228
SNAP25 15.755832
ANKS1B 15.605362
GRID2 15.461536
KAZN 15.290413
MSRA 15.083475
CALM1 13.966373
DLGAP1 13.554987
RORA 13.512486
NTM 13.479536
NLGN1 13.344391
CALN1 13.213117
CACNA1A 12.965487
LINC00486 12.807418
NFIA 12.51874
SYNE1 12.420801
ZFPM2 12.312247
ANK3 12.283244
RBFOX3 11.96194
CA10 11.820011
UNC13C 11.696393
ZBTB20 11.657179
NKAIN2 11.602788
MAP1B 11.503829
ERC1 11.317675
OPCML 11.300125
ABLIM1 11.08845
CAMK4 11.029335
