In [1]:
import os
import re
import csv
import torch
import random
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime
from scipy.stats import pearsonr
from anndata import AnnData

# for flex attention
import torch._dynamo
import torch.multiprocessing as mp 
torch._dynamo.config.suppress_errors = True

sc.set_figure_params(figsize=(4, 4))

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext

2025-04-24 00:40:01.620254: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
data_root = '/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm/data/GSE153807/tsvs'

fnames = ['GSM4654467_Nuc-RM101-2.raw.tsv', 
          'GSM4654469_Nuc-RM102-1.raw.tsv', 'GSM4654468_Nuc-RM102-2.raw.tsv', 
          'GSM4654470_Nuc-RM77-1.raw.tsv', 'GSM4654471_Nuc-RM77-2.raw.tsv',
          'GSM4654472_Nuc-RM95-1.raw.tsv', 'GSM4654473_Nuc-RM95-2.raw.tsv']

sex_per_fname = ['Female', 'Female', 'Female', 'Male', 'Male', 'Female', 'Female']
sex_ontology_type_id_per_fname = ['PATO:0000383', 'PATO:0000383', 'PATO:0000383', 'PATO:0000384', 'PATO:0000384', 'PATO:0000383', 'PATO:0000383']

gene_info_path = '/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm/gene_info/gene_info.tsv'
ontology_infos_path = '/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm/ontology_infos.pt'

idx_to_run = 0

In [3]:
ontology_infos = torch.load(ontology_infos_path)

gene_symb_to_gene_id = {}
gene_synonym_to_gene_id = {}
with open(gene_info_path, 'r') as fp:
    reader = csv.reader(fp, delimiter='\t')
    next(reader)
    for row in reader:
        gene_symb_to_gene_id[row[2]] = row[0]
        gene_synonym_to_gene_id[row[4]] = row[0]

  ontology_infos = torch.load(ontology_infos_path)


In [4]:
gene_symb_to_gene_id

{'A2MP1': 'ENSG00000256069',
 'ALOX12P2': 'ENSG00000262943',
 'KYAT1': 'ENSG00000171097',
 'CRYBB2P1': 'ENSG00000100058',
 'CTSLP2': 'ENSG00000266217',
 'CYP2B7P': 'ENSG00000256612',
 'FCGR1BP': 'ENSG00000198019',
 'GBAP1': 'ENSG00000160766',
 'GTF2IP1': 'ENSG00000277053',
 'NPY6R': 'ENSG00000226306',
 'PDE4C': 'ENSG00000105650',
 'PRKY': 'ENSG00000099725',
 'SAA3P': 'ENSG00000166787',
 'MBL1P': 'ENSG00000242600',
 'PDE8B': 'ENSG00000113231',
 'DLGAP2': 'ENSG00000198010',
 'TAAR3P': 'ENSG00000179073',
 'CLCA3P': 'ENSG00000153923',
 'MATR3': 'ENSG00000015479',
 'PPBPP2': 'ENSG00000248848',
 'CYP2G1P': 'ENSG00000130612',
 'GABARAPL3': 'ENSG00000238244',
 'AP1B1P1': 'ENSG00000234479',
 'OR7E47P': 'ENSG00000257542',
 'INGX': 'ENSG00000243468',
 'DUSP13B': 'ENSG00000293542',
 'CYP4F29P': 'ENSG00000228314',
 'PCDHB18P': 'ENSG00000146001',
 'VNN3P': 'ENSG00000093134',
 'LINC00216': 'ENSG00000279636',
 'PCDHGB8P': 'ENSG00000248449',
 'LINC00869': 'ENSG00000277147',
 'STAG3L4': 'ENSG00000106610

In [5]:
df = pd.read_csv(os.path.join(data_root, fnames[idx_to_run]), sep='\t', index_col=0)

original_symbols = df.index.to_series(name='gene_symbol')
mapped_ids = original_symbols.map(lambda s: gene_symb_to_gene_id.get(s))
df.index = mapped_ids

data = {
    'suspension_type': ['nucleus'] * len(df.columns),
    'total_mrna_umis': df.sum(axis=0),
    'assay_ontology_term_id': ['EFO:0009899'] * len(df.columns),
    'assay': ["10x 3' v2"] * len(df.columns),
    'sex': [sex_per_fname[idx_to_run]] * len(df.columns),
    'sex_ontology_term_id': [sex_ontology_type_id_per_fname[idx_to_run]] * len(df.columns)
}

obs = pd.DataFrame(index=df.columns, data=data)
var = pd.DataFrame(index=df.index)        # one row per gene ID
var['gene_symbol'] = original_symbols     # store the original symbol

adata = AnnData(X=df.values.T, obs=obs, var=var)

  utils.warn_names_duplicates("var")


In [6]:
adata

AnnData object with n_obs × n_vars = 5491 × 21283
    obs: 'suspension_type', 'total_mrna_umis', 'assay_ontology_term_id', 'assay', 'sex', 'sex_ontology_term_id'
    var: 'gene_symbol'

In [7]:
adata.var_names_make_unique()

In [8]:
metacell_X = np.array(adata.X.mean(axis=0)).ravel()

metacell_obs = pd.DataFrame(index=["metacell"],
                            data= {
                                'suspension_type': ['nucleus'],
                                'total_mrna_umis': metacell_X.sum(axis=0),
                                'assay_ontology_term_id': ['EFO:0009899'],
                                'assay': ["10x 3' v2"],
                                'sex': [sex_per_fname[idx_to_run]],
                                'sex_ontology_term_id': [sex_ontology_type_id_per_fname[idx_to_run]]
                            })

metacell = AnnData(
    X = metacell_X.reshape(1, -1),
    obs = metacell_obs,
    var = adata.var.copy()
)

In [9]:
metacell.obs

Unnamed: 0,suspension_type,total_mrna_umis,assay_ontology_term_id,assay,sex,sex_ontology_term_id
metacell,nucleus,3464.07139,EFO:0009899,10x 3' v2,Female,PATO:0000383


In [10]:
ROOT_PATH = "/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm"

REF_ADATA_FP = '/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm/data/extract_0.h5ad'
OUT_ADATA_DIR = '/work/hdd/bbjr/mallina1/data/human_cellariumgpt_v2/suspension_type_conversion'

GENE_INFO_PATH = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

# CHECKPOINT_PATH = "/work/hdd/bbjr/mallina1/cellarium/models/compute_optimal_checkpoints/epoch=1-step=28244.ckpt"
CHECKPOINT_PATH = "/work/hdd/bbjr/mallina1/cellarium/models/compute_optimal_checkpoints/epoch=6-step=63560.ckpt"
# CHECKPOINT_PATH = "/work/hdd/bbjr/mallina1/cellarium/models/compute_optimal_checkpoints/epoch=10-step=78917.ckpt"

DEVICE = 'cuda'

ref_adata = sc.read_h5ad(REF_ADATA_FP)
ref_var_names = set(ref_adata.var_names)

In [25]:
ref_var_names

{'ENSG00000285817',
 'ENSG00000230500',
 'ENSG00000156050',
 'ENSG00000230100',
 'ENSG00000185479',
 'ENSG00000236540',
 'ENSG00000267467',
 'ENSG00000232243',
 'ENSG00000176697',
 'ENSG00000107984',
 'ENSG00000227161',
 'ENSG00000255121',
 'ENSG00000287439',
 'ENSG00000162460',
 'ENSG00000197119',
 'ENSG00000273951',
 'ENSG00000100804',
 'ENSG00000204241',
 'ENSG00000134759',
 'ENSG00000177463',
 'ENSG00000253164',
 'ENSG00000254813',
 'ENSG00000146276',
 'ENSG00000172543',
 'ENSG00000229821',
 'ENSG00000211892',
 'ENSG00000131389',
 'ENSG00000127586',
 'ENSG00000277368',
 'ENSG00000261105',
 'ENSG00000056586',
 'ENSG00000184216',
 'ENSG00000145632',
 'ENSG00000226070',
 'ENSG00000204544',
 'ENSG00000187650',
 'ENSG00000119737',
 'ENSG00000145879',
 'ENSG00000156575',
 'ENSG00000101347',
 'ENSG00000198535',
 'ENSG00000198805',
 'ENSG00000285902',
 'ENSG00000008516',
 'ENSG00000242732',
 'ENSG00000247213',
 'ENSG00000135452',
 'ENSG00000249515',
 'ENSG00000125977',
 'ENSG00000126216',


In [11]:
genes_to_keep = ["AC004448.2","AC010894.3","AC011468.3","AC011586.2","AC016708.1","AC022217.3","AC024230.1",
                 "AC044781.1","AC072062.1","AC245014.3","ACTB","AIF1","AL136454.1","ALOX5AP","AMBRA1","APOC1",
                 "APOE","APOO","ARMC9","ATP5F1E","ATP5MC2","ATP6V0B","ATP6V0E1","B2M","BAIAP2L1","BDNF-AS",
                 "BTF3","BTG2","C1QA","C1QB","C1QC","CARMIL1","CCDC200","CCL2","CCL3","CCL3L1","CCL4","CCL4L2",
                 "CD14","CD37","CD63","CD68","CD74","CEBPB","CEBPD","CFL1","CHCHD3","COMMD6","CORO1A","COX4I1",
                 "CST3","CTSB","CYBA","DAPK1","DDIT4","DNAJB1","DUSP1","EEF1A1","EEF1B2","EEF1D","EEF2","EFCAB3",
                 "EIF1","FAU","FCER1G","FCGRT","FOLR2","FOS","FP700111.1","FTH1","FTL","GADD45B","GGACT",
                 "GPR183","GPX4","GRN","GSTP1","H3F3B","HCST","HERPUD1","HLA-A","HLA-B","HLA-C","HLA-DPA1",
                 "HLA-DPB1","HLA-DRA","HLA-DRB1","HLA-DRB5","HLA-E","HMOX1","HNRNPA1","HSP90AA1","HSPA1A",
                 "HSPA1B","HSPB1","IER2","IER3","ITM2B","JUN","JUNB","KIZ-AS1","LAMTOR4","LAPTM4A","LAPTM5",
                 "LINC01500","LINC01736","LINGO1","LTC4S","MAMDC2","MARCKS","MECOM","MT-ATP6","MT-CO1","MT-CO2",
                 "MT-CO3","MT-CYB","MT-ND2","MT-ND3","MT-ND4","MYL6","NACA","NACA2","NBEAL1","NFKBIA","NHSL2",
                 "NINJ1","NPC2","OLFML3","OOEP","OTULINL","PDK4","PFDN5","PFN1","PLD4","PLEKHA6","PLEKHA7",
                 "PNRC1","PSAP","PTMA","PYCARD","RAC1","RACK1","RGS1","RGS10","RHOB","RHOG","RNASE6","RPL10",
                 "RPL10A","RPL11","RPL12","RPL13","RPL13A","RPL14","RPL15","RPL18","RPL18A","RPL19","RPL21",
                 "RPL23","RPL23A","RPL24","RPL27","RPL27A","RPL28","RPL29","RPL3","RPL30","RPL31","RPL32","RPL34",
                 "RPL35","RPL35A","RPL36","RPL36AL","RPL37","RPL37A","RPL38","RPL39","RPL4","RPL41","RPL5","RPL6",
                 "RPL7","RPL7A","RPL8","RPLP0","RPLP1","RPLP2","RPS11","RPS12","RPS13","RPS14","RPS15","RPS15A",
                 "RPS16","RPS17","RPS18","RPS19","RPS2","RPS20","RPS23","RPS24","RPS25","RPS26","RPS27","RPS27A",
                 "RPS28","RPS29","RPS3","RPS3A","RPS4X","RPS5","RPS6","RPS7","RPS8","RPS9","RPSA","S100A11","SAT1",
                 "SERF2","SIK3","SLC25A6","SLC27A4","SLC47A1","SPP1","SRGN","TEX14","TMSB10","TMSB4X","TOMM7","TPT1",
                 "TREM2","TSPO","TUBA1B","TXNRD1","TYROBP","UBA52","UBC","VSIR","XPO5","YBX1","ZFP36","ZFP36L1",
                 "ZFP36L2","ZNF90"]

print(len(genes_to_keep))

gene_ids_to_keep = []
for x in genes_to_keep:
    if x in gene_symb_to_gene_id:
        if gene_symb_to_gene_id[x] in ref_var_names:
            gene_ids_to_keep.append(gene_symb_to_gene_id[x])
    elif x in gene_synonym_to_gene_id:
        if gene_synonym_to_gene_id[x] in ref_var_names:
            gene_ids_to_keep.append(gene_synonym_to_gene_id[x])

print(len(gene_ids_to_keep))

n_fixed_query_genes = 4096 - len(gene_ids_to_keep)

246
232


In [12]:
n_fixed_query_genes

3864

In [13]:
adata

AnnData object with n_obs × n_vars = 5491 × 21283
    obs: 'suspension_type', 'total_mrna_umis', 'assay_ontology_term_id', 'assay', 'sex', 'sex_ontology_term_id'
    var: 'gene_symbol'

In [14]:
'ENSG00000132475' in gene_ids_to_keep

True

In [27]:
_adata = adata[:, adata.var_names.isin(ref_var_names)].copy()
_metacell = metacell[:, metacell.var_names.isin(ref_var_names)].copy()

In [None]:
# _adata = adata.copy()
# _metacell = metacell.copy()
# _adata = adata[:, adata.var_names.isin(list(ref_var_names) + gene_ids_to_keep)].copy()
# _metacell = metacell[:, metacell.var_names.isin(list(ref_var_names) + gene_ids_to_keep)].copy()

# var_names = np.array(list(_adata.var_names))
# _adata = _adata[:, var_names]

In [28]:
sc.pp.highly_variable_genes(_adata, flavor='seurat_v3', n_top_genes=n_fixed_query_genes)
_adata.var['highly_variable']


gene_symbol
ENSG00000121410    False
ENSG00000268895    False
ENSG00000148584     True
ENSG00000175899     True
ENSG00000245105    False
                   ...  
ENSG00000070476    False
ENSG00000203995    False
ENSG00000162378    False
ENSG00000159840     True
ENSG00000074755    False
Name: highly_variable, Length: 16513, dtype: bool

In [29]:
temp_subset = _adata[:, _adata.var['highly_variable']].copy()
print(temp_subset)

final_gene_list = list(set(gene_ids_to_keep + list(temp_subset.var_names)))
print(final_gene_list)
print(len(final_gene_list))

# subset_metacell_adata = _metacell[:, _metacell.var_names.isin(final_gene_list)].copy()
# subset_adata = _adata[:, _adata.var_names.isin(final_gene_list)].copy()

# subset_metacell_adata = _metacell[:, np.array(final_gene_list)].copy()
# subset_adata = _adata[:, final_gene_list].copy()

AnnData object with n_obs × n_vars = 5491 × 3864
    obs: 'suspension_type', 'total_mrna_umis', 'assay_ontology_term_id', 'assay', 'sex', 'sex_ontology_term_id'
    var: 'gene_symbol', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'hvg'
['ENSG00000145990', 'ENSG00000005421', 'ENSG00000167778', 'ENSG00000172985', 'ENSG00000103528', 'ENSG00000160183', 'ENSG00000130300', 'ENSG00000197405', 'ENSG00000111640', 'ENSG00000224043', 'ENSG00000176697', 'ENSG00000151834', 'ENSG00000125650', 'ENSG00000175600', 'ENSG00000152208', 'ENSG00000254008', 'ENSG00000170624', 'ENSG00000230561', 'ENSG00000136535', 'ENSG00000160439', 'ENSG00000010438', 'ENSG00000153802', 'ENSG00000131697', 'ENSG00000112981', 'ENSG00000010818', 'ENSG00000176454', 'ENSG00000186314', 'ENSG00000137463', 'ENSG00000143889', 'ENSG00000111335', 'ENSG00000173848', 'ENSG00000121898', 'ENSG00000178607', 'ENSG00000102387', 'ENSG00000163131', 'ENSG00000120254', 'ENSG00000172986', 'ENSG000001858

In [30]:
subset_adata = _adata[:, _adata.var_names.isin(final_gene_list)].copy()
subset_metacell_adata = _metacell[:, _metacell.var_names.isin(final_gene_list)].copy()

In [31]:
ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=CHECKPOINT_PATH,
    ref_adata_path=REF_ADATA_FP,
    gene_info_tsv_path=GENE_INFO_PATH,
    device=DEVICE,
    attention_backend="mem_efficient"
)

In [32]:
subset_adata.obs['cell_type_ontology_term_id'] = None
subset_adata.obs['tissue_ontology_term_id'] = None
subset_adata.obs['disease_ontology_term_id'] = None
# subset_adata.obs['sex_ontology_term_id'] = None
subset_adata.obs['development_stage_ontology_term_id'] = None

subset_metacell_adata.obs['cell_type_ontology_term_id'] = None
subset_metacell_adata.obs['tissue_ontology_term_id'] = None
subset_metacell_adata.obs['disease_ontology_term_id'] = None
# subset_metacell_adata.obs['sex_ontology_term_id'] = None
subset_metacell_adata.obs['development_stage_ontology_term_id'] = None

In [33]:
metadata_prompt_dict = {
    "cell_type": False,
    "tissue": False,
    "disease": False,
    "sex": True,
    "development_stage": False
}

In [34]:
## run metacell first
query_genes = list(subset_metacell_adata.var_names)

tokens_dict, context_indices = ctx.generate_tokens_from_adata(subset_metacell_adata, 
                                                                obs_index=[0], 
                                                                query_var_names=query_genes,
                                                                metadata_prompt_masks_dict=metadata_prompt_dict,
                                                                query_total_mrna_umis=4900,
                                                                query_suspension_type='cell')

gene_logits_nqk = ctx.get_gene_value_logits_from_tokens(tokens_dict,
                                                        context_indices,
                                                        max_counts=None)

gene_marginal_mean_nq, _ = ctx.calculate_gene_mean_std_from_logits(gene_logits_nqk,
                                                                    gene_logits_nqk.shape[-1],
                                                                    use_logsumexp=True)

dist = torch.distributions.categorical.Categorical(logits = gene_logits_nqk)
sampled_counts = dist.sample().cpu()

In [None]:
output_adata = subset_metacell_adata.copy()

newX = np.vstack([output_adata.X, gene_marginal_mean_nq.detach().cpu().numpy()])
new_obs = output_adata.obs.copy()

new_obs.suspension_type = 'cell'
new_obs.total_mrna_umis = 4900

new_obs = pd.concat([output_adata.obs, new_obs], axis=0)

final_out = AnnData(
    X = newX.reshape(2, -1),
    obs = new_obs,
    var = output_adata.var.copy()
)

AttributeError: 'DataFrame' object has no attribute 'obs'

In [46]:
final_out.obs

Unnamed: 0,suspension_type,total_mrna_umis,assay_ontology_term_id,assay,sex,sex_ontology_term_id,cell_type_ontology_term_id,tissue_ontology_term_id,disease_ontology_term_id,development_stage_ontology_term_id
metacell,nucleus,3464.07139,EFO:0009899,10x 3' v2,Female,PATO:0000383,,,,
metacell,cell,4900.0,EFO:0009899,10x 3' v2,Female,PATO:0000383,,,,


In [55]:
final_out.obs.suspension_type = final_out.obs.suspension_type.astype('category')


In [56]:
final_out.X

array([[0.00273174, 0.14551084, 0.03132398, ..., 0.43598616, 0.05408851,
        0.07102531],
       [0.00127832, 0.07497808, 0.03085769, ..., 0.35955817, 0.02758668,
        0.11570451]])

In [None]:
sc.tl.rank_genes_groups(final_out, groupby="suspension_type", method="wilcoxon")

ValueError: Could not calculate statistics for groups cell, nucleus since they only contain one sample.

: 