In [11]:
import torch 
import pandas as pd 
import torch_geometric as pyg 
import numpy as np 
import os 
from tkgdti.data.GraphBuilder import GraphBuilder
from sklearn.model_selection import KFold

%load_ext autoreload
%autoreload 2

# seed 
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<torch._C.Generator at 0x7fad778ba650>

In [12]:
root = '../../extdata/relations/'
relnames = os.listdir(root)
N_dtis = pd.read_csv('../../extdata/relations/targetome_drug_targets_gene.csv').shape[0]
K=10
val_prop = 0.075 

NO_REV = False
if NO_REV: 
    relnames = [r for r in relnames if '_rev' not in r]


In [13]:

relnames2 = [] 
for rname in relnames: 
    rdf = pd.read_csv(root + '/' + rname)
    if rdf.shape[0] > 0: 
        relnames2.append(rname)
relnames = relnames2 

print('# of relations types:', len(relnames))

# of relations types: 49


In [14]:

for fold, (train_idxs, test_idxs) in enumerate(KFold(n_splits=K, random_state=seed, shuffle=True).split(range(N_dtis))):
    
    val_idxs = np.random.choice(train_idxs, int(val_prop*N_dtis), replace=False)
    train_idxs = np.array([i for i in train_idxs if i not in val_idxs]) # unnecessary but just to be sure

    assert len(set(train_idxs).intersection(set(val_idxs))) == 0, 'train and val overlap'
    assert len(set(train_idxs).intersection(set(test_idxs))) == 0, 'train and test overlap'
    assert len(set(val_idxs).intersection(set(test_idxs))) == 0, 'val and test overlap'

    print('init...')
    GB = GraphBuilder(root=root, relnames=relnames, val_idxs=val_idxs, test_idxs=test_idxs)
    print('building...')
    GB.build() 
    print('generating triples...')
    train, valid, test, data = GB.get_triples() 

    os.makedirs('../../data/tkg/processed/', exist_ok=True)
    os.makedirs(f'../../data/tkg/processed/FOLD_{fold}/', exist_ok=True)
    torch.save(train, f'../../data/tkg/processed/FOLD_{fold}/pos_train.pt')
    torch.save(valid, f'../../data/tkg/processed/FOLD_{fold}/pos_valid.pt')
    torch.save(test, f'../../data/tkg/processed/FOLD_{fold}/pos_test.pt')

    torch.save(None, f'../../data/tkg/processed/FOLD_{fold}/neg_train.pt')
    torch.save(None, f'../../data/tkg/processed/FOLD_{fold}/neg_valid.pt')
    torch.save(None, f'../../data/tkg/processed/FOLD_{fold}/neg_test.pt')

    torch.save(data, f'../../data/tkg/processed/FOLD_{fold}/Data.pt')

    print(f'Fold {fold} -> # train: {len(train_idxs)}, # val: {len(val_idxs)}, # test: {len(test_idxs)}')
    print()

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 0 -> # train: 2666, # val: 242, # test: 324

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 1 -> # train: 2666, # val: 242, # test: 324

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 2 -> # train: 2667, # val: 242, # test: 323

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 3 -> # train: 2667, # val: 242, # test: 323

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 4 -> # train: 2667, # val: 242, # test: 323

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 5 -> # train: 2667, # val: 242, # test: 323

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pa

In [15]:
GB.relations.head()

Unnamed: 0,src,dst,src_type,dst_type,relation,src_idx,dst_idx
0,2157,DNMT3A,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,73,3203
1,2606,SYCP2,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,313,12380
2,2034,KRAS,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,16,6088
3,2034,RAD21,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,16,10415
4,2469,MYBPC3,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,241,7906


In [16]:
GB.relations.relation.unique()

array(['mut_missense_variant_deleterious_fwd',
       'PHH_lincs_perturbation_rev', 'HT29_lincs_perturbation_rev',
       'PC3_lincs_perturbation_fwd', 'associates_fwd',
       'A375_lincs_perturbation_fwd', 'predicted_weak_binding_rev',
       'high_expr_rev', 'MCF7_lincs_perturbation_rev',
       'A549_lincs_perturbation_rev', 'predicted_weak_binding_fwd',
       'HA1E_lincs_perturbation_rev', 'PC3_lincs_perturbation_rev',
       'mut_missense_variant_tolerated_rev', 'isin_rev',
       'ASC_lincs_perturbation_fwd', 'ASC_lincs_perturbation_rev',
       'low_expr_rev', 'HT29_lincs_perturbation_fwd',
       'predicted_conf_weak_binding_fwd', 'associates_rev',
       'predicted_conf_weak_binding_rev', 'resistant_to_rev',
       'A549_lincs_perturbation_fwd', 'protbert_similarity',
       'resistant_to_fwd', 'predicted_strong_binding_rev',
       'VCAP_lincs_perturbation_fwd', 'chemberta_cosine_similarity',
       'sensitive_to_fwd', 'sensitive_to_rev', 'low_expr_fwd', 'isin_fwd',
       

In [17]:
data['num_nodes_dict']

{'dbgap_subject': 470,
 'disease': 4095,
 'drug': 89,
 'gene': 14334,
 'pathway': 2363}

In [18]:
for k,v in data['node_name_dict'].items(): 
    print(k, len(v))

dbgap_subject 470
disease 4095
drug 89
gene 14334
pathway 2363


In [19]:
valid

{'head_type': array(['drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'dr

In [20]:
# report graph metrics

print('num nodes:', sum([v for k,v in data['num_nodes_dict'].items()]))
print('num edges:', sum([v.size(1) for k,v in data['edge_index_dict'].items()]))
print('num node types:', len(data['node_name_dict']))
print('num relations:', len(data['edge_index_dict']))


num nodes: 21351
num edges: 1179913
num node types: 5
num relations: 49
