### Embedding analysis

In [1]:
import os
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

# for flex attention
import torch._dynamo
torch._dynamo.config.suppress_errors = True

# DEVICE = torch.device('cuda:1')
DEVICE = torch.device('cuda:0')
sc.set_figure_params(figsize=(4, 4))

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext, \
    GeneNetworkAnalysisBase

2025-02-24 15:33:50.783326: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
ROOT_PATH = "/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm"
TRAIN_ROOT_PATH = "/work/hdd/bbjr/mallina1/data/human_cellariumgpt_v2/extract_files"
CHECKPOINT_PATH = "/work/hdd/bbjr/mallina1/cellarium/models/latest/epoch=5-step=452000.ckpt"

REF_ADATA_PATH = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
GENE_INFO_PATH = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=CHECKPOINT_PATH,
    ref_adata_path=REF_ADATA_PATH,
    gene_info_tsv_path=GENE_INFO_PATH,
    device=DEVICE,
    attention_backend="mem_efficient"
)

In [39]:
tr0 = sc.read_h5ad(os.path.join(TRAIN_ROOT_PATH, 'extract_0.h5ad'))

In [46]:
tr0.obs['cell_type'].value_counts()
tr0.obs.iloc[0]

original_cell_id                                                          CCTATTAGTCAAAGAT-46
cell_type                                   naive thymus-derived CD4-positive, alpha-beta ...
assay_ontology_term_id                                                            EFO:0009899
cell_type_ontology_term_id                                                         CL:0000895
development_stage_ontology_term_id                                             HsapDv:0000206
disease_ontology_term_id                                                         PATO:0000461
donor_id                                                                              106_106
is_primary_data                                                                          True
organism_ontology_term_id                                                      NCBITaxon:9606
self_reported_ethnicity_ontology_term_id                                       HANCESTRO:0005
sex_ontology_term_id                                        

In [None]:
gene_names = tr0.var_names.tolist()
print(gene_names[:10])
print(len(gene_names))

['ENSG00000187642', 'ENSG00000078808', 'ENSG00000272106', 'ENSG00000162585', 'ENSG00000272088', 'ENSG00000204624', 'ENSG00000162490', 'ENSG00000177000', 'ENSG00000011021', 'ENSG00000120949']
36601


In [80]:
# manually select prompt metadata
custom_prompt_metadata_dict = {
    "cell_type": False,
    "tissue": False,
    "disease": True,
    "sex": False,
    "development_stage": False
}

# idx = 0

# # prompt metadata comes from sample and process function?
# samp = tr0.obs.iloc[idx]
# assay = samp.assay
# suspension_type = samp.suspension_type
# total_mrna_umis = samp.total_mrna_umis
# prompt_metadata_dict = {
#     "cell_type": samp.cell_type,
#     "tissue": samp.tissue,
#     "disease": samp.disease,
#     "sex": samp.sex
# }

# metadata_prompt_masks_dict, metadata_dict = ctx.process_user_metadata(
#     assay, suspension_type, prompt_metadata_dict, total_mrna_umis)

# print(metadata_prompt_masks_dict)

with torch.inference_mode():
    tokens_dict, context_indices = ctx.generate_tokens_from_adata(tr0, obs_index=[0], query_var_names=gene_names[:5000],
                                                                  metadata_prompt_masks_dict=custom_prompt_metadata_dict)
    # tokens_dict, context_indices = ctx.generate_tokens_from_adata(tr0, obs_index=[0], query_var_names=[],
    #                                                               metadata_prompt_masks_dict=custom_prompt_metadata_dict)
    print(tokens_dict['gene_tokens_nc']['assay'].shape)

torch.Size([1, 36601])


In [82]:
try:
    print(f"prompt cell type idx: {context_indices[f'prompt_cell_type']}")
except:
    print(f"query cell type idx: {context_indices[f'query_cell_type']}")

try:
    print(f"prompt disease idx: {context_indices[f'prompt_disease']}")
except:
    print(f"query disease idx: {context_indices[f'query_disease']}")

query cell type idx: 36601
prompt disease idx: 36604


In [83]:
print(tokens_dict['metadata_tokens_n'])
print(tokens_dict['metadata_tokens_n']['cell_type'])

{'cell_type': tensor([890], dtype=torch.int32), 'tissue': tensor([822], dtype=torch.int32), 'development_stage': tensor([191], dtype=torch.int32), 'disease': tensor([0], dtype=torch.int32), 'sex': tensor([2], dtype=torch.int32)}
tensor([890], dtype=torch.int32)


In [84]:
with torch.inference_mode():
    hidden_states_ncd, gene_hidden_states_nqd = ctx.get_embeddings_from_tokens(tokens_dict, context_indices)
    gene_hidden_states_nqd = gene_hidden_states_nqd.cpu()

In [85]:
print(gene_hidden_states_nqd.shape)
print(hidden_states_ncd.shape)

torch.Size([1, 0, 512])
torch.Size([1, 36606, 512])


OUTDATED BELOW

In [2]:
# ROOT_PATH = "/mnt/cellariumgpt-xfer/mb-ml-dev-vm"
# CHECKPOINT_PATH = "/mnt/cellariumgpt-xfer/100M_long_run/run_001/lightning_logs/version_3/checkpoints/epoch=5-step=504000.ckpt"
ROOT_PATH = "/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm"
CHECKPOINT_PATH = "/work/hdd/bbjr/mallina1/cellarium/models/latest/epoch=5-step=452000.ckpt"

REF_ADATA_PATH = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
GENE_INFO_PATH = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=CHECKPOINT_PATH,
    ref_adata_path=REF_ADATA_PATH,
    gene_info_tsv_path=GENE_INFO_PATH,
    device=DEVICE,
    attention_backend="mem_efficient"
)

Testing embedding dims from cell type prompt

In [3]:
query_gene_ids_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "query_gene_ids.csv")
query_gene_ids = pd.read_csv(query_gene_ids_path, header=None).values.flatten().tolist()[:5000]  # NOTE

assay = "10x 3' v3"
suspension_type = "cell"
prompt_metadata_dict = {
    "cell_type": "neuron",
    "tissue": "blood"
}
total_mrna_umis = 5_000

with torch.inference_mode():
    tokens_dict, context_indices, _, _ = ctx.generate_gene_tokens_by_metadata(
        assay=assay,
        suspension_type=suspension_type,
        prompt_metadata_dict=prompt_metadata_dict,
        total_mrna_umis=total_mrna_umis,
        query_gene_ids=query_gene_ids,
        perturb_gene_ids=None
    )

In [6]:
tokens_dict

{'gene_tokens_nc': {'assay': tensor([[1, 1, 1,  ..., 1, 1, 1]]),
  'suspension_type': tensor([[1, 1, 1,  ..., 1, 1, 1]]),
  'gene_id': tensor([[    0,  3269,  1251,  ..., 27863,  3842, 22118]]),
  'gene_value': tensor([[[0.0000, 1.0000, 8.5174],
           [0.0000, 1.0000, 8.5174],
           [0.0000, 1.0000, 8.5174],
           ...,
           [0.0000, 1.0000, 8.5174],
           [0.0000, 1.0000, 8.5174],
           [0.0000, 1.0000, 8.5174]]])},
 'metadata_tokens_n': {'cell_type': tensor([0], dtype=torch.int32),
  'tissue': tensor([65], dtype=torch.int32),
  'development_stage': tensor([191], dtype=torch.int32),
  'disease': tensor([350], dtype=torch.int32),
  'sex': tensor([2], dtype=torch.int32)},
 'prompt_mask_nc': tensor([[False, False, False,  ..., False, False, False]])}

In [11]:
print(tokens_dict['gene_tokens_nc']['gene_id'].shape)
print(tokens_dict['gene_tokens_nc']['gene_value'].shape)
print(tokens_dict['prompt_mask_nc'].shape)

torch.Size([1, 5001])
torch.Size([1, 5001, 3])
torch.Size([1, 5006])


In [12]:
context_indices

{'prompt_genes': [0],
 'query_genes': [1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,
  111,
  112,
  113,
  114,
  115,
  116,
  117,
  118,
  119,
  120,
  121,
  122,
  123,
  124,
  125,
  126,
  127,
  128,
  129,
  130,
  131,
  132,
  133,
  134,
  135,
  136,
  137,
  138,
  139,
  140,
  141,
  142,
  143,
  144,
  145,
  146,
  147,
  148,
  149,
  150,
  151,
  152,
  153,


In [13]:
len(context_indices['query_genes'])

5000

In [14]:
with torch.inference_mode():
    hidden_states_ncd, gene_hidden_states_nqd = ctx.get_embeddings_from_tokens(tokens_dict, context_indices)

In [15]:
print(hidden_states_ncd.shape, gene_hidden_states_nqd.shape)

torch.Size([1, 5006, 512]) torch.Size([1, 5000, 512])
