In [None]:
import subprocess
import os

import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
import torch_geometric.transforms as T 
import torch.nn.functional as F 
from sklearn.metrics import roc_auc_score

import json
import networkx as nx
from node2vec import Node2Vec    # https://github.com/eliorc/node2vec 
from sklearn.metrics.pairwise import cosine_similarity

from src.utilities import *

In [18]:
data_dir = '../data/'
emb_dir = os.path.join(data_dir, 'embedding/emb_Node2vec')
result_dir = os.path.join(data_dir, 'result')

parameters = {
    'model':{
        'encoder': 'node2vec_encoder', 
        'decoder': 'node2vec_decoder',
        'hidden_channels': 15, 
        'walk_length' : 100,
        'walks_per_node' :100,     
        'workers': 30,
        'window': 5,
        'min_count': 1,
        'batch_words': 4,
        
    },

    'evaluation':{
        'top_k' : np.arange(5,500,10).tolist(),
        'rep' : 1000,
        'n_cpu': 1
    }
}

device = torch.device('cpu')

In [29]:
edge_index = pd.read_csv(os.path.join(data_dir, 'edge_index.csv'), sep= ',')
edge_index = np.array(edge_index)
edge_index.dtype
edge_index = edge_index.astype(np.int64)
edge_index.dtype
edge_index = np.transpose(edge_index)
edge_index.shape
edge_index = torch.tensor(edge_index, dtype=torch.long)

In [30]:
conversion_table = pd.read_csv(os.path.join(data_dir, 'conversionTable.csv'))
N_genes = conversion_table['Names'].str.contains('entrez.').sum()

In [None]:
# Data Construction
max_node_index = torch.max(edge_index)
num_nodes = max_node_index + 1
x = torch.arange(num_nodes)
data = Data(x = x, edge_index = edge_index)

In [32]:
first_I_adr = conversion_table[conversion_table['Names'].str.contains('meddra.')]['Nodes'].min()
last_I_adr = conversion_table[conversion_table['Names'].str.contains('meddra.')]['Nodes'].max()

first_I_drug = conversion_table[conversion_table['Names'].str.contains('drugbank.')]['Nodes'].min()
last_I_drug = conversion_table[conversion_table['Names'].str.contains('drugbank.')]['Nodes'].max()

first_I_dp = conversion_table[conversion_table['Names'].str.contains('hpo.')]['Nodes'].min()
last_I_dp = conversion_table[conversion_table['Names'].str.contains('hpo.')]['Nodes'].max()

first_I_disease = conversion_table[conversion_table['Names'].str.contains('mondo.')]['Nodes'].min()
last_I_disease = conversion_table[conversion_table['Names'].str.contains('mondo.')]['Nodes'].max()

first_I_gene = conversion_table[conversion_table['Names'].str.contains('entrez.')]['Nodes'].min()
last_I_gene = conversion_table[conversion_table['Names'].str.contains('entrez.')]['Nodes'].max()


In [None]:
flag_drug_adr = (((edge_index[0]>=first_I_drug) & (edge_index[0]<=last_I_drug) & (edge_index[1]>=first_I_adr) & (edge_index[1]<=last_I_adr)) |\
                 ((edge_index[0]>=first_I_adr) & (edge_index[0]<=last_I_adr) & (edge_index[1]>=first_I_drug) & (edge_index[1]<=last_I_drug)))

flag_drug_gene = (((edge_index[0]>=first_I_drug) & (edge_index[0]<=last_I_drug) & (edge_index[1]>=first_I_gene) & (edge_index[1]<=last_I_gene)) |\
                    ((edge_index[0]>=first_I_gene) & (edge_index[0]<=last_I_gene) & (edge_index[1]>=first_I_drug) & (edge_index[1]<=last_I_drug)))

flag_Disease_dp = (((edge_index[0]>=first_I_disease) & (edge_index[0]<=last_I_disease) & (edge_index[1]>=first_I_dp) & (edge_index[1]<=last_I_dp)) |\
                    ((edge_index[0]>=first_I_dp) & (edge_index[0]<=last_I_dp) & (edge_index[1]>=first_I_disease) & (edge_index[1]<=last_I_disease)))

flag_gene_Disease = (((edge_index[0]>=first_I_disease) & (edge_index[0]<=last_I_disease) & (edge_index[1]>=first_I_gene) & (edge_index[1]<=last_I_gene))|\
                    ((edge_index[0]>=first_I_gene) & (edge_index[0]<=last_I_gene) & (edge_index[1]>=first_I_disease) & (edge_index[1]<=last_I_disease)))

flag_gene_gene = ((edge_index[0]>=first_I_gene) & (edge_index[0]<=last_I_gene) & (edge_index[1]>=first_I_gene) & (edge_index[1]<=last_I_gene))

edge_index.shape, flag_drug_adr.shape, flag_drug_gene.shape, flag_Disease_dp.shape, flag_gene_Disease.shape, flag_gene_gene.shape


In [34]:
edge_index_drug_adr = edge_index[:,flag_drug_adr]
edge_index_drug_gene = edge_index[:,flag_drug_gene]
edge_index_disease_dp = edge_index[:,flag_Disease_dp]
edge_index_gene_disease = edge_index[:,flag_gene_Disease]
edge_index_gene_gene = edge_index[:,flag_gene_gene]

data_drug_adr = Data(x = x, edge_index = edge_index_drug_adr)
data_drug_gene = Data(x = x, edge_index = edge_index_drug_gene)
data_disease_dp = Data(x = x, edge_index = edge_index_disease_dp)
data_gene_disease = Data(x = x, edge_index = edge_index_gene_disease)
data_gene_gene = Data(x = x, edge_index = edge_index_gene_gene)


In [35]:
# Parameters
Repeat = parameters["evaluation"]["rep"]
num_nodes = len(data.x)

In [13]:
# Initial values
res_drug_adr = pd.DataFrame()
res_gene_drug = pd.DataFrame()
res_disease_dp = pd.DataFrame()
res_gene_disease = pd.DataFrame()

loss_train = []
loss_test_drug_adr_all = []
loss_test_drug_gene_all = []
loss_test_disease_dp_all = []
loss_test_gene_disease_all = []
loss_test_total = []


In [14]:
splitdata = T.RandomLinkSplit(is_undirected = True, 
                              num_val = 0.0, 
                              num_test = 0.1,
                             neg_sampling_ratio = 1,
                             add_negative_train_samples = False)


In [15]:
def cosine_sim(edge_label_index, model):

    similarity = np.zeros(edge_label_index.shape[1])
    for i, (node1, node2) in enumerate(edge_label_index.T):
        try:
            vec1 = model.wv[str(node1.item())]
            vec2 = model.wv[str(node2.item())]
            vec1 = vec1.reshape(1, -1)
            vec2 = vec2.reshape(1, -1)
            similarity[i] = cosine_similarity(vec1, vec2)[0][0]
        except:
            similarity[i] = np.random.rand() * 2 - 1
        
    return similarity

In [None]:
for i in range(Repeat):
    
    print(f'repeat:{i}')
    train_data_drug_adr, _, test_data_drug_adr = splitdata(data_drug_adr)  
    train_data_drug_gene, _, test_data_drug_gene = splitdata(data_drug_gene)
    train_data_disease_dp, _, test_data_disease_dp = splitdata(data_disease_dp)  
    train_data_gene_disease, _, test_data_gene_disease = splitdata(data_gene_disease)  

    train_data_edge_index = torch.cat((train_data_drug_adr.edge_index, train_data_drug_gene.edge_index, 
                                       train_data_disease_dp.edge_index, train_data_gene_disease.edge_index, edge_index_gene_gene), dim = 1).to(device)
    
    test_data_edge_label = torch.cat((test_data_drug_adr.edge_label, test_data_drug_gene.edge_label,
                                     test_data_disease_dp.edge_label, test_data_gene_disease.edge_label), dim = 0).to(device)
    
    test_data_edge_label_index = torch.cat((test_data_drug_adr.edge_label_index, test_data_drug_gene.edge_label_index,
                                           test_data_disease_dp.edge_label_index, test_data_gene_disease.edge_label_index), dim = 1).to(device)

    edge_type_flag_test = torch.cat([torch.ones(test_data_drug_adr.edge_label.shape), 
                                    2*torch.ones(test_data_drug_gene.edge_label.shape),
                                    3*torch.ones(test_data_disease_dp.edge_label.shape),
                                    4*torch.ones(test_data_gene_disease.edge_label.shape)], dim=0)


    # training loop
    
    train_data_edge_index_nx = [(str(a.item()), str(b.item())) for a, b in train_data_edge_index.T]
    graph = nx.Graph()
    graph.add_edges_from(train_data_edge_index_nx)

    node2vec = Node2Vec(graph, 
                        dimensions=parameters['model']['hidden_channels'], 
                        walk_length=parameters['model']['walk_length'], 
                        num_walks=parameters['model']['walks_per_node'], 
                        workers=parameters['model']['workers'])
    
    model = node2vec.fit(window=parameters['model']['window'], 
                         min_count=parameters['model']['min_count'], 
                         batch_words=parameters['model']['batch_words'])

    pred_test = cosine_sim(test_data_edge_label_index, model)
    pred_test_tensor = torch.tensor(pred_test, dtype=torch.float32)

    target_test = test_data_edge_label.cpu()
    
    loss_test = F.binary_cross_entropy_with_logits(pred_test_tensor, target_test)


    pred_test_drug_adr = pred_test[edge_type_flag_test == 1]
    pred_test_drug_gene = pred_test[edge_type_flag_test == 2]
    pred_test_disease_dp = pred_test[edge_type_flag_test == 3]
    pred_test_gene_disease = pred_test[edge_type_flag_test == 4]

    target_test_drug_adr = target_test[edge_type_flag_test == 1]
    target_test_drug_gene = target_test[edge_type_flag_test == 2]
    target_test_disease_dp = target_test[edge_type_flag_test == 3]
    target_test_gene_disease = target_test[edge_type_flag_test == 4]

    print(f'Loss_test: {loss_test:.4f}')

    ################## 
    pred_test_drug_adr = torch.tensor(pred_test_drug_adr, dtype=torch.float32)
    res_drug_adr = pd.concat([res_drug_adr, Evaluation(target_test_drug_adr, pred_test_drug_adr).eval(k)], ignore_index=True)
    
    pred_test_drug_gene = torch.tensor(pred_test_drug_gene, dtype=torch.float32)
    res_gene_drug = pd.concat([res_gene_drug, Evaluation(target_test_drug_gene, pred_test_drug_gene).eval(k)], ignore_index=True)
    
    pred_test_disease_dp = torch.tensor(pred_test_disease_dp, dtype=torch.float32)
    res_disease_dp = pd.concat([res_disease_dp, Evaluation(target_test_disease_dp, pred_test_disease_dp).eval(k)], ignore_index=True)
    
    pred_test_gene_disease = torch.tensor(pred_test_gene_disease, dtype=torch.float32)
    res_gene_disease = pd.concat([res_gene_disease, Evaluation(target_test_gene_disease, pred_test_gene_disease).eval(k)], ignore_index=True)

    ######### save the embedding
    embeddings = model.wv.vectors
    nodes = model.wv.index_to_key
    df_embeddings = pd.DataFrame(embeddings, index=nodes)
    df_embeddings.to_csv(os.path.join(emb_dir, 'emb'+'_rep'+str(i+1)+'.csv'))


In [30]:
res_drug_adr.to_csv(os.path.join(result_dir, 'adr_drug_1.csv'), sep = ',', index=False)
res_gene_drug.to_csv(os.path.join(result_dir, 'drug_gene_1.csv'), sep = ',', index=False)
res_disease_dp.to_csv(os.path.join(result_dir, 'dp_disease_1.csv'), sep = ',', index=False)
res_gene_disease.to_csv(os.path.join(result_dir, 'gene_disease_1.csv'), sep = ',', index=False)
