In [None]:
import torch 
import pandas as pd 
import torch_geometric as pyg 
import numpy as np 
import os 
from tkgdti.data.GraphBuilder import GraphBuilder
from sklearn.model_selection import KFold

%load_ext autoreload
%autoreload 2

# seed 
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
root = '../../extdata/relations/'
relnames = os.listdir(root)
N_dtis = pd.read_csv('../../extdata/relations/targetome_drug_targets_gene.csv').shape[0]
K=10
val_prop = 0.075 

NO_REV = False
if NO_REV: 
    relnames = [r for r in relnames if '_rev' not in r]


# relations to include: 53


In [26]:

relnames2 = [] 
for rname in relnames: 
    rdf = pd.read_csv(root + '/' + rname)
    if rdf.shape[0] > 0: 
        relnames2.append(rname)
relnames = relnames2 

print('# of relations types:', len(relnames))

# of relations types: 49


In [None]:

for fold, (train_idxs, test_idxs) in enumerate(KFold(n_splits=K, random_state=seed, shuffle=True).split(range(N_dtis))):
    
    val_idxs = np.random.choice(train_idxs, int(val_prop*N_dtis), replace=False)
    train_idxs = np.array([i for i in train_idxs if i not in val_idxs]) # unnecessary but just to be sure

    assert len(set(train_idxs).intersection(set(val_idxs))) == 0, 'train and val overlap'
    assert len(set(train_idxs).intersection(set(test_idxs))) == 0, 'train and test overlap'
    assert len(set(val_idxs).intersection(set(test_idxs))) == 0, 'val and test overlap'

    print('init...')
    GB = GraphBuilder(root=root, relnames=relnames, val_idxs=val_idxs, test_idxs=test_idxs)
    print('building...')
    GB.build() 
    print('generating triples...')
    train, valid, test, data = GB.get_triples() 

    os.makedirs('../../data/tkg/processed/', exist_ok=True)
    os.makedirs(f'../../data/tkg/processed/FOLD_{fold}/', exist_ok=True)
    torch.save(train, f'../../data/tkg/processed/FOLD_{fold}/pos_train.pt')
    torch.save(valid, f'../../data/tkg/processed/FOLD_{fold}/pos_valid.pt')
    torch.save(test, f'../../data/tkg/processed/FOLD_{fold}/pos_test.pt')

    torch.save(None, f'../../data/tkg/processed/FOLD_{fold}/neg_train.pt')
    torch.save(None, f'../../data/tkg/processed/FOLD_{fold}/neg_valid.pt')
    torch.save(None, f'../../data/tkg/processed/FOLD_{fold}/neg_test.pt')

    torch.save(data, f'../../data/tkg/processed/FOLD_{fold}/Data.pt')

    print(f'Fold {fold} -> # train: {len(train_idxs)}, # val: {len(val_idxs)}, # test: {len(test_idxs)}')
    print()

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 0 -> # train: 358, # val: 21, # test: 43

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 1 -> # train: 358, # val: 21, # test: 43

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 2 -> # train: 359, # val: 21, # test: 42

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 3 -> # train: 359, # val: 21, # test: 42

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 4 -> # train: 359, # val: 21, # test: 42

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 5 -> # train: 359, # val: 21, # test: 42

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building..

In [4]:
GB.relations.head()

Unnamed: 0,src,dst,src_type,dst_type,relation,edge_type,src_idx,dst_idx
0,2157,DNMT3A,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,,73,3193
1,2606,SYCP2,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,,313,12353
2,2034,KRAS,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,,16,6077
3,2034,RAD21,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,,16,10397
4,2469,MYBPC3,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,,241,7894


In [5]:
GB.relations.relation.unique()

array(['mut_missense_variant_deleterious_fwd',
       'mut_frameshift_variant__fwd', 'PHH_lincs_perturbation_rev',
       'HT29_lincs_perturbation_rev', 'PC3_lincs_perturbation_fwd', nan,
       'A375_lincs_perturbation_fwd', 'predicted_weak_binding_rev',
       'high_expr_rev', 'MCF7_lincs_perturbation_rev',
       'A549_lincs_perturbation_rev', 'predicted_weak_binding_fwd',
       'HA1E_lincs_perturbation_rev', 'PC3_lincs_perturbation_rev',
       'mut_missense_variant_tolerated_rev', 'isin_rev',
       'ASC_lincs_perturbation_fwd', 'ASC_lincs_perturbation_rev',
       'low_expr_rev', 'HT29_lincs_perturbation_fwd',
       'predicted_conf_weak_binding_fwd', 'associates_rev',
       'predicted_conf_weak_binding_rev', 'mut_stop_gained__rev',
       'resistant_to_rev', 'A549_lincs_perturbation_fwd',
       'associates_fwd', 'protbert_similarity', 'resistant_to_fwd',
       'predicted_strong_binding_rev', 'VCAP_lincs_perturbation_fwd',
       'chemberta_cosine_similarity', 'mut_stop_gaine

In [6]:
data['num_nodes_dict']

{'dbgap_subject': 470,
 'disease': 4095,
 'drug': 89,
 'gene': 14304,
 'pathway': 2363}

In [7]:
for k,v in data['node_name_dict'].items(): 
    print(k, len(v))

dbgap_subject 470
disease 4095
drug 89
gene 14304
pathway 2363


In [None]:
valid

{'head_type': array(['drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug', 'drug',
        'drug', 'drug', 'drug', 'drug', 'drug'], dtype='<U4'),
 'head': array([35, 50, 46, 13, 63, 62, 19, 51, 39, 72, 55, 65, 81, 17, 20,  9, 23,
        69, 20, 77, 75]),
 'relation': array([36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36,
        36, 36, 36, 36]),
 'tail_type': array(['gene', 'gene', 'gene', 'gene', 'gene', 'gene', 'gene', 'gene',
        'gene', 'gene', 'gene', 'gene', 'gene', 'gene', 'gene', 'gene',
        'gene', 'gene', 'gene', 'gene', 'gene'], dtype='<U4'),
 'tail': array([ 1062,  5049,  1875, 10416,  7842,  5908,  1062,  3623,  6690,
        10587,  2131,    81,  5990,  4040, 10152,  4040,  5734,  9326,
        10122,   941,  3623])}

In [None]:
# report graph metrics

print('num nodes:', sum([v for k,v in data['num_nodes_dict'].items()]))
print('num edges:', sum([v.size(1) for k,v in data['edge_index_dict'].items()]))
print('num node types:', len(data['node_name_dict']))
print('num relations:', len(data['edge_index_dict']))


num nodes: 21321
num edges: 1112377
num node types: 5
num relations: 51
