In [None]:
%load_ext autoreload
%autoreload 2
    
import subprocess
import os
import numpy as np
import pandas as pd
import torch
import torch_geometric
import json

import torch.nn.functional as F 
from torch_geometric.utils import negative_sampling
import torch_geometric.transforms as T 

from src.model import *
from src.utilities import *
from src.data import *


In [None]:
print(os.environ.get('CONDA_DEFAULT_ENV'))
nvidia_smi_output = subprocess.check_output(['nvidia-smi']).decode('utf-8')
print(nvidia_smi_output)
torch.cuda.is_available()

In [3]:
data_dir = '../data/'
emb_dir = os.path.join(data_dir, 'embedding')
result_dir = os.path.join(data_dir, 'result')

graph_data_dir = {
    'node_attr' : {
        'drug' : 'node_attr_drug',
        'adr' : 'node_attr_adr',
        'disease' : 'node_attr_disease',
        'dp' : 'node_attr_dp',
        'gene' : 'node_attr_gene',
    },
    'edge_index' : {
        'drug_adr' : 'edge_index_drug_adr',
        'gene_drug' : 'edge_index_gene_drug',
        'disease_dp' : 'edge_index_disease_dp',
        'gene_disease' : 'edge_index_gene_disease',
        'gene_gene' : 'edge_index_ppi'
    }
}

parameters = {  
    'Network': {
        'nodes': 'Heterogenious Network',
        'directed': False,
               },
    'model':{
        'encoder': 'GraphConv', 
        'decoder': 'cosine',
        'N_layers' : 3,        # fix 
        'hidden_channels': 200,    
        'objective_Func': 'binary_cross_entropy_with_logits',
        'epoch' : 500,     
        'patience' : 5,    
        'lr' : 0.001,
        'weight_decay' : 1e-5,
        'iter_neg' : 'Yes'
    },

    'evaluation':{
        'top_k' : np.arange(5,500,10).tolist(),
        'rep' : 1000
    }
}

device = torch.device('cuda:0')

In [None]:
data = GraphConstruction(data_dir, graph_data_dir).create_graph().to(device)  
Undirected = T.ToUndirected()
data = Undirected(data)

In [6]:
splitdata = T.RandomLinkSplit(is_undirected=True,
                              num_val=0.0, 
                              num_test = 0.1,
                              neg_sampling_ratio = 1,
                              add_negative_train_samples = False, #when False, it does not produce negative samples for train_data
                              edge_types= [('drug', 'conntectedTo', 'adr'), 
                                           ('gene', 'targetedBy', 'drug'),
                                          ('disease', 'conntectedTo', 'dp'),
                                          ('gene', 'associatedWith', 'disease')],
                             rev_edge_types = [('adr', 'rev_conntectedTo', 'drug'),
                                               ('drug', 'rev_targetedBy', 'gene'),
                                               ('dp', 'rev_conntectedTo', 'disease'),
                                               ('disease', 'rev_associatedWith', 'gene')])



splitdata_PPI = T.RandomLinkSplit(is_undirected=False,
                              num_val=0.0, 
                              num_test = 0.0,
                              neg_sampling_ratio = 1,
                              add_negative_train_samples = False, 
                              edge_types= [('gene', 'interactWith', 'gene')])


In [None]:
hidden_channels = parameters['model']['hidden_channels']
patience_thr = parameters['model']['patience']
Epoch =  parameters['model']['epoch']
lr = parameters['model']['lr']
weight_decay=parameters['model']['weight_decay']
Repeat = parameters['evaluation']['rep']

loss_old = 1000000
loss_train = []
loss_val = []
loss_test = []

loss_total = pd.DataFrame(columns = ['loss_train', 'loss_val', 'loss_test'])
res_drug_adr = pd.DataFrame()
res_gene_drug = pd.DataFrame()
res_disease_dp = pd.DataFrame()
res_gene_disease = pd.DataFrame()

losses_train = []
losses_val = []
losses_test = []

losses_train_drug_adr = []
losses_val_drug_adr = []  
losses_test_drug_adr = []

losses_train_gene_drug = []
losses_val_gene_drug = []  
losses_test_gene_drug = []

losses_train_disease_dp = []
losses_val_disease_dp = [] 
losses_test_disease_dp = []

losses_train_gene_disease = []
losses_val_gene_disease = []  
losses_test_gene_disease = [] 

for i in range(Repeat):
    print(f'repeat:{i}')
    
    train_data, val_data, test_data = splitdata(data)
    train_data_PPI, _, _ = splitdata_PPI(data)


    ####
    model = Model(data.num_nodes_dict, hidden_channels, data).to(device)

    with torch.no_grad():
        model.forward_encoder(train_data.x_dict, train_data.edge_index_dict, raw_data_dir, graph_data_dir)
    optimizer = torch.optim.Adam(model.parameters(), lr= lr, weight_decay = weight_decay)


    # training loop
    losses_train.clear()
    losses_val.clear()
    losses_test.clear()
    
    losses_train_drug_adr.clear()
    losses_val_drug_adr.clear() 
    losses_test_drug_adr.clear()
    
    losses_train_gene_drug.clear()
    losses_val_gene_drug.clear() 
    losses_test_gene_drug.clear()
    
    losses_train_disease_dp.clear()
    losses_val_disease_dp.clear()
    losses_test_disease_dp.clear()
    
    losses_train_gene_disease.clear()
    losses_val_gene_disease.clear() 
    losses_test_gene_disease.clear()

    counter = 0

    for epoch in range(Epoch):
        model.train()
        optimizer.zero_grad()

        if epoch % 10 == 0:
            #### generate negative links:
            neg_edge_label_index_drug_adr = negative_sampling(train_data['drug', 'conntectedTo', 'adr'].edge_index, 
                                                          num_nodes = (train_data.num_nodes_dict['drug'], train_data.num_nodes_dict['adr']),
                                                          num_neg_samples = train_data['drug', 'conntectedTo', 'adr'].edge_label.shape[0])
    
            neg_edge_label_index_gene_drug = negative_sampling(train_data['gene', 'targetedBy', 'drug'].edge_index, 
                                                          (train_data.num_nodes_dict['gene'], train_data.num_nodes_dict['drug']),
                                                           num_neg_samples = train_data['gene', 'targetedBy', 'drug'].edge_label.shape[0])
            
            neg_edge_label_index_disease_dp = negative_sampling(train_data['disease', 'conntectedTo', 'dp'].edge_index, 
                                                              (train_data.num_nodes_dict['disease'], train_data.num_nodes_dict['dp']),
                                                               num_neg_samples = train_data['disease', 'conntectedTo', 'dp'].edge_label.shape[0])
            
            neg_edge_label_index_gene_disease = negative_sampling(train_data['gene', 'associatedWith', 'disease'].edge_index, 
                                                              (train_data.num_nodes_dict['gene'], train_data.num_nodes_dict['disease']),
                                                               num_neg_samples = train_data['gene', 'associatedWith', 'disease'].edge_label.shape[0])


        # Forward pass
        z = model.forward_encoder(train_data.x_dict, train_data.edge_index_dict, raw_data_dir, graph_data_dir)
    
        pos_pred_drug_adr = model.decoder(z, train_data['drug', 'conntectedTo', 'adr'].edge_label_index, type = 1)
        neg_pred_drug_adr = model.decoder(z, neg_edge_label_index_drug_adr, type = 1)
    
        pos_pred_gene_drug = model.decoder(z, train_data['gene', 'targetedBy', 'drug'].edge_label_index, type = 2)
        neg_pred_gene_drug = model.decoder(z, neg_edge_label_index_gene_drug, type = 2)
    
        pos_pred_disease_dp = model.decoder(z, train_data['disease', 'conntectedTo', 'dp'].edge_label_index, type = 3)
        neg_pred_disease_dp = model.decoder(z, neg_edge_label_index_disease_dp, type = 3)
    
        pos_pred_gene_disease = model.decoder(z, train_data['gene', 'associatedWith', 'disease'].edge_label_index, type = 4)
        neg_pred_gene_disease = model.decoder(z, neg_edge_label_index_gene_disease, type = 4)

        
        pos_target_drug_adr = train_data['drug', 'conntectedTo', 'adr'].edge_label
        neg_target_drug_adr = torch.zeros(neg_edge_label_index_drug_adr.size(1), dtype=torch.float, device=device)
    
        pos_target_gene_drug = train_data['gene', 'targetedBy', 'drug'].edge_label
        neg_target_gene_drug = torch.zeros(neg_edge_label_index_gene_drug.size(1), dtype=torch.float, device=device)
    
        pos_target_disease_dp = train_data['disease', 'conntectedTo', 'dp'].edge_label
        neg_target_disease_dp = torch.zeros(neg_edge_label_index_disease_dp.size(1), dtype=torch.float, device=device)
    
        pos_target_gene_disease = train_data['gene', 'associatedWith', 'disease'].edge_label
        neg_target_gene_disease = torch.zeros(neg_edge_label_index_gene_disease.size(1), dtype=torch.float, device=device)


        # Combine positive and negative samples
        logit_drug_adr = torch.cat([pos_pred_drug_adr, neg_pred_drug_adr], dim=0)
        target_drug_adr = torch.cat([pos_target_drug_adr, neg_target_drug_adr], dim=0)
    
        logit_gene_drug = torch.cat([pos_pred_gene_drug, neg_pred_gene_drug], dim=0)
        target_gene_drug = torch.cat([pos_target_gene_drug, neg_target_gene_drug], dim=0)
    
        logit_disease_dp = torch.cat([pos_pred_disease_dp, neg_pred_disease_dp], dim=0)
        target_disease_dp = torch.cat([pos_target_disease_dp, neg_target_disease_dp], dim=0)
    
        logit_gene_disease = torch.cat([pos_pred_gene_disease, neg_pred_gene_disease], dim=0)
        target_gene_disease = torch.cat([pos_target_gene_disease, neg_target_gene_disease], dim=0)

        logit_gene_gene = model.decoder(z, train_data_PPI['gene', 'interactWith', 'gene'].edge_label_index, type = 5)
        target_gene_gene = train_data_PPI['gene', 'interactWith', 'gene'].edge_label

        #########################
        
        loss_drug_adr = F.binary_cross_entropy_with_logits(logit_drug_adr, target_drug_adr)
        loss_gene_drug = F.binary_cross_entropy_with_logits(logit_gene_drug, target_gene_drug)
        loss_disease_dp = F.binary_cross_entropy_with_logits(logit_disease_dp, target_disease_dp)
        loss_gene_disease = F.binary_cross_entropy_with_logits(logit_gene_disease, target_gene_disease)
        loss_gene_gene = F.binary_cross_entropy_with_logits(logit_gene_gene, target_gene_gene)

        loss_total_train = loss_drug_adr + loss_gene_drug + loss_disease_dp + loss_gene_disease + loss_gene_gene
        
        loss_total_train.backward()
        optimizer.step()

        ##----
        #validation
        model.eval()
        z_val = model.forward_encoder(val_data.x_dict, val_data.edge_index_dict, raw_data_dir, graph_data_dir)
    
        logit_drug_adr_val = model.decoder(z_val, val_data['drug', 'conntectedTo', 'adr'].edge_label_index, type = 1)
        logit_gene_drug_val = model.decoder(z_val, val_data['gene', 'targetedBy', 'drug'].edge_label_index, type = 2)
        logit_disease_dp_val = model.decoder(z_val, val_data['disease', 'conntectedTo', 'dp'].edge_label_index, type = 3)
        logit_gene_disease_val = model.decoder(z_val, val_data['gene', 'associatedWith', 'disease'].edge_label_index, type = 4)


        target_drug_adr_val = val_data['drug', 'conntectedTo', 'adr'].edge_label
        target_gene_drug_val = val_data['gene', 'targetedBy', 'drug'].edge_label
        target_disease_dp_val = val_data['disease', 'conntectedTo', 'dp'].edge_label
        target_gene_disease_val = val_data['gene', 'associatedWith', 'disease'].edge_label
    
        loss_drug_adr_val = F.binary_cross_entropy_with_logits(logit_drug_adr_val, target_drug_adr_val)
        loss_gene_drug_val = F.binary_cross_entropy_with_logits(logit_gene_drug_val, target_gene_drug_val)
        loss_disease_dp_val = F.binary_cross_entropy_with_logits(logit_disease_dp_val, target_disease_dp_val)
        loss_gene_disease_val = F.binary_cross_entropy_with_logits(logit_gene_disease_val, target_gene_disease_val)
    
        loss_total_val = loss_drug_adr_val + loss_gene_drug_val + loss_disease_dp_val + loss_gene_disease_val
        loss_total_val = round(loss_total_val.item(), 3)

        #test
        model.eval()
        z_test = model.forward_encoder(test_data.x_dict, test_data.edge_index_dict, raw_data_dir, graph_data_dir)
    
        logit_drug_adr_test = model.decoder(z_test, test_data['drug', 'conntectedTo', 'adr'].edge_label_index, type = 1)
        logit_gene_drug_test = model.decoder(z_test, test_data['gene', 'targetedBy', 'drug'].edge_label_index, type = 2)
        logit_disease_dp_test = model.decoder(z_test, test_data['disease', 'conntectedTo', 'dp'].edge_label_index, type = 3)
        logit_gene_disease_test = model.decoder(z_test, test_data['gene', 'associatedWith', 'disease'].edge_label_index, type = 4)
        
        target_drug_adr_test = test_data['drug', 'conntectedTo', 'adr'].edge_label
        target_gene_drug_test = test_data['gene', 'targetedBy', 'drug'].edge_label
        target_disease_dp_test = test_data['disease', 'conntectedTo', 'dp'].edge_label
        target_gene_disease_test = test_data['gene', 'associatedWith', 'disease'].edge_label
    
        loss_drug_adr_test = F.binary_cross_entropy_with_logits(logit_drug_adr_test, target_drug_adr_test)
        loss_gene_drug_test = F.binary_cross_entropy_with_logits(logit_gene_drug_test, target_gene_drug_test)
        loss_disease_dp_test = F.binary_cross_entropy_with_logits(logit_disease_dp_test, target_disease_dp_test)
        loss_gene_disease_test = F.binary_cross_entropy_with_logits(logit_gene_disease_test, target_gene_disease_test)
        loss_total_test = loss_drug_adr_test + loss_gene_drug_test + lam_dpdisease*loss_disease_dp_test + lam_diseasegene*loss_gene_disease_test

        ###############################

        # if loss_total_val > loss_old:
        #     counter += 1
        # else:
        #     counter = 0
        # loss_old = loss_total_val
    
        # if counter >= patience_thr:
        #     print(f' patience condition is met in epoch: {epoch}') 
        #     break

        # if epoch%1 == 0:
        
        #     print(f'epoch: {epoch:03d}, loss_tr = {loss_total_train:.3f}, loss_val: {loss_total_val:.3f},'
        # )

    ######################

    res_drug_adr = pd.concat([res_drug_adr, Evaluation(target_drug_adr_test, logit_drug_adr_test).eval(k)], ignore_index=True)
    res_gene_drug = pd.concat([res_gene_drug, Evaluation(target_gene_drug_test, logit_gene_drug_test).eval(k)], ignore_index=True)
    res_disease_dp = pd.concat([res_disease_dp, Evaluation(target_disease_dp_test, logit_disease_dp_test).eval(k)], ignore_index=True)
    res_gene_disease = pd.concat([res_gene_disease, Evaluation(target_gene_disease_test, logit_gene_disease_test).eval(k)], ignore_index=True)

    
    # get embedding using encoder
    model.eval()
    z = model.forward_encoder(data.x_dict, data.edge_index_dict, raw_data_dir, graph_data_dir)
    embedding_adr = z['adr'].detach().cpu().numpy()
    embedding_dp = z['dp'].detach().cpu().numpy()
    embedding_drug = z['drug'].detach().cpu().numpy()
    embedding_disease = z['disease'].detach().cpu().numpy()
    embedding_gene = z['gene'].detach().cpu().numpy()
    emb = np.concatenate((embedding_adr, embedding_dp, embedding_drug, embedding_disease, embedding_gene), axis = 0)
    
    np.savetxt(os.path.join(emb_dir, 'emb_adr'+'_rep'+str(i+1)+'.csv'), embedding_adr, delimiter=',')
    np.savetxt(os.path.join(emb_dir, 'emb_drug'+'_rep'+str(i+1)+'.csv'), embedding_drug, delimiter=',')
    np.savetxt(os.path.join(emb_dir, 'emb_gene'+'_rep'+str(i+1)+'.csv'), embedding_gene, delimiter=',')
    np.savetxt(os.path.join(emb_dir, 'emb_dp'+'_rep'+str(i+1)+'.csv'), embedding_dp, delimiter=',')
    np.savetxt(os.path.join(emb_dir, 'emb_disease'+'_rep'+str(i+1)+'.csv'), embedding_disease, delimiter=',')


In [None]:
res_drug_adr.to_csv(os.path.join(result_dir, 'RGCN_res_adr_drug'+'.csv'), sep = ',', index=False)
res_gene_drug.to_csv(os.path.join(result_dir, 'RGCN_res_drug_gene'+'.csv'), sep = ',', index=False)
res_disease_dp.to_csv(os.path.join(result_dir, 'RGCN_res_dp_disease'+'.csv'), sep = ',', index=False)
res_gene_disease.to_csv(os.path.join(result_dir, 'RGCN_res_gene_disease'+'.csv'), sep = ',', index=False)