## This notebook runs inference on a GEARS model trained on Norman

- Download trained GEARS model and Dataloader from Dataverse
- Model is trained on Norman et al. 2019 (Science) dataset
- Example below showing how to make perturbation outcome prediction and GI prediction

In [1]:
import sys
sys.path.append('../')

import numpy as np

from gears import PertData, GEARS
from gears.utils import dataverse_download
from zipfile import ZipFile 


%load_ext autoreload
%load_ext jupyter_spaces
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


### Download saved model and dataloader

In [2]:
## Download dataloader from dataverse
dataverse_download('https://dataverse.harvard.edu/api/access/datafile/6979957', 'norman_umi_go.tar.gz')

## Extract and set up dataloader directory
import tarfile
with tarfile.open('norman_umi_go.tar.gz', 'r:gz') as tar:
    tar.extractall()

Downloading...
100%|██████████| 1.10G/1.10G [01:35<00:00, 11.4MiB/s] 


In [3]:
## Download model from dataverse
dataverse_download('https://dataverse.harvard.edu/api/access/datafile/6979956', 'model.zip')

## Extract and set up model directory
with ZipFile(('model.zip'), 'r') as zip:
    zip.extractall(path = './')

Downloading...
100%|██████████| 10.9M/10.9M [00:05<00:00, 2.18MiB/s]


### Load model and dataloader

In [2]:
data_path = './'
data_name = 'norman_umi_go'
model_name = 'gears_misc_umi_no_test'

pert_data = PertData(data_path)
pert_data.load(data_path = data_path + data_name)
pert_data.prepare_split(split = 'no_test', seed = 1)
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128)

# gears_model = GEARS(pert_data, device = 'cuda:0', 
#                         weight_bias_track = False, 
#                         proj_name = 'gears', 
#                         exp_name = model_name)
# gears_model.load_pretrained('./model_ckpt')

Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['RHOXF2BB+ctrl' 'LYL1+IER5L' 'ctrl+IER5L' 'KIAA1804+ctrl' 'IER5L+ctrl'
 'RHOXF2BB+ZBTB25' 'RHOXF2BB+SET']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Done!
Creating dataloaders....
Done!


In [6]:
gears_model = GEARS(pert_data, device = 'cuda:5', 
                        weight_bias_track = False, 
                        proj_name = 'gears', 
                        exp_name = model_name)
gears_model.load_pretrained('./model_ckpt')

## GEARS data investigation

In [3]:
%%space version_test
import anndata as ad
norman_gears=ad.read_h5ad("norman_umi_go/perturb_processed.h5ad")
print("norman_gears",norman_gears.shape,norman_gears.obs["condition"].nunique())
norman_original=ad.read_h5ad("/data/liz0f/sc_diffusion/datasets/Norman2019.h5ad")
print("norman_original",norman_original.shape,norman_gears.obs["condition"].nunique())
norman_scPerturb=ad.read_h5ad("/data/liz0f/sc_diffusion/scPerturb/NormanWeissman2019_filtered.h5ad")
print("norman_scPerturb",norman_scPerturb.shape,norman_scPerturb.obs["perturbation"].nunique())


norman_gears (91205, 5054) 284
norman_original (108497, 5000) 284
norman_scPerturb (111445, 33694) 237


In [121]:
%%space version_test
norman_scPertub_perturbation_list=norman_scPerturb.obs["perturbation"].unique().tolist()
with open("pertmap.tsv",'w') as f:
    for pert in norman_scPertub_perturbation_list:
        print(pert,end="\t",file=f)
        pert=pert.replace('_','+')
        if '+' in pert:
            newpert=norman_gears.obs.query(f"condition.str.contains('{pert}',regex=False)")["condition"].unique().tolist()
        else:
            newpert=norman_gears.obs.query(f"condition.str.contains('{pert}',regex=False)&condition.str.contains('ctrl',regex=False)")["condition"].unique().tolist()
        print(','.join(newpert),file=f)

In [39]:
%%space version_test
barcode_prefix="AAACGGGTCCTAGGGC"
mask=norman_gears.obs.index.str.startswith(barcode_prefix)
assert mask.sum()==1
print(norman_gears.X[mask,:].sum())
mask=norman_original.obs.index.str.startswith(barcode_prefix)
assert mask.sum()==1
print(norman_original.X[norman_original.obs.index.str.startswith(barcode_prefix),:].sum())
mask=norman_scPerturb.obs.index.str.startswith(barcode_prefix)
assert mask.sum()==1
print(norman_scPerturb.X[norman_scPerturb.obs.index.str.startswith(barcode_prefix),:].sum())

2616.059
553.6387
18919.0


### Make transcriptional outcome predictions

In [None]:
gears_model.predict([['CNN1', 'CBL']])

### Make GI outcome prediction

In [10]:
gears_model.GI_predict(['CNN1', 'CBL'], GI_genes_file=None)

{'ts': TheilSenRegressor(fit_intercept=False, max_iter=1000, max_subpopulation=100000,
                   random_state=1000),
 'c1': 1.0942881586568658,
 'c2': 0.684177476331237,
 'mag': 1.290567856912458,
 'dcor': 0.8649321390185458,
 'dcor_singles': 0.7813616432466521,
 'dcor_first': 0.827889410401002,
 'dcor_second': 0.8135062057416026,
 'corr_fit': 0.9303117736028462,
 'dominance': 0.20396292696340834,
 'eq_contr': 0.9826266594563244}

In [None]:
## If reproducing results from paper, you can use the same gene set ()
dataverse_download('https://dataverse.harvard.edu/api/access/datafile/6979958', 
                   'genes_with_hi_mean.npy')

gears_model.GI_predict(['CNN1', 'CBL'], GI_genes_file='./genes_with_hi_mean.npy')