In [1]:
import os
import argparse
import math
import torch
import pickle
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
from braceexpand import braceexpand
from tqdm import tqdm
import multiprocessing as mp
from copy import copy

# for flex attention
import torch._dynamo
import torch.multiprocessing as mp 
torch._dynamo.config.suppress_errors = True

sc.set_figure_params(figsize=(4, 4))

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext, \
    GeneNetworkAnalysisBase

2025-03-21 17:14:29.562832: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
ADATA_FP = '/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm/data/pbmc_adata.h5ad'
REF_ADATA_FP = '/work/hdd/bbjr/mallina1/data/mb-ml-dev-vm/data/extract_0.h5ad'

In [3]:
ref_adata = sc.read_h5ad(REF_ADATA_FP)
ref_adata.obs['suspension_type'].unique()

['cell', 'nucleus']
Categories (2, object): ['nucleus', 'cell']

In [53]:
cellarium_var_names = set(ref_adata.var_names)

In [43]:
'dendritic cell' in ref_adata.obs[ref_adata.obs['suspension_type'] == 'nucleus'].cell_type.unique()

False

In [4]:
val_adata = sc.read_h5ad(ADATA_FP)
val_adata.obs

Unnamed: 0,NAME,nGene,nUMI,percent.mito,Cluster,CellType,Experiment,Method
0,pbmc1_Celseq2_1_ACAGAC,3290,9030,0.0902547065337763,1,CD4+ T cell,pbmc1,CEL-Seq2
1,pbmc1_Celseq2_1_ACAGGA,2797,8482,0.116953548691346,1,CD4+ T cell,pbmc1,CEL-Seq2
2,pbmc1_Celseq2_1_ACGTTG,2651,6787,0.111978782967438,1,CD4+ T cell,pbmc1,CEL-Seq2
3,pbmc1_Celseq2_1_AGACCA,2766,6004,0.0932711525649567,1,CD4+ T cell,pbmc1,CEL-Seq2
4,pbmc1_Celseq2_1_CAACTC,3510,10162,0.11621728006298,1,CD4+ T cell,pbmc1,CEL-Seq2
...,...,...,...,...,...,...,...,...
30490,pbmc2_inDrops_1_TAGTCTCT.GAGCCTTA.ATCCGCTA,453,717,0.097629009762901,11,Plasmacytoid dendritic cell,pbmc2,inDrops
30491,pbmc2_inDrops_1_TCCAGAAG.TTATGCGA.TAAGACGG,592,938,0.035181236673774,11,Plasmacytoid dendritic cell,pbmc2,inDrops
30492,pbmc2_inDrops_1_TGAATCCT.GAGCCTTA.CCCAAGCA,406,662,0.138972809667674,11,Plasmacytoid dendritic cell,pbmc2,inDrops
30493,pbmc2_inDrops_1_TGAATCCT.TTATGCGA.CATCTCCC,1001,2066,0.0556631171345595,11,Plasmacytoid dendritic cell,pbmc2,inDrops


In [None]:
pbmc_var_names = set(val_adata.var_names)
pbmc_var_names - cellarium_var_names


In [58]:
val_adata

AnnData object with n_obs × n_vars = 30495 × 33694
    obs: 'NAME', 'nGene', 'nUMI', 'percent.mito', 'Cluster', 'CellType', 'Experiment', 'Method'
    var: 'gene_symbols'
    uns: 'CellType_colors', 'Method_colors'
    obsm: 'X_harmony'

In [None]:
val_adata[:, val_adata.var_names.isin(cellarium_var_names)]

View of AnnData object with n_obs × n_vars = 30495 × 32351
    obs: 'NAME', 'nGene', 'nUMI', 'percent.mito', 'Cluster', 'CellType', 'Experiment', 'Method'
    var: 'gene_symbols'
    uns: 'CellType_colors', 'Method_colors'
    obsm: 'X_harmony'

In [74]:
val_adata[val_adata.obs.Experiment == 'pbmc1'].obs.Method.value_counts()

Method
10x Chromium (v2) A    3222
10x Chromium (v2) B    3222
10x Chromium (v3)      3222
Drop-seq               3222
Seq-Well               3222
inDrops                3222
CEL-Seq2                253
Name: count, dtype: int64

In [45]:
val_adata.obs['Method'].unique()

['CEL-Seq2', '10x Chromium (v2) A', '10x Chromium (v2) B', '10x Chromium (v3)', 'Drop-seq', 'Seq-Well', 'inDrops', '10x Chromium (v2)']
Categories (8, object): ['10x Chromium (v2)', '10x Chromium (v2) A', '10x Chromium (v2) B', '10x Chromium (v3)', 'CEL-Seq2', 'Drop-seq', 'Seq-Well', 'inDrops']

In [46]:
val_adata.obs['CellType'].unique()

['CD4+ T cell', 'Cytotoxic T cell', 'Natural killer cell', 'CD16+ monocyte', 'CD14+ monocyte', 'Megakaryocyte', 'B cell', 'Dendritic cell', 'Plasmacytoid dendritic cell', 'Unassigned']
Categories (10, object): ['B cell', 'CD4+ T cell', 'CD14+ monocyte', 'CD16+ monocyte', ..., 'Megakaryocyte', 'Natural killer cell', 'Plasmacytoid dendritic cell', 'Unassigned']

In [None]:
'''
Cellarium Assay Labels:
            "Seq-Well",
            "10x 3' v3",
            "SPLiT-seq",
            "Smart-seq v4",
            "Drop-seq",
            "sci-RNA-seq",
            "10x 5' v2",
            "10x 5' transcription profiling",
            "inDrop",
            "microwell-seq",
            "10x multiome",
            "10x 3' v1",
            "ScaleBio single cell RNA sequencing",
            "Smart-seq2",
            "10x 3' transcription profiling",
            "Seq-Well S3",
            "10x 3' v2",
            "MARS-seq",
            "10x 5' v1"
'''

assay_label_map = {
    "10x Chromium (v2)": "10x 3' v2",
    "10x Chromium (v2) A": "10x 3' v2",
    "10x Chromium (v2) B": "10x 3' v2",
    "10x Chromium (v3)": "10x 3' v3",
    "CEL-Seq2": "", # drop this
    "Drop-seq": "Drop-seq",
    "Seq-Well": "Seq-Well",
    "inDrops": "inDrop"
}

In [11]:
val_adata = sc.read_h5ad(ADATA_FP)
ref_adata = sc.read_h5ad(REF_ADATA_FP)
val_adata

AnnData object with n_obs × n_vars = 30495 × 33694
    obs: 'NAME', 'nGene', 'nUMI', 'percent.mito', 'Cluster', 'CellType', 'Experiment', 'Method'
    var: 'gene_symbols'
    uns: 'CellType_colors', 'Method_colors'
    obsm: 'X_harmony'

In [5]:
pbmc_umis = val_adata[val_adata.obs.CellType == 'B cell']
pbmc_umis = pbmc_umis[pbmc_umis.obs.Method == '10x Chromium (v3)']

In [8]:
pbmc_umis.obs

Unnamed: 0,NAME,nGene,nUMI,percent.mito,Cluster,CellType,Experiment,Method
6860,pbmc1_10x_v3_CCACAAATCTGGGTCG,2600,8744,0.105901189387008,3,B cell,pbmc1,10x Chromium (v3)
6975,pbmc1_10x_v3_GGAATGGTCGAGATGG,2575,11338,0.116334450520374,3,B cell,pbmc1,10x Chromium (v3)
7132,pbmc1_10x_v3_TGCATCCAGACTTCCA,2904,11065,0.101129688206055,3,B cell,pbmc1,10x Chromium (v3)
8754,pbmc1_10x_v3_AAAGAACAGATTGTGA,1446,5101,0.0848853166045873,3,B cell,pbmc1,10x Chromium (v3)
8755,pbmc1_10x_v3_AAAGGATAGCCGGATA,1093,3804,0.169558359621451,3,B cell,pbmc1,10x Chromium (v3)
...,...,...,...,...,...,...,...,...
9098,pbmc1_10x_v3_TTTACCAGTAATGATG,1390,4969,0.120547393841819,3,B cell,pbmc1,10x Chromium (v3)
9099,pbmc1_10x_v3_TTTACCATCTCGTGAA,1269,4455,0.162289562289562,3,B cell,pbmc1,10x Chromium (v3)
9100,pbmc1_10x_v3_TTTCAGTCACCCGTAG,1258,4052,0.100691016781836,3,B cell,pbmc1,10x Chromium (v3)
9101,pbmc1_10x_v3_TTTGGAGTCAAGCTTG,1734,6409,0.0756748322671244,3,B cell,pbmc1,10x Chromium (v3)


In [10]:
pbmc_umis.obs.nUMI.to_numpy()

array(['8744', '11338', '11065', '5101', '3804', '6379', '4039', '6318',
       '5195', '5526', '5984', '5997', '7075', '5140', '6937', '4061',
       '3514', '4132', '5440', '5729', '5951', '5081', '4211', '4554',
       '3680', '7183', '5622', '3823', '3923', '2734', '5356', '4081',
       '3374', '6907', '4231', '4894', '4874', '3024', '4964', '3633',
       '5090', '5195', '8843', '4041', '3568', '4217', '4313', '5108',
       '5045', '6188', '5110', '6796', '4982', '2813', '3568', '5462',
       '3686', '4451', '5591', '3693', '4174', '4882', '3849', '3451',
       '4669', '4420', '4364', '4308', '4046', '4796', '6738', '5441',
       '3619', '4923', '6732', '7162', '4435', '4548', '3837', '3506',
       '3619', '3997', '11446', '5277', '4450', '4853', '4816', '6075',
       '5374', '5534', '4609', '3830', '4714', '4828', '3986', '3417',
       '10630', '4486', '4720', '6017', '5050', '4632', '4923', '6923',
       '3913', '5514', '4703', '5837', '4883', '4964', '4560', '5126',
  

In [13]:
ref_var_names = set(ref_adata.var_names)
ref_var_names

{'ENSG00000087087',
 'ENSG00000261462',
 'ENSG00000278740',
 'ENSG00000237560',
 'ENSG00000243620',
 'ENSG00000241669',
 'ENSG00000287938',
 'ENSG00000286369',
 'ENSG00000249917',
 'ENSG00000223993',
 'ENSG00000248515',
 'ENSG00000177034',
 'ENSG00000231560',
 'ENSG00000063241',
 'ENSG00000169548',
 'ENSG00000267986',
 'ENSG00000100197',
 'ENSG00000265243',
 'ENSG00000228470',
 'ENSG00000229487',
 'ENSG00000283992',
 'ENSG00000162511',
 'ENSG00000272047',
 'ENSG00000286290',
 'ENSG00000114861',
 'ENSG00000198542',
 'ENSG00000239480',
 'ENSG00000100897',
 'ENSG00000105711',
 'ENSG00000152556',
 'ENSG00000228606',
 'ENSG00000077063',
 'ENSG00000254237',
 'ENSG00000275743',
 'ENSG00000270157',
 'ENSG00000122547',
 'ENSG00000233405',
 'ENSG00000211676',
 'ENSG00000285774',
 'ENSG00000284664',
 'ENSG00000287232',
 'ENSG00000287143',
 'ENSG00000163607',
 'ENSG00000188986',
 'ENSG00000236115',
 'ENSG00000276473',
 'ENSG00000135218',
 'ENSG00000226124',
 'ENSG00000150471',
 'ENSG00000128951',


In [None]:
random_genes = np.random.choice(list(ref_var_names), size=10, replace=False)

array(['ENSG00000287950', 'ENSG00000232653', 'ENSG00000231233',
       'ENSG00000226862', 'ENSG00000287323', 'ENSG00000211945',
       'ENSG00000198796', 'ENSG00000188186', 'ENSG00000245662',
       'ENSG00000287759'], dtype='<U15')

In [26]:
np.array(pbmc_umis.X.sum(-1)).squeeze().shape

(346,)

In [33]:
ref_adata.obs.sex_ontology_term_id.unique()

['PATO:0000384', 'unknown', 'PATO:0000383']
Categories (3, object): ['PATO:0000383', 'unknown', 'PATO:0000384']

In [34]:
val_adata = sc.read_h5ad(ADATA_FP)

In [41]:
val_adata.var_names

Index(['ENSG00000000003', 'ENSG00000000005', 'ENSG00000000419',
       'ENSG00000000457', 'ENSG00000000460', 'ENSG00000000938',
       'ENSG00000000971', 'ENSG00000001036', 'ENSG00000001084',
       'ENSG00000001167',
       ...
       'ENSG00000283078', 'ENSG00000283083', 'ENSG00000283088',
       'ENSG00000283093', 'ENSG00000283095', 'ENSG00000283096',
       'ENSG00000283103', 'ENSG00000283117', 'ENSG00000283118',
       'ENSG00000283125'],
      dtype='object', length=33694)