# benchmarking on omnipath


In [1]:
from scprint import scPrint
from scprint.tasks import GRNfer

from bengrn import BenGRN
import scanpy as sc

from bengrn.base import train_classifier

from anndata.utils import make_index_unique
from bengrn import compute_genie3
from grnndata import utils as grnutils

%load_ext autoreload
%autoreload 2 

import torch
torch.set_float32_matmul_precision('medium')

💡 connected lamindb: jkobject/scprint


2024-05-22 14:35:58,757:INFO - Downloading data from `https://omnipathdb.org/queries/enzsub?format=json`
2024-05-22 14:35:58,879:INFO - Downloading data from `https://omnipathdb.org/queries/interactions?format=json`
2024-05-22 14:35:58,969:INFO - Downloading data from `https://omnipathdb.org/queries/complexes?format=json`
2024-05-22 14:35:59,063:INFO - Downloading data from `https://omnipathdb.org/queries/annotations?format=json`
2024-05-22 14:35:59,150:INFO - Downloading data from `https://omnipathdb.org/queries/intercell?format=json`
2024-05-22 14:35:59,361:INFO - Downloading data from `https://omnipathdb.org/about?format=text`
  warn(
  Shape = jax.core.Shape
For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html
  PRNGKey = jax.random.KeyArray
For more information, see https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html
  IntOrKey = Union[int, jax.random.KeyArray]
For more information, see https://jax.readthedocs.io/en/latest/jep/9263-type

In [None]:
model = scPrint.load_from_checkpoint(
    '../data/temp/vbd8bavn/epoch=17-step=90000.ckpt', precpt_gene_emb=None)
NUM_GENES = 4000

In [None]:
CELLTYPES = [
    'kidney distal convoluted tubule epithelial cell',
    'kidney loop of Henle thick ascending limb epithelial cell',
    'kidney collecting duct principal cell',
    'mesangial cell',
    'blood vessel smooth muscle cell',
    'podocyte',
    'macrophage',
    'leukocyte',
    'kidney interstitial fibroblast',
    'endothelial cell'
]

In [None]:
adata = sc.read_h5ad('/home/ml4ig1/scprint/.lamindb/yBCKp6HmXuHa0cZptMo7.h5ad')
adata.var["isTF"] = False
adata.var.loc[adata.var.symbol.isin(grnutils.TF), "isTF"] = True
adata

In [None]:
sc.tl.rank_genes_groups(
    adata, groupby="cell_type"
)
adata.var['ensembl_id'] = adata.var.index
metrics = {}
for celltype in CELLTYPES:
    to_use = adata.uns["rank_genes_groups"]["names"][celltype][
        : NUM_GENES
    ].tolist()
    subadata = adata[adata.obs.cell_type == celltype][:1024, adata.var.index.isin(
        model.genes) & adata.var.index.isin(to_use)]
    print(subadata)
    genie_grn = compute_genie3(
        subadata, nthreads=32, regulators=adata.var[adata.var.isTF].index.tolist())
    genie_grn.var.index = make_index_unique(
        genie_grn.var['symbol'].astype(str))
    metrics['genie3_tf_'+celltype] = BenGRN(genie_grn,
                                            do_auc=True, doplot=True).scprint_benchmark()
    genie_grn = compute_genie3(subadata, nthreads=32)
    genie_grn.var.index = make_index_unique(
        genie_grn.var['symbol'].astype(str))
    metrics['genie3_'+celltype] = BenGRN(genie_grn,
                                         do_auc=True, doplot=True).scprint_benchmark()

In [None]:
metrics

In [None]:
metrics = {}
for celltype in CELLTYPES:
    grn_inferer = GRNfer(model, adata[adata.X.sum(1) > 500],
                         how="random expr",
                         preprocess="softmax",
                         head_agg='max',
                         filtration="none",
                         forward_mode="none",
                         organisms=adata.obs['organism_ontology_term_id'][0],
                         num_genes=3000,
                         max_cells=1024,
                         doplot=False,
                         batch_size=32,
                         )
    grn = grn_inferer(layer=list(range(model.nlayers))[8:], cell_type=celltype)
    grn.var.index = make_index_unique(grn.var['symbol'].astype(str))
    metrics[celltype+'_scprint'] = BenGRN(grn).scprint_benchmark()
    grn_inferer = GRNfer(model, adata[adata.X.sum(1) > 500],
                         how="most var across",
                         preprocess="softmax",
                         head_agg='none',
                         filtration="none",
                         forward_mode="none",
                         organisms=adata.obs['organism_ontology_term_id'][0],
                         num_genes=NUM_GENES,
                         max_cells=1024,
                         doplot=False,
                         batch_size=32,
                         )
    grn = grn_inferer(layer=list(range(model.nlayers))[:], cell_type=celltype)
    grn, m, clf_omni = train_classifier(grn, C=0.3, train_size=0.5, class_weight={
                                        1: 100, 0: 1}, shuffle=False)
    grn.varp['GRN'] = grn.varp['classified']
    grn.var.index = make_index_unique(grn.var['symbol'].astype(str))
    metrics[celltype+'_scprint_class'] = BenGRN(grn).scprint_benchmark()
    grn.varp['GRN'][~grn.var.index.isin(grnutils.TF),:] = 0
    metrics[celltype+'_scprint_class_TF'] = BenGRN(grn).scprint_benchmark()

### we have 50% of the omnipath coming from protein interaction type stuff. the model doesn't get necessarily a ton better without it.

### we have 75% less sources in the transcript only dataset and similar amount of targets.

### in the end we get that most transcript level ground truth doesn't overlap the ppi level one. so dropping it would help in the results but on both side. and it should not help too much with EPR


In [None]:
metrics

In [None]:
0,000999, 4.7 2nd axis-> cls
0.00327, 9.1 1st axis -> cls


In [None]:
grn.varp['GRN'][~grn.var.index.isin(grnutils.TF),:] = 0
BenGRN(grn).scprint_benchmark()