In [None]:
import os
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data

from torch_geometric.transforms import RandomLinkSplit
import torch_geometric.transforms as T 
from torch.nn import Embedding
from torch_geometric.nn import GATConv, GraphConv, SAGEConv
import torch.nn.functional as F 
from torch.nn import Linear
import json
from sklearn.metrics import roc_auc_score
import subprocess


In [None]:
nvidia_smi_output = subprocess.check_output(['nvidia-smi']).decode('utf-8')
print(nvidia_smi_output)


In [24]:
data_dir = '../data/'
emb_dir = os.path.join(data_dir, 'embedding/emb_GraphConv')
result_dir = os.path.join(data_dir, 'result')


parameters = {
    "model":{
        "encoder": "GraphConv",     
        "decoder": "cosine",   
        "N_layers" : 3,        # fix 
        "activation_func": 'relu',
        "hidden_channels": 15,    
        "objective_Func": "binary_cross_entropy_with_logits",
        "epoch" : 500,     
        "patience" : 5,  
        "lr" : 0.001,
        "weight_decay" : 1e-5
    },

    "evaluation":{
        "top_k" : np.arange(5,500,10).tolist(),
        "rep" : 1000
    }
}

device = torch.device("cuda:1")

In [None]:
edge_index = pd.read_csv(os.path.join(data_dir, 'edge_index.csv'), sep= ',')
edge_index = np.array(edge_index)
edge_index.dtype
edge_index = edge_index.astype(np.int64)
edge_index.dtype
edge_index = np.transpose(edge_index)
edge_index.shape
edge_index = torch.tensor(edge_index, dtype=torch.long)

In [26]:
conversion_table = pd.read_csv(os.path.join(data_dir, 'conversionTable.csv'))
N_genes = conversion_table['Names'].str.contains('entrez.').sum()

In [1]:
# Data Construction
max_node_index = torch.max(edge_index)
num_nodes = max_node_index + 1
x = torch.arange(num_nodes)
data = Data(x = x, edge_index = edge_index).to(device)

In [28]:
first_I_adr = conversion_table[conversion_table['Names'].str.contains('meddra.')]['Nodes'].min()
last_I_adr = conversion_table[conversion_table['Names'].str.contains('meddra.')]['Nodes'].max()

first_I_drug = conversion_table[conversion_table['Names'].str.contains('drugbank.')]['Nodes'].min()
last_I_drug = conversion_table[conversion_table['Names'].str.contains('drugbank.')]['Nodes'].max()

first_I_dp = conversion_table[conversion_table['Names'].str.contains('hpo.')]['Nodes'].min()
last_I_dp = conversion_table[conversion_table['Names'].str.contains('hpo.')]['Nodes'].max()

first_I_disease = conversion_table[conversion_table['Names'].str.contains('mondo.')]['Nodes'].min()
last_I_disease = conversion_table[conversion_table['Names'].str.contains('mondo.')]['Nodes'].max()

first_I_gene = conversion_table[conversion_table['Names'].str.contains('entrez.')]['Nodes'].min()
last_I_gene = conversion_table[conversion_table['Names'].str.contains('entrez.')]['Nodes'].max()


In [None]:
flag_drug_adr = (((edge_index[0]>=first_I_drug) & (edge_index[0]<=last_I_drug) & (edge_index[1]>=first_I_adr) & (edge_index[1]<=last_I_adr)) |\
                 ((edge_index[0]>=first_I_adr) & (edge_index[0]<=last_I_adr) & (edge_index[1]>=first_I_drug) & (edge_index[1]<=last_I_drug)))

flag_drug_gene = (((edge_index[0]>=first_I_drug) & (edge_index[0]<=last_I_drug) & (edge_index[1]>=first_I_gene) & (edge_index[1]<=last_I_gene)) |\
                    ((edge_index[0]>=first_I_gene) & (edge_index[0]<=last_I_gene) & (edge_index[1]>=first_I_drug) & (edge_index[1]<=last_I_drug)))

flag_disease_dp = (((edge_index[0]>=first_I_disease) & (edge_index[0]<=last_I_disease) & (edge_index[1]>=first_I_dp) & (edge_index[1]<=last_I_dp)) |\
                    ((edge_index[0]>=first_I_dp) & (edge_index[0]<=last_I_dp) & (edge_index[1]>=first_I_disease) & (edge_index[1]<=last_I_disease)))

flag_gene_disease = (((edge_index[0]>=first_I_disease) & (edge_index[0]<=last_I_disease) & (edge_index[1]>=first_I_gene) & (edge_index[1]<=last_I_gene))|\
                    ((edge_index[0]>=first_I_gene) & (edge_index[0]<=last_I_gene) & (edge_index[1]>=first_I_disease) & (edge_index[1]<=last_I_disease)))

flag_gene_gene = ((edge_index[0]>=first_I_gene) & (edge_index[0]<=last_I_gene) & (edge_index[1]>=first_I_gene) & (edge_index[1]<=last_I_gene))

edge_index.shape, flag_drug_adr.shape, flag_drug_gene.shape, flag_disease_dp.shape, flag_gene_disease.shape, flag_gene_gene.shape


In [30]:
edge_index_drug_adr = edge_index[:,flag_drug_adr]
edge_index_drug_gene = edge_index[:,flag_drug_gene]
edge_index_disease_dp = edge_index[:,flag_disease_dp]
edge_index_gene_disease = edge_index[:,flag_gene_disease]
edge_index_gene_gene = edge_index[:,flag_gene_gene]

data_drug_adr = Data(x = x, edge_index = edge_index_drug_adr).to(device)
data_drug_gene = Data(x = x, edge_index = edge_index_drug_gene).to(device)
data_disease_dp = Data(x = x, edge_index = edge_index_disease_dp).to(device)
data_gene_disease = Data(x = x, edge_index = edge_index_gene_disease).to(device)
data_gene_gene = Data(x = x, edge_index = edge_index_gene_gene).to(device)


In [31]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, num_nodes, hidden_channels, out_channels):
        super().__init__()
        self.embedding = Embedding(num_nodes, hidden_channels)
        self.conv1 = GraphConv((-1, -1), hidden_channels)
        self.conv2 = GraphConv((-1, -1), hidden_channels)
        self.conv5 = GraphConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.embedding(x)
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv5(x, edge_index)
        return x

class EdgeDecoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, z, edge_label_index):
        row_0, row_1 = edge_label_index
        sim = F.cosine_similarity(z[row_0], z[row_1], dim = -1, eps = 1e-6)
        return sim
 

class Model(torch.nn.Module):
    def __init__(self, num_nodes, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(num_nodes, hidden_channels, hidden_channels)
        self.decoder = EdgeDecoder()

    def forward(self, x, edge_index, edge_label_index):
        z = self.encoder(x, edge_index)
        return self.decoder(z, edge_label_index)

    def forward_encoder(self, x, edge_index):
        z = self.encoder(x, edge_index)
        return z

In [32]:
class Evaluation():
    def __init__(self, y_true, y_pred):
        self.y_true = y_true
        self.y_pred = y_pred
        self.y_pred_pos = y_pred[y_true == 1]
        self.y_pred_neg = y_pred[y_true == 0]

    def eval(self, k_list):

        df = pd.DataFrame(columns = ['prc@' + str(k) for k in k_list] + 
                                    ['hits@' + str(k) for k in k_list] +
                                    ['mrr_hits@' + str(k) for k in [1,3,10,20,50,100]] +
                                    ['mrr', 'mrr2', 'auc'])
        Precision_k_list = []
        Hits_k_list = []
        for k in k_list:
            # Precision_k
            Precision_k_list.append(self.precision_k(k))
            Hits_k_list.append(self.hits_k(k))

        MRR_list = self.mrr()
        MRR2 = [self.mrr2()]
        AUC = [self.auc()]
        df.loc[0] = Precision_k_list + Hits_k_list + MRR_list + MRR2 + AUC
        return df

    def precision_k(self, k):
        """
        the fraction of true links that appear in the first 𝑘 link of the sorted rank list.
        
        y_true: A tensor of ground truth (0 and 1).
        y_pred: A tensor of logits.
        k: Number of top elements to look at for computing precision.
        """
        if k > len(self.y_pred):
            k = len(self.y_pred)
        
        topk_indices = torch.topk(self.y_pred, k).indices
        
        # Calculate precision
        value = self.y_true[topk_indices].float().mean()
        return value.item()
    
    
    def hits_k(self, k):
        
        if k > len(self.y_pred_neg):
            return(1)
            
        kth_score_in_negative_edges = torch.topk(self.y_pred_neg, k=k)[0][-1]
        hitsK = float(torch.sum(self.y_pred_pos > kth_score_in_negative_edges).cpu()) / len(self.y_pred_pos)
        return hitsK
    
    
    def mrr(self):
        '''
            compute mrr
            y_pred_neg is an array with shape (batch size, num_entities_neg).
            y_pred_pos is an array with shape (batch size, num_entities_pos)
        '''
        
        # calculate ranks
        y_pred_pos = self.y_pred_pos.view(-1, 1)
        optimistic_rank = (self.y_pred_neg >= y_pred_pos).sum(dim=1)
        pessimistic_rank = (self.y_pred_neg > y_pred_pos).sum(dim=1)
        ranking_list = 0.5 * (optimistic_rank + pessimistic_rank) + 1
    
        hits1_list = (ranking_list <= 1).to(torch.float)
        hits3_list = (ranking_list <= 3).to(torch.float)
        hits10_list = (ranking_list <= 10).to(torch.float)
        hits20_list = (ranking_list <= 20).to(torch.float)
        hits50_list = (ranking_list <= 50).to(torch.float)
        hits100_list = (ranking_list <= 100).to(torch.float)
        mrr_list = 1./ranking_list.to(torch.float)
    
        return [hits1_list.mean().item(),
                hits3_list.mean().item(),
                hits10_list.mean().item(),
                hits20_list.mean().item(),
                hits50_list.mean().item(),
                hits100_list.mean().item(),
                mrr_list.mean().item()]
    
    def mrr2(self):
        """
        Calculate the Mean Reciprocal Rank (MRR) for link prediction.
    
        y_true: A tensor of ground truth (0 and 1).
        y_pred: A tensor of logits.
        """
    
        sorted_indices = torch.argsort(self.y_pred, descending=True) +1
    
        # Calculate the mean reciprocal rank
        mrr = (1 / sorted_indices[self.y_true == 1]).mean()
        return mrr.item()
    
    
    def auc(self):
        auc_score = roc_auc_score(self.y_true.detach().cpu().numpy(), self.y_pred.detach().cpu().numpy())
        return auc_score

In [34]:
# Parameters
lam = 1
hidden_channels = parameters["model"]["hidden_channels"]
patience_thr = parameters["model"]["patience"]
Epoch =  parameters["model"]["epoch"]
lr = parameters["model"]["lr"]
weight_decay=parameters["model"]["weight_decay"]
Repeat = parameters["evaluation"]["rep"]

num_nodes = len(data.x)

In [35]:
# Initial values
loss_old = 1000000
loss_train = []
loss_val = []
loss_test = []

loss_total = pd.DataFrame(columns = ["loss_train", "loss_val", "loss_test"])
df_drug_adr = pd.DataFrame()
df_gene_drug = pd.DataFrame()
df_disease_dp = pd.DataFrame()
df_gene_disease = pd.DataFrame()

losses_train = []
losses_val = []
losses_test = []

losses_train_drug_adr = []
losses_val_drug_adr = []  
losses_test_drug_adr = []

losses_train_gene_drug = []
losses_val_gene_drug = []  
losses_test_gene_drug = []

losses_train_disease_dp = []
losses_val_disease_dp = [] 
losses_test_disease_dp = []

losses_train_gene_disease = []
losses_val_gene_disease = []  
losses_test_gene_disease = []


In [36]:
splitdata = T.RandomLinkSplit(is_undirected = True, 
                              num_val = 0.0, 
                              num_test = 0.1,
                              neg_sampling_ratio = 0,
                              add_negative_train_samples = False)


splitdata_PPI = T.RandomLinkSplit(is_undirected = True, 
                                  num_val = 0.0, 
                                  num_test = 0.0,
                                  neg_sampling_ratio = 0,
                                  add_negative_train_samples = False)

In [37]:
def negative_sample(node_type1_range, node_type2_range, edge_index, num_test, is_undirected=True, device='cpu'):

    existing_edges = set(zip(edge_index[0].tolist(), edge_index[1].tolist()))

    min1, max1 = node_type1_range
    min2, max2 = node_type2_range

    nodes_1 = torch.randint(min1, max1+1, (3*num_test,1)).to(device)
    nodes_2 = torch.randint(min2, max2+1, (3*num_test,1)).to(device)
    rand_edges = torch.concatenate([nodes_1, nodes_2], dim=1)
    rand_edges, _ = torch.unique(rand_edges, dim=0, return_inverse=True)

    negative_edge_index = [
        (u.item(), v.item()) for u, v in rand_edges 
        if (u.item(), v.item()) not in existing_edges
    ]

    negative_edge_index = torch.tensor(negative_edge_index)
    rand_ind = np.random.choice(np.arange(0, negative_edge_index.shape[0]), size=num_test, replace=False)
    
    negative_edge_index = negative_edge_index[rand_ind]

    if is_undirected:
        negative_edge_index = torch.concatenate([negative_edge_index, negative_edge_index[:,[1,0]]])
    
    negative_edge_index = negative_edge_index.T

    return negative_edge_index

In [None]:
for i in range(Repeat):
    print(f'repeat:{i}')
    
    train_data_drug_adr, _, test_data_drug_adr = splitdata(data_drug_adr)  
    train_data_drug_gene, _, test_data_drug_gene = splitdata(data_drug_gene)
    train_data_disease_dp, _, test_data_disease_dp = splitdata(data_disease_dp)  
    train_data_gene_disease, _, test_data_gene_disease = splitdata(data_gene_disease)  
    train_data_gene_gene, _, _ = splitdata_PPI(data_gene_gene)

    ## generate negative sample for train and test data
    neg_drug_adr = negative_sample((first_I_drug, last_I_drug), (first_I_adr, last_I_adr), edge_index_drug_adr, 
                                                        num_test = len(train_data_drug_adr.edge_label)+len(test_data_drug_adr.edge_label), 
                                         is_undirected=False, device = device)
    
    neg_drug_gene = negative_sample((first_I_drug, last_I_drug), (first_I_gene, last_I_gene), edge_index_drug_gene, 
                                                     num_test = len(train_data_drug_gene.edge_label)+len(test_data_drug_gene.edge_label), 
                                    is_undirected=True, device = device)
    
    neg_disease_dp = negative_sample((first_I_disease, last_I_disease), (first_I_dp, last_I_dp), edge_index_disease_dp, 
                                                      num_test = len(train_data_disease_dp.edge_label)+len(test_data_disease_dp.edge_label), 
                                     is_undirected=True, device = device)
    
    neg_gene_disease = negative_sample((first_I_gene, last_I_gene), (first_I_disease, last_I_disease), edge_index_gene_disease, 
                                                        num_test = len(train_data_gene_disease.edge_label)+len(test_data_gene_disease.edge_label), 
                                       is_undirected=True, device = device)

    neg_gene_gene = negative_sample((first_I_gene, last_I_gene), (first_I_gene, last_I_gene), edge_index_gene_gene, 
                                                        num_test = len(train_data_gene_gene.edge_label), 
                                    is_undirected=True, device = device)
    

    ## seperate negative sample for train and test data

    neg_drug_adr_train = neg_drug_adr[:, 0:len(train_data_drug_adr.edge_label)].to(device)
    neg_drug_adr_test = neg_drug_adr[:, (len(train_data_drug_adr.edge_label)+1):(len(train_data_drug_adr.edge_label)+len(test_data_drug_adr.edge_label))].to(device)

    neg_drug_gene_train = neg_drug_gene[:, 0:len(train_data_drug_gene.edge_label)].to(device)
    neg_drug_gene_test = neg_drug_gene[:, (len(train_data_drug_gene.edge_label)+1):(len(train_data_drug_gene.edge_label)+len(test_data_drug_gene.edge_label))].to(device)

    neg_disease_dp_train = neg_disease_dp[:, 0:len(train_data_disease_dp.edge_label)].to(device)
    neg_disease_dp_test = neg_disease_dp[:, (len(train_data_disease_dp.edge_label)+1):(len(train_data_disease_dp.edge_label)+len(test_data_disease_dp.edge_label))].to(device)

    neg_gene_disease_train = neg_gene_disease[:, 0:len(train_data_gene_disease.edge_label)].to(device)
    neg_gene_disease_test = neg_gene_disease[:, (len(train_data_gene_disease.edge_label)+1):(len(train_data_gene_disease.edge_label)+len(test_data_gene_disease.edge_label))].to(device)

    neg_gene_gene_train = neg_gene_gene[:, 0:len(train_data_gene_gene.edge_label)].to(device)

    #####################################

    edge_label_index_train_drug_adr = torch.cat([train_data_drug_adr.edge_label_index, neg_drug_adr_train], dim=1)
    edge_label_index_train_drug_gene = torch.cat([train_data_drug_gene.edge_label_index, neg_drug_gene_train], dim=1)
    edge_label_index_train_disease_dp = torch.cat([train_data_disease_dp.edge_label_index, neg_disease_dp_train], dim=1)
    edge_label_index_train_gene_disease = torch.cat([train_data_gene_disease.edge_label_index, neg_gene_disease_train], dim=1)
    edge_label_index_train_gene_gene = torch.cat([train_data_gene_gene.edge_label_index, neg_gene_gene_train], dim=1)


    ####
    edge_label_neg_drug_adr_train = torch.zeros(train_data_drug_adr.edge_label.size(), dtype=torch.float, device=device)
    edge_label_neg_drug_gene_train = torch.zeros(train_data_drug_gene.edge_label.size(), dtype=torch.float, device=device)
    edge_label_neg_disease_dp_train = torch.zeros(train_data_gene_disease.edge_label.size(), dtype=torch.float, device=device)
    edge_label_neg_gene_disease_train = torch.zeros(train_data_disease_dp.edge_label.size(), dtype=torch.float, device=device)
    edge_label_neg_gene_gene_train = torch.zeros(train_data_gene_gene.edge_label.size(), dtype=torch.float, device=device)

    
    edge_label_train_drug_adr = torch.cat([train_data_drug_adr.edge_label, edge_label_neg_drug_adr_train], dim=0)
    edge_label_train_drug_gene = torch.cat([train_data_drug_gene.edge_label, edge_label_neg_drug_gene_train], dim=0)
    edge_label_train_disease_dp = torch.cat([train_data_disease_dp.edge_label, edge_label_neg_disease_dp_train], dim=0)
    edge_label_train_gene_disease = torch.cat([train_data_gene_disease.edge_label, edge_label_neg_gene_disease_train], dim=0)
    edge_label_train_gene_gene = torch.cat([train_data_gene_gene.edge_label, edge_label_neg_gene_gene_train], dim=0)

    ########


    train_data_edge_index = torch.cat((train_data_drug_adr.edge_index, train_data_drug_gene.edge_index, 
                                       train_data_disease_dp.edge_index, train_data_gene_disease.edge_index, 
                                       train_data_gene_gene.edge_index), dim = 1).to(device)


    train_data_edge_label = torch.cat((edge_label_train_drug_adr, edge_label_train_drug_gene,
                                          edge_label_train_disease_dp, edge_label_train_gene_disease, 
                                       edge_label_train_gene_gene), dim = 0)

    train_data_edge_label_index = torch.cat((edge_label_index_train_drug_adr, edge_label_index_train_drug_gene,
                                            edge_label_index_train_disease_dp, edge_label_index_train_gene_disease, 
                                             edge_label_index_train_gene_gene), dim = 1)
    
    
    train_data = Data(x=x, edge_index=train_data_edge_index, edge_label=train_data_edge_label, edge_label_index=train_data_edge_label_index).to(device)


    
    ###test:
    edge_label_index_test_drug_adr = torch.cat([test_data_drug_adr.edge_label_index, neg_drug_adr_test], dim=1)
    edge_label_index_test_drug_gene = torch.cat([test_data_drug_gene.edge_label_index, neg_drug_gene_test], dim=1)
    edge_label_index_test_disease_dp = torch.cat([test_data_disease_dp.edge_label_index, neg_disease_dp_test], dim=1)
    edge_label_index_test_gene_disease = torch.cat([test_data_gene_disease.edge_label_index, neg_gene_disease_test], dim=1)

    ##
    edge_label_neg_drug_adr_test = torch.zeros(neg_drug_adr_test.shape[1], dtype=torch.float, device=device)
    edge_label_neg_drug_gene_test = torch.zeros(neg_drug_gene_test.shape[1], dtype=torch.float, device=device)
    edge_label_neg_disease_dp_test = torch.zeros(neg_disease_dp_test.shape[1], dtype=torch.float, device=device)
    edge_label_neg_gene_disease_test = torch.zeros(neg_gene_disease_test.shape[1], dtype=torch.float, device=device)

    edge_label_test_drug_adr = torch.cat([test_data_drug_adr.edge_label, edge_label_neg_drug_adr_test], dim=0)
    edge_label_test_drug_gene = torch.cat([test_data_drug_gene.edge_label, edge_label_neg_drug_gene_test], dim=0)
    edge_label_test_gene_disease = torch.cat([test_data_gene_disease.edge_label, edge_label_neg_disease_dp_test], dim=0)
    edge_label_test_disease_dp = torch.cat([test_data_disease_dp.edge_label, edge_label_neg_gene_disease_test], dim=0)

    
    test_data_edge_index = torch.cat((test_data_drug_adr.edge_index, test_data_drug_gene.edge_index, 
                                      test_data_disease_dp.edge_index, test_data_gene_disease.edge_index, data_gene_gene.edge_index), dim = 1)
        
    test_data_edge_label_index = torch.cat((edge_label_index_test_drug_adr, edge_label_index_test_drug_gene,
                                           edge_label_index_test_disease_dp, edge_label_index_test_gene_disease), dim = 1)

    test_data_edge_label = torch.cat((edge_label_test_drug_adr, edge_label_test_drug_gene,
                                     edge_label_test_disease_dp, edge_label_test_gene_disease), dim = 0)
    
    test_data = Data(x=x, edge_index=test_data_edge_index, edge_label=test_data_edge_label, edge_label_index=test_data_edge_label_index).to(device)
    

    edge_type_flag_test = torch.cat([torch.ones(edge_label_test_drug_adr.shape), 
                                    2*torch.ones(edge_label_test_drug_gene.shape),
                                    3*torch.ones(edge_label_test_disease_dp.shape),
                                    4*torch.ones(edge_label_test_gene_disease.shape)], dim=0)

    model = Model(num_nodes, hidden_channels = hidden_channels).to(device)
    
    with torch.no_grad():
        model.encoder(data.x, train_data.edge_index)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # training loop
    losses_train.clear()
    losses_test.clear()
    
    losses_train_drug_adr.clear()
    losses_test_drug_adr.clear()
    
    losses_train_gene_drug.clear()
    losses_test_gene_drug.clear()
    
    losses_train_disease_dp.clear()
    losses_test_disease_dp.clear()
    
    losses_train_gene_disease.clear()
    losses_test_gene_disease.clear()

    counter = 0

    for epoch in range(Epoch):
        model.train()
        optimizer.zero_grad()
        
        pred = model(train_data.x, train_data.edge_index, train_data.edge_label_index)
        target = train_data.edge_label

        loss = F.binary_cross_entropy_with_logits(pred, target)

        loss.backward()
        optimizer.step()


        #test
        model.eval()
        pred_test = model(test_data.x, test_data.edge_index, test_data.edge_label_index)
        target_test = test_data.edge_label
    
        pred_test_drug_adr = pred_test[edge_type_flag_test == 1]
        pred_test_drug_target = pred_test[edge_type_flag_test == 2]
        pred_test_disease_dp = pred_test[edge_type_flag_test == 3]
        pred_test_gene_disease = pred_test[edge_type_flag_test == 4]

        target_test_drug_adr = target_test[edge_type_flag_test == 1]
        target_test_drug_target = target_test[edge_type_flag_test == 2]
        target_test_disease_dp = target_test[edge_type_flag_test == 3]
        target_test_gene_disease = target_test[edge_type_flag_test == 4]

        

        # if loss_total_val > loss_old:
        #     counter += 1
        # else:
        #     counter = 0
        # loss_old = loss_total_val
    
        # if counter >= patience_thr:
        #     print(f' patience condition is met in epoch: {epoch}') 
        #     break


    # save embedding
    model.eval()
    z = model.forward_encoder(data.x, data.edge_index)
    Emb = z.detach().cpu().numpy()
    np.savetxt(os.path.join(emb_dir, str(i+1)+'.csv'), Emb, delimiter=',')

    ###################### 
    df_drug_adr = pd.concat([df_drug_adr, Evaluation(target_test_drug_adr, pred_test_drug_adr).eval(k)], ignore_index=True)
    df_gene_drug = pd.concat([df_gene_drug, Evaluation(target_test_drug_target, pred_test_drug_target).eval(k)], ignore_index=True)
    df_disease_dp = pd.concat([df_disease_dp, Evaluation(target_test_disease_dp, pred_test_disease_dp).eval(k)], ignore_index=True)
    df_gene_disease = pd.concat([df_gene_disease, Evaluation(target_test_gene_disease, pred_test_gene_disease).eval(k)], ignore_index=True)

In [46]:
df_drug_adr.to_csv(os.path.join(result_dir, 'GraphConv_res_adr_drug.csv'), sep = ',', index=False)
df_gene_drug.to_csv(os.path.join(result_dir, 'GraphConv_res_drug_gene.csv'), sep = ',', index=False)
df_disease_dp.to_csv(os.path.join(result_dir, 'GraphConv_res_dp_disease.csv'), sep = ',', index=False)
df_gene_disease.to_csv(os.path.join(result_dir, 'GraphConv_res_gene_disease.csv'), sep = ',', index=False)