In [1]:
import torch 
import pandas as pd 
import torch_geometric as pyg 
import numpy as np 
import os 
from tkgdti.data.GraphBuilder import GraphBuilder
from sklearn.model_selection import KFold

%load_ext autoreload
%autoreload 2

# seed 
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f021f286650>

In [2]:
root = '../../extdata/relations/'
OUT = '../../data/tkge_no_patient/'
relnames = os.listdir(root)
N_dtis = pd.read_csv('../../extdata/relations/targetome_drug_targets_gene.csv').shape[0]
K=10
val_prop = 0.075 
NO_REV = False
EXCLUDE_PATIENT_RELATIONS = True

# NOTE (2/28/25): the ablation study suggests that these relations are either detrimental or not beneficial to model performance.
DO_NOT_INCLUDE_RELATIONS=['gene->mut_missense_variant_deleterious_rev->dbgap_subject',
                          'gene->isin_fwd->pathway',
                          'gene->associates_fwd->disease',
                          'gene->A549_lincs_perturbation_rev->drug']


In [3]:
if EXCLUDE_PATIENT_RELATIONS: 
    DO_NOT_INCLUDE_RELATIONS += [x for x in relnames if 'beataml_' in x]

print( DO_NOT_INCLUDE_RELATIONS )

['gene->mut_missense_variant_deleterious_rev->dbgap_subject', 'gene->isin_fwd->pathway', 'gene->associates_fwd->disease', 'gene->A549_lincs_perturbation_rev->drug', 'beataml_mutation_stop_gained_rel_rev.csv', 'beataml_mutation_missense_variant_deleterious_rel_fwd.csv', 'beataml_mutation_frameshift_variant_rel_fwd.csv', 'beataml_mutation_missense_variant_tolerated_rel_rev.csv', 'beataml_res_rel_rev.csv', 'beataml_res_rel_fwd.csv', 'beataml_mutation_frameshift_variant_rel_rev.csv', 'beataml_sens_rel_fwd.csv', 'beataml_sens_rel_rev.csv', 'beataml_mutation_missense_variant_deleterious_rel_rev.csv', 'beataml_mutation_missense_variant_tolerated_rel_fwd.csv', 'beataml_mutation_stop_gained_rel_fwd.csv']


In [4]:
if NO_REV: 
    relnames = [r for r in relnames if '_rev' not in r]

In [None]:
excluded = 0 
relnames2 = [] 
for rname in relnames: 
    rdf = pd.read_csv(root + '/' + rname)
    if rdf.shape[0] == 0: 
        print('no relations (consider excluded):', rname)
        excluded += 1
        continue
    
    rel_type = f'{rdf.src_type.values[0]}->{rdf.relation.values[0]}->{rdf.dst_type.values[0]}'
    if rel_type in DO_NOT_INCLUDE_RELATIONS: 
        print('excluding:', rname)
        excluded += 1
        continue

    if (rdf.shape[0] > 0): 
        print('including:', rname)
        relnames2.append(rname)
relnames = relnames2 

print('# excluded:', excluded)
#assert excluded == len(DO_NOT_INCLUDE_RELATIONS), 'might be an issue/bug in the exclusion of relations'

print('# of relations types:', len(relnames))

no relations (consider excluded): beataml_mutation_stop_gained_rel_rev.csv
no relations (consider excluded): beataml_mutation_frameshift_variant_rel_fwd.csv
excluding: ctd_genes_diseases_fwd.csv
excluding: A549_lincs_perturbation_rev.csv
no relations (consider excluded): beataml_mutation_frameshift_variant_rel_rev.csv
excluding: ctd_gene_isin_pathway_fwd.csv
excluding: beataml_mutation_missense_variant_deleterious_rel_rev.csv
no relations (consider excluded): beataml_mutation_stop_gained_rel_fwd.csv
# excluded: 8
# of relations types: 39


In [None]:

for fold, (train_idxs, test_idxs) in enumerate(KFold(n_splits=K, random_state=seed, shuffle=True).split(range(N_dtis))):
    
    val_idxs = np.random.choice(train_idxs, int(val_prop*N_dtis), replace=False)
    train_idxs = np.array([i for i in train_idxs if i not in val_idxs]) # unnecessary but just to be sure

    assert len(set(train_idxs).intersection(set(val_idxs))) == 0, 'train and val overlap'
    assert len(set(train_idxs).intersection(set(test_idxs))) == 0, 'train and test overlap'
    assert len(set(val_idxs).intersection(set(test_idxs))) == 0, 'val and test overlap'

    print('init...')
    GB = GraphBuilder(root=root, relnames=relnames, val_idxs=val_idxs, test_idxs=test_idxs)
    print('building...')
    GB.build() 
    print('generating triples...')
    train, valid, test, data = GB.get_triples() 

    os.makedirs(f'{OUT}/processed/', exist_ok=True)
    os.makedirs(f'{OUT}/processed/FOLD_{fold}/', exist_ok=True)
    torch.save(train, f'{OUT}/processed/FOLD_{fold}/pos_train.pt')
    torch.save(valid, f'{OUT}/processed/FOLD_{fold}/pos_valid.pt')
    torch.save(test, f'{OUT}/processed/FOLD_{fold}/pos_test.pt')

    # this is deprecated, but may cause issues if not saved. TODO: remove and ensure that the code does not rely on these files
    torch.save(None, f'{OUT}/processed/FOLD_{fold}/neg_train.pt')
    torch.save(None, f'{OUT}/processed/FOLD_{fold}/neg_valid.pt')
    torch.save(None, f'{OUT}/processed/FOLD_{fold}/neg_test.pt')

    torch.save(data, f'{OUT}/processed/FOLD_{fold}/Data.pt')

    print(f'Fold {fold} -> # train: {len(train_idxs)}, # val: {len(val_idxs)}, # test: {len(test_idxs)}')
    print()

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 0 -> # train: 2666, # val: 242, # test: 324

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 1 -> # train: 2666, # val: 242, # test: 324

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 2 -> # train: 2667, # val: 242, # test: 323

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 3 -> # train: 2667, # val: 242, # test: 323

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 4 -> # train: 2667, # val: 242, # test: 323

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...
generating triples...
Fold 5 -> # train: 2667, # val: 242, # test: 323

init...
Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pa

In [None]:
GB.relations.head()

Unnamed: 0,src,dst,src_type,dst_type,relation,src_idx,dst_idx
0,2157,DNMT3A,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,73,3203
1,2606,SYCP2,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,313,12380
2,2034,KRAS,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,16,6088
3,2034,RAD21,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,16,10415
4,2469,MYBPC3,dbgap_subject,gene,mut_missense_variant_deleterious_fwd,241,7906


In [None]:
GB.relations.relation.unique()

array(['mut_missense_variant_deleterious_fwd',
       'PHH_lincs_perturbation_rev', 'HT29_lincs_perturbation_rev',
       'PC3_lincs_perturbation_fwd', 'associates_fwd',
       'A375_lincs_perturbation_fwd', 'predicted_weak_binding_rev',
       'high_expr_rev', 'MCF7_lincs_perturbation_rev',
       'A549_lincs_perturbation_rev', 'predicted_weak_binding_fwd',
       'HA1E_lincs_perturbation_rev', 'PC3_lincs_perturbation_rev',
       'mut_missense_variant_tolerated_rev', 'isin_rev',
       'ASC_lincs_perturbation_fwd', 'ASC_lincs_perturbation_rev',
       'low_expr_rev', 'HT29_lincs_perturbation_fwd',
       'predicted_conf_weak_binding_fwd', 'associates_rev',
       'predicted_conf_weak_binding_rev', 'resistant_to_rev',
       'A549_lincs_perturbation_fwd', 'protbert_similarity',
       'resistant_to_fwd', 'predicted_strong_binding_rev',
       'VCAP_lincs_perturbation_fwd', 'chemberta_cosine_similarity',
       'sensitive_to_fwd', 'sensitive_to_rev', 'low_expr_fwd', 'isin_fwd',
       

In [None]:
data['num_nodes_dict']

{'dbgap_subject': 470,
 'disease': 4095,
 'drug': 89,
 'gene': 14334,
 'pathway': 2363}

In [None]:
for k,v in data['node_name_dict'].items(): 
    print(k, len(v))

dbgap_subject 470
disease 4095
drug 89
gene 14334
pathway 2363


In [None]:
# report graph metrics

print('num nodes:', sum([v for k,v in data['num_nodes_dict'].items()]))
print('num edges:', sum([v.size(1) for k,v in data['edge_index_dict'].items()]))
print('num node types:', len(data['node_name_dict']))
print('num relations:', len(data['edge_index_dict']))


num nodes: 21351
num edges: 1179913
num node types: 5
num relations: 49
