In [1]:
import scanpy as sc
from anndata import AnnData

from grnndata import GRNAnnData, from_embeddings
from grnndata import utils 

from scdataloader import Preprocessor as myPreprocessor
from bengrn import BenGRN, get_sroy_gt, get_perturb_gt

import os
import warnings
import sys
import numpy as np

sys.path.insert(0, "../")
from scgpt_helper import prepare_model, prepare_dataset, generate_embedding, generate_grn
import scgpt as scg
from scgpt.tasks import GeneEmbedding
from scgpt.tokenizer.gene_tokenizer import GeneVocab
from scgpt.utils import set_seed
from scgpt.tokenizer import tokenize_and_pad_batch
import torch


os.environ["KMP_WARNINGS"] = "off"
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

print(torch.cuda.is_available())
set_seed(42)


💡 connected lamindb: jkobject/scprint




True



## Step 1: Load fine-tuned model and dataset


### 1.1 Load fine-tuned model

We are going to load a fine-tuned model for the gene interaction analysis on
Adamson dataset. The fine-tuned model can be downloaded via this
[link](https://drive.google.com/drive/folders/1HsPrwYGPXm867_u_Ye0W4Ch8AFSneXAn).
The dataset will be loaded in the next step 1.2.

To reproduce the provided fine-tuned model. Please followw the integration
fin-tuning pipeline to fine-tune the pre-trained blood model on the Adamson
perturbation dataset. Note that in the fine-tuning stage, we did not perform
highly vairable gene selection but trained on the 5000+ genes present in the
Adamson dataset. This is to provide flexbility in the inference stage to
investigate changes in attention maps across different perturbation conditions.


In [2]:
model, vocab = prepare_model(model_dir="../save/scGPT_human")

Resume model from ../save/scGPT_human/best_model.pt, the model args will override the config ../save/scGPT_human/args.json.



In [3]:
mpreprocessor = myPreprocessor(is_symbol=True, force_preprocess=True, skip_validate=True,
                            do_postp=False, min_valid_genes_id=5000, min_dataset_size=64)

genes = torch.load(
   '../../scPRINT/data/temp/vbd8bavn/epoch=17-step=90000.ckpt'
#    '/pasteur/zeus/projets/p02/ml4ig_hot/Users/jkalfon/scprint_scale/o2uniqsx/checkpoints/epoch=18-step=133000.ckpt'

)['hyper_parameters']['genes']

In [4]:
CELLTYPES = [
    'kidney distal convoluted tubule epithelial cell',
    'kidney loop of Henle thick ascending limb epithelial cell',
    'kidney collecting duct principal cell',
#    'mesangial cell', #cannot do it... too few cells
    'blood vessel smooth muscle cell',
    'podocyte',
    'macrophage',
    'leukocyte',
    'kidney interstitial fibroblast',
    'endothelial cell'
]
NUM_GENES = 5000
MAXCELLS = 1024


In [5]:
#adata = sc.read_h5ad('/home/ml4ig1/scprint/.lamindb/yBCKp6HmXuHa0cZptMo7.h5ad')
adata = sc.read_h5ad('../../scPRINT/data/yBCKp6HmXuHa0cZptMo7.h5ad')
adata.var["isTF"] = False
adata.var.loc[adata.var.symbol.isin(utils.TF), "isTF"] = True
adata

AnnData object with n_obs × n_vars = 15728 × 70116
    obs: 'donor_id', 'self_reported_ethnicity_ontology_term_id', 'organism_ontology_term_id', 'sample_uuid', 'sample_preservation_method', 'tissue_ontology_term_id', 'development_stage_ontology_term_id', 'tissue_section_uuid', 'tissue_section_thickness', 'library_uuid', 'assay_ontology_term_id', 'mapped_reference_annotation', 'is_primary_data', 'cell_type_ontology_term_id', 'author_cell_type', 'disease_ontology_term_id', 'sex_ontology_term_id', 'suspension_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'cell_culture', 'nnz', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier', 'leiden_3', 'leide

In [6]:
sc.tl.rank_genes_groups(
    adata, groupby="cell_type"
)
adata.var['ensembl_id'] = adata.var.index
metrics = {}
for celltype in CELLTYPES:
    to_use = adata.uns["rank_genes_groups"]["names"][celltype].tolist()
    to_use = [x for x in to_use if x in genes]
    to_use = to_use[:NUM_GENES]
    subadata = adata[adata.obs.cell_type == celltype]
    subadata = subadata[subadata.X.sum(1) > 500][:MAXCELLS, adata.var.index.isin(
        to_use)]
    subadata.var = subadata.var.set_index('feature_name')
    grn = generate_grn(model, vocab, subadata, batch_size = 10, num_attn_layers = 11)
    break
    metrics["scGPT_"+celltype] = BenGRN(grn).scprint_benchmark()
    grn.varp['GRN'][~grn.var.isTF,:]=0
    metrics['scGPT_tf_'+celltype] = BenGRN(grn).scprint_benchmark()

scGPT - INFO - Filtering genes by counts ...
scGPT - INFO - Normalizing total counts ...
scGPT - INFO - Log1p transforming ...
scGPT - INFO - Binning data ...


100%|██████████| 86/86 [03:53<00:00,  2.71s/it]


In [8]:
grn.grn.sum(1)

feature_name
TSPAN6           0.999814
DPM1             0.999812
FUCA2            0.999813
GCLC             0.999812
ENPP4            0.999814
                   ...   
NCBP2AS2         0.999814
MROH7-TTC4       0.999814
GTF2H5           0.999813
NUDT3            0.999813
C1QTNF3-AMACR    0.999811
Length: 5000, dtype: float32

In [9]:
grn.grn.sum(0)

feature_name
TSPAN6           0.974363
DPM1             1.020918
FUCA2            0.981289
GCLC             0.923938
ENPP4            0.964060
                   ...   
NCBP2AS2         0.940517
MROH7-TTC4       1.026443
GTF2H5           0.974878
NUDT3            1.007010
C1QTNF3-AMACR    0.877217
Length: 5000, dtype: float32

In [7]:
metrics

{'scGPT_kidney distal convoluted tubule epithelial cell': {'TF_enr': True,
  'enriched_terms_Regulators': ['celltype.gmt__Distal tubule cells',
   'celltype.gmt__Enterocytes',
   'celltype.gmt__Connecting tubule cells',
   'celltype.gmt__Foveolar cells',
   'celltype.gmt__-intercalated cells (Collecting duct system)',
   'celltype.gmt__Hepatocytes',
   'celltype.gmt__Cholangiocytes',
   'celltype.gmt__Principal cells (Collecting duct system)',
   'celltype.gmt__Cone bipolar cells',
   'celltype.gmt__Mesangial cells',
   'celltype.gmt__Alveolar macrophages',
   'celltype.gmt__Loop of Henle cells',
   'celltype.gmt__Melanocytes',
   'celltype.gmt__Acinar cells',
   'celltype.gmt__Kupffer cells',
   'celltype.gmt__Oligodendrocytes',
   'celltype.gmt__Proximal tubule cells',
   'celltype.gmt__Ductal cells',
   'celltype.gmt__Hepatic stellate cells',
   'celltype.gmt__Pulmonary alveolar type I cells'],
  'significant_enriched_TFtargets': 0.0,
  'precision': 0.0010619058203785046,
  'recall'

In [None]:
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True

In [8]:
import pandas as pd

In [9]:
res = []
for k, v in metrics.items():
    res.append([k.split('_')[-1], v['epr'], v['auprc'], v['rand_precision'], v['significant_enriched_TFtargets'], v.get('TF_enr', False), 'tf_' in k])

df = pd.DataFrame(res, columns=['name','EPR', 'AUPRC', 'RAND', 'TF_targ', 'TF_enr', 'TF_only'])
df

Unnamed: 0,name,EPR,AUPRC,RAND,TF_targ,TF_enr,TF_only
0,kidney distal convoluted tubule epithelial cell,0.418024,0.001125,0.001062,0.0,True,False
1,kidney distal convoluted tubule epithelial cell,5.49282,0.002397,0.001062,0.0,True,True
2,kidney loop of Henle thick ascending limb epit...,1.601299,0.001359,0.001212,0.0,True,False
3,kidney loop of Henle thick ascending limb epit...,7.504324,0.003709,0.001212,0.0,True,True
4,kidney collecting duct principal cell,0.667388,0.001621,0.001362,0.0,True,False
5,kidney collecting duct principal cell,8.956276,0.004798,0.001362,0.0,True,True
6,blood vessel smooth muscle cell,1.877086,0.001949,0.001771,0.0,True,False
7,blood vessel smooth muscle cell,6.532609,0.005903,0.001771,0.0,True,True
8,podocyte,1.540324,0.001728,0.001518,0.0,True,False
9,podocyte,7.603643,0.005567,0.001518,0.0,True,True
