In [1]:
import sys
sys.path.append('../')

import pickle as pkl

import numpy as np
import anndata as ad
import pandas as pd
import torch
from gears import PertData, GEARS
from gears.utils import parse_single_pert,parse_combo_pert
%load_ext autoreload
%load_ext jupyter_spaces
%autoreload 2

Load data. We use norman as an example.

In [55]:
%%space preprocess
adata_bak=ad.read_h5ad("norman_umi_go/perturb_processed_bak.h5ad")
pertmap=pd.read_csv("pertmap.tsv",sep="\t",header=None)

pertmap_dict=dict()
for index,row in pertmap.iterrows():
    for k in row[1].split(','):
        pertmap_dict[k]=row[0]
        
def rename_conditions(x):
    x_s=x.split('+')
    if len(x_s)==1:
        return x_s[0]
    elif len(x_s)==2:
        if 'ctrl' in x_s:
            return 'ctrl+'+parse_single_pert(x)
        else:
            return "+".join(parse_combo_pert(x))
    
adata_bak.obs["condition"]=adata_bak.obs["condition"].map(rename_conditions)
if False:
    adata_bak.write("norman_umi_go/perturb_processed.h5ad")

In [43]:
data_path = './'
data_name = 'norman_umi_go'
model_name = 'gears_misc_umi_no_test'

pert_data = PertData(data_path)

Found local copy ./gene2go_all.pkl ...


In [5]:
pert_data.load(data_path = data_path + data_name)

Found local copy ./essential_all_data_pert_genes.pkl ...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['ctrl+RHOXF2BB' 'LYL1+IER5L' 'ctrl+IER5L' 'ctrl+KIAA1804'
 'RHOXF2BB+ZBTB25' 'RHOXF2BB+SET']
Local copy of pyg dataset is detected. Loading...
Done!


In [11]:
def map_fn(condition):
    c_split=condition.split('+')
    if len(c_split)==1:
        return 'ctrl'
    elif len(c_split)==2:
        cond1 = condition.split('+')[0]
        cond2 = condition.split('+')[1]
        num_ctrl = (cond1 == 'ctrl') + (cond2 == 'ctrl')
        if num_ctrl==1:
            return 'one'
        elif num_ctrl==0:
            return 'two'
        
def make_custom_split(conditions,n_combo_test):
    np.random.seed(0)
    assert 'ctrl' in conditions
    
    single_conditions=[cond for cond in conditions if map_fn(cond)=='one']
    combo_conditions=[cond for cond in conditions if map_fn(cond)=='two']
    
    print("all conditions (including ctrl)",len(conditions))
    print("single_conditions",len(single_conditions))
    print("combo_conditions",len(combo_conditions))
    assert n_combo_test<=len(combo_conditions)
    train_conditions=list()
    val_conditions=list()
    
    train_conditions+=single_conditions
    val_combo_conditions=list(np.random.choice(combo_conditions,n_combo,replace=False))
    train_combo_conditions=[cond for cond in combo_conditions if cond not in val_combo_conditions]
    
    train_conditions+=train_combo_conditions
    val_conditions+=val_combo_conditions
    print("train_conditions",len(train_conditions),"train_single_conditions",len(train_conditions)-len(train_combo_conditions),"train_combo_conditions",len(train_combo_conditions))
    print("val_conditions",len(val_conditions),"val_combo_conditions",len(val_combo_conditions))
    
    single_genes=[parse_single_pert(single_cond) for single_cond in single_conditions]
    combo_genes=sum((list(parse_combo_pert(combo_pert)) for combo_pert in combo_conditions),start=list())
    assert all(v in single_genes for v in combo_genes)
    return {
        'train':train_conditions,
        'val':val_conditions,
        'test':val_conditions
    }
for n_combo in [66,99,128]:
    custom_split = make_custom_split(pert_data.adata.obs.condition.unique().tolist(),n_combo)
    custom_split_path = "./norman_umi_go/splits/custom_split_%d.pkl"%(n_combo)
    with open(custom_split_path,'wb') as f:
        pkl.dump(custom_split,f)

all conditions (including ctrl) 231
single_conditions 102
combo_conditions 128
train_conditions 164 train_single_conditions 102 train_combo_conditions 62
val_conditions 66 val_combo_conditions 66
all conditions (including ctrl) 231
single_conditions 102
combo_conditions 128
train_conditions 131 train_single_conditions 102 train_combo_conditions 29
val_conditions 99 val_combo_conditions 99
all conditions (including ctrl) 231
single_conditions 102
combo_conditions 128
train_conditions 102 train_single_conditions 102 train_combo_conditions 0
val_conditions 128 val_combo_conditions 128


In [10]:
n_combo=66
pert_data.prepare_split(split = 'combo_seen2', seed = 1, split_path = "norman_umi_go/splits/custom_split_%d.pkl"%(n_combo))

Local copy of split is detected. Loading...
Done!


In [11]:
pert_data.get_dataloader(batch_size = 32, test_batch_size = 128)

Creating dataloaders....
Done!


Create a model object; if you use [wandb](https://wandb.ai), you can easily track model training and evaluation by setting `weight_bias_track` to true, and specify the `proj_name` and `exp_name` that you like.

In [12]:
data=next(iter(pert_data.dataloader["train_loader"]))

In [11]:
data.x.shape

torch.Size([161728, 1])

In [30]:
torch.split(data.x,32)[0].shape

torch.Size([32, 1])

In [15]:
data.de_idx

[array([ 100,  832, 1524, 1624, 1708, 1729, 1873, 2109, 2324, 2555, 2637,
        3178, 3594, 4058, 4067, 4166, 4409, 4531, 4736, 4898]),
 array([  30,  279,  445,  930, 1899, 2137, 2492, 2653, 2789, 3095, 3211,
        3366, 3392, 3721, 4224, 4445, 4465, 4613, 4737, 4869]),
 array([ 392,  521,  646,  746,  771, 1288, 1322, 1323, 1694, 2362, 2555,
        2651, 2653, 2738, 2770, 2785, 3446, 3877, 4457, 4737]),
 array([ 185,  342,  350,  432,  746, 1702, 2492, 2602, 2738, 2811, 3182,
        3208, 3346, 3367, 3830, 4460, 4661, 4859, 4879, 5031]),
 array([ 585,  815, 1473, 1502, 1662, 1798, 2442, 2738, 2783, 3369, 3538,
        3721, 4022, 4038, 4315, 4447, 4588, 4753, 4970, 4996]),
 array([  42,  148,  332, 1026, 1322, 1323, 1431, 1695, 1696, 1808, 2137,
        2731, 2811, 2941, 3802, 3936, 4149, 4460, 4684, 4970]),
 array([ 185,  342,  350,  432,  746, 1702, 2492, 2602, 2738, 2811, 3182,
        3208, 3346, 3367, 3830, 4460, 4661, 4859, 4879, 5031]),
 array([ 493,  691,  728, 1052, 10

In [17]:
data.pert

['ctrl+BAK1',
 'ctrl+TMSB4X',
 'ctrl+MEIS1',
 'ctrl+MAP2K6',
 'ctrl+ARRDC3',
 'ctrl+CEBPA',
 'ctrl+MAP2K6',
 'ctrl+BCL2L11',
 'ctrl+HOXB9',
 'CEBPE+KLF1',
 'ctrl+FOXL2',
 'ctrl+CEBPA',
 'ctrl+IGDCC3',
 'FEV+MAP7D1',
 'ctrl+CELF2',
 'ctrl+CEBPA',
 'ctrl+CSRNP1',
 'ctrl+OSR2',
 'ctrl+CEBPA',
 'ctrl+C19orf26',
 'ctrl+PTPN1',
 'ctrl+DUSP9',
 'ctrl+FOXO4',
 'ctrl+C3orf72',
 'TBX3+TBX2',
 'ctrl+CEBPE',
 'PTPN12+OSR2',
 'ctrl+MAP7D1',
 'ctrl+SLC4A1',
 'ctrl+CEBPB',
 'ctrl+UBASH3B',
 'ctrl+OSR2']

In [24]:
torch.unique(data.batch,return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
 tensor([5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054,
         5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054,
         5054, 5054, 5054, 5054, 5054, 5054, 5054, 5054]))

In [24]:
gears_model.config["num_genes"]

5054

In [25]:
gears_model.config["num_perts"]

9853

In [27]:
gears_model.config["G_go"].max()

tensor(9852)

In [28]:
gears_model.config["G_coexpress"].max()

tensor(5053)

In [33]:
with open("gears_config.pkl",'wb') as f:
    pkl.dump({
        "config":gears_model.config,
        "gene_names":pert_data.gene_names,
        "pert_names":pert_data.pert_names
    },f)
    

In [23]:
gears_model = GEARS(
    pert_data, device = 'cuda:0', 
                        weight_bias_track = False, 
                        proj_name = 'pertnet', 
                        exp_name = 'pertnet'
                   )
gears_model.model_initialize(hidden_size = 64, no_GO=True)
print(gears_model.tunable_parameters())

Found local copy ./go_essential_all ...


> [0;32m/data/liz0f/GEARS/gears/gears.py[0m(238)[0;36mmodel_initialize[0;34m()[0m
[0;32m    237 [0;31m        [0;32mimport[0m [0mipdb[0m[0;34m;[0m [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 238 [0;31m        [0mself[0m[0;34m.[0m[0mmodel[0m [0;34m=[0m [0mGEARS_Model[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mconfig[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    239 [0;31m        [0mself[0m[0;34m.[0m[0mbest_model[0m [0;34m=[0m [0mdeepcopy[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmodel[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> !self.config.keys()
dict_keys(['hidden_size', 'num_go_gnn_layers', 'num_gene_gnn_layers', 'decoder_hidden_size', 'num_similar_genes_go_graph', 'num_similar_genes_co_express_graph', 'coexpress_threshold', 'uncertainty', 'uncertainty_reg', 'direction_lambda', 'G_go', 

You can find available tunable parameters in model_initialize via

Train your model:

Note: For the sake of demo, we set epoch size to 1. To get full model, set `epochs = 20`.

In [8]:
gears_model.train(epochs = 20, lr = 1e-3)

Start Training...
Epoch 1 Step 1 Train Loss: 0.9232


GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled
GO disabled


KeyboardInterrupt: 

Save and load pretrained models:

In [27]:
print("n_combo =",n_combo)
gears_model.save_model('best_model-%d'%(n_combo))
gears_model.load_pretrained('best_model-%d'%(n_combo))

In [29]:
from gears.inference import evaluate,compute_metrics
test_res = evaluate(gears_model.dataloader['test_loader'], gears_model.best_model,
                    gears_model.config['uncertainty'], gears_model.device)
test_metrics, test_pert_res = compute_metrics(test_res)  
print(test_metrics)

{'mse': 0.011455394, 'mse_de': 0.31053308, 'pearson': 0.9896291897540179, 'pearson_de': 0.8852968530890082}


Make prediction for new perturbation:

Gene list can be found here: