In [1]:
import numpy as np
import torch as th
import torch.nn as nn
import argparse
import easydict

In [2]:
th.cuda.is_available()

True

In [3]:
#args = parser.parse_args(argv[1:])

args = easydict.EasyDict({
     "epochs": 40,
     "rounds": 1,
     "device":"cuda",
     "dim_embedding":1000,
     "k":1,
     "lr":0.001,
     "weight_decay":0,
     "reg_lambda":1,
     "patience":6,
     "alpha":0.9,
     "edge_drop":0.5
})

In [4]:
def row_normalize(t):
    t = t.float()
    row_sums = t.sum(1) + 1e-12
    output = t / row_sums[:, None]
    output[th.isnan(output) | th.isinf(output)] = 0.0
    return output


def col_normalize(a_matrix, substract_self_loop):
    if substract_self_loop:
        np.fill_diagonal(a_matrix, 0)
    a_matrix = a_matrix.astype(float)
    col_sums = a_matrix.sum(axis=0) + 1e-12
    new_matrix = a_matrix / col_sums[np.newaxis, :]
    new_matrix[np.isnan(new_matrix) | np.isinf(new_matrix)] = 0.0
    return new_matrix


def l2_norm(t, axit=1):
    t = t.float()
    norm = th.norm(t, 2, axit, True) + 1e-12
    output = th.div(t, norm)
    output[th.isnan(output) | th.isinf(output)] = 0.0
    return output

In [6]:
!pip install dgl

Collecting dgl
  Downloading dgl-0.6.1-cp37-cp37m-manylinux1_x86_64.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 29.9 MB/s 
Installing collected packages: dgl
Successfully installed dgl-0.6.1


In [7]:
from torch import nn
from dgl import function as fn

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [8]:
class Propagation(nn.Module):
    def __init__(self, k, alpha, edge_drop=0.):
        super(Propagation, self).__init__()
        self._k = k
        self._alpha = alpha
        self.edge_drop = nn.Dropout(edge_drop)

    def forward(self, graph, feat):
        graph = graph.local_var().to('cuda')
        norm = th.pow(graph.in_degrees().float().clamp(min=1e-12), -0.5)
        shp = norm.shape + (1,) * (feat.dim() - 1)
        norm = th.reshape(norm, shp).to(feat.device)
        feat_0 = feat
        for _ in range(self._k):
            feat = feat * norm
            graph.ndata['h'] = feat
            #graph.edata['w'] = th.ones(graph.number_of_edges(), 1).to(feat.device)
            graph.edata['w'] = self.edge_drop(th.ones(graph.number_of_edges(), 1).to(feat.device))
            graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            feat = graph.ndata.pop('h')
            feat = feat * norm
            feat = (1 - self._alpha) * feat + self._alpha * feat_0

        return feat

In [9]:
import dgl
import torch.nn.functional as F

In [10]:
class GRDTI(nn.Module):
    def __init__(self, g, n_disease, n_drug, n_protein, n_sideeffect, args):
        super(GRDTI, self).__init__()
        self.g = g
        self.device = th.device(args.device)
        self.dim_embedding = args.dim_embedding

        self.activation = F.elu
        self.reg_lambda = args.reg_lambda

        self.num_disease = n_disease
        self.num_drug = n_drug
        self.num_protein = n_protein
        self.num_sideeffect = n_sideeffect

        self.drug_feat = nn.Parameter(th.FloatTensor(self.num_drug, self.dim_embedding))
        nn.init.normal_(self.drug_feat, mean=0, std=0.1)
        self.protein_feat = nn.Parameter(th.FloatTensor(self.num_protein, self.dim_embedding))
        nn.init.normal_(self.protein_feat, mean=0, std=0.1)
        self.disease_feat = nn.Parameter(th.FloatTensor(self.num_disease, self.dim_embedding))
        nn.init.normal_(self.disease_feat, mean=0, std=0.1)
        self.sideeffect_feat = nn.Parameter(th.FloatTensor(self.num_sideeffect, self.dim_embedding))
        nn.init.normal_(self.sideeffect_feat, mean=0, std=0.1)

        # 邻居信息的权重矩阵，对应论文公式（1）中的Wr、br
        self.fc_DDI = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_D_ch = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_D_Di = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_D_Side = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_D_P = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_PPI = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_P_seq = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_P_Di = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_P_D = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_Di_D = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_Di_P = nn.Linear(self.dim_embedding, self.dim_embedding).float()
        self.fc_Side_D = nn.Linear(self.dim_embedding, self.dim_embedding).float()

        self.propagation = Propagation(args.k, args.alpha, args.edge_drop)

        # Linear transformation for reconstruction
        tmp = th.randn(self.dim_embedding).float()
        self.re_DDI = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_D_ch = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_D_Di = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_D_Side = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_D_P = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_PPI = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_P_seq = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        self.re_P_Di = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        #self.re_P_D = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        #self.re_Di_P = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        #self.re_Di_D = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))
        #self.re_Side_D = nn.Parameter(th.diag(th.nn.init.normal_(tmp, mean=0, std=0.1)))

        self.reset_parameters()

    def reset_parameters(self):
        for m in GRDTI.modules(self):
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, mean=0, std=0.1)
                if m.bias is not None:
                    m.bias.data.fill_(0.1)

    def forward(self, drug_drug, drug_chemical, drug_disease, drug_sideeffect, protein_protein,
                protein_sequence, protein_disease, drug_protein, drug_protein_mask):

        disease_feat = th.mean(th.stack((th.mm(row_normalize(drug_disease.T).float(),
                                               F.relu(self.fc_Di_D(self.drug_feat))),
                                         th.mm(row_normalize(protein_disease.T).float(),
                                               F.relu(self.fc_Di_P(self.protein_feat))),
                                         self.disease_feat), dim=1), dim=1)

        drug_feat = th.mean(th.stack((th.mm(row_normalize(drug_drug).float(),
                                            F.relu(self.fc_DDI(self.drug_feat))),
                                      th.mm(row_normalize(drug_chemical).float(),
                                            F.relu(self.fc_D_ch(self.drug_feat))),
                                      th.mm(row_normalize(drug_disease).float(),
                                            F.relu(self.fc_D_Di(self.disease_feat))),
                                      th.mm(row_normalize(drug_sideeffect).float(),
                                            F.relu(self.fc_D_Side(self.sideeffect_feat))),
                                      th.mm(row_normalize(drug_protein).float(),
                                            F.relu(self.fc_D_P(self.protein_feat))),
                                      self.drug_feat), dim=1), dim=1)

        protein_feat = th.mean(th.stack((th.mm(row_normalize(protein_protein).float(),
                                               F.relu(self.fc_PPI(self.protein_feat))),
                                         th.mm(row_normalize(protein_sequence).float(),
                                               F.relu(self.fc_P_seq(self.protein_feat))),
                                         th.mm(row_normalize(protein_disease).float(),
                                               F.relu(self.fc_P_Di(self.disease_feat))),
                                         th.mm(row_normalize(drug_protein.T).float(),
                                               F.relu(self.fc_P_D(self.drug_feat))),
                                         self.protein_feat), dim=1), dim=1)

        sideeffect_feat = th.mean(th.stack((th.mm(row_normalize(drug_sideeffect.T).float(),
                                                  F.relu(self.fc_Side_D(self.drug_feat))),
                                            self.sideeffect_feat), dim=1), dim=1)

        node_feat = th.cat((disease_feat, drug_feat, protein_feat, sideeffect_feat), dim=0)

        node_feat = self.propagation(dgl.to_homogeneous(self.g), node_feat)

        disease_embedding = node_feat[:self.num_disease].to(self.device)
        drug_embedding = node_feat[self.num_disease:self.num_disease + self.num_drug].to(self.device)
        protein_embedding = node_feat[self.num_disease + self.num_drug:self.num_disease + self.num_drug +
                                                                       self.num_protein].to(self.device)
        sideeffect_embedding = node_feat[-self.num_sideeffect:].to(self.device)

        disease_vector = l2_norm(disease_embedding)
        drug_vector = l2_norm(drug_embedding)
        protein_vector = l2_norm(protein_embedding)
        sideeffect_vector = l2_norm(sideeffect_embedding)

        drug_drug_reconstruct = th.mm(th.mm(drug_vector, self.re_DDI), drug_vector.t())
        drug_drug_reconstruct_loss = th.sum(
            (drug_drug_reconstruct - drug_drug.float()) ** 2)

        drug_chemical_reconstruct = th.mm(th.mm(drug_vector, self.re_D_ch), drug_vector.t())
        drug_chemical_reconstruct_loss = th.sum(
            (drug_chemical_reconstruct - drug_chemical.float()) ** 2)

        drug_disease_reconstruct = th.mm(th.mm(drug_vector, self.re_D_Di), disease_vector.t())
        drug_disease_reconstruct_loss = th.sum(
            (drug_disease_reconstruct - drug_disease.float()) ** 2)

        drug_sideeffect_reconstruct = th.mm(th.mm(drug_vector, self.re_D_Side), sideeffect_vector.t())
        drug_sideeffect_reconstruct_loss = th.sum(
            (drug_sideeffect_reconstruct - drug_sideeffect.float()) ** 2)

        protein_protein_reconstruct = th.mm(th.mm(protein_vector, self.re_PPI), protein_vector.t())
        protein_protein_reconstruct_loss = th.sum(
            (protein_protein_reconstruct - protein_protein.float()) ** 2)

        protein_sequence_reconstruct = th.mm(th.mm(protein_vector, self.re_P_seq), protein_vector.t())
        protein_sequence_reconstruct_loss = th.sum(
            (protein_sequence_reconstruct - protein_sequence.float()) ** 2)

        protein_disease_reconstruct = th.mm(th.mm(protein_vector, self.re_P_Di), disease_vector.t())
        protein_disease_reconstruct_loss = th.sum(
            (protein_disease_reconstruct - protein_disease.float()) ** 2)

        drug_protein_reconstruct = th.mm(th.mm(drug_vector, self.re_D_P), protein_vector.t())
        tmp = th.mul(drug_protein_mask.float(), (drug_protein_reconstruct - drug_protein.float()))
        DTI_potential = drug_protein_reconstruct - drug_protein.float()
        drug_protein_reconstruct_loss = th.sum(tmp ** 2)

        other_loss = drug_drug_reconstruct_loss + drug_chemical_reconstruct_loss + drug_disease_reconstruct_loss + \
                     drug_sideeffect_reconstruct_loss + protein_protein_reconstruct_loss + \
                     protein_sequence_reconstruct_loss + protein_disease_reconstruct_loss

        L2_loss = 0.
        for name, param in GRDTI.named_parameters(self):
            if 'bias' not in name:
                L2_loss = L2_loss + th.sum(param.pow(2))
        L2_loss = L2_loss * 0.5

        tloss = drug_protein_reconstruct_loss + 1.0 * other_loss + self.reg_lambda * L2_loss

        return tloss, drug_protein_reconstruct_loss, L2_loss, drug_protein_reconstruct, DTI_potential

In [None]:
# -*- coding: utf-8 -*-

import dgl
import time
import torch as th
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.model_selection import train_test_split, StratifiedKFold


def loda_data():
    network_path = '../data/'

    drug_drug = np.loadtxt(network_path + 'mat_drug_drug.txt')
    true_drug = 708
    drug_chemical = np.loadtxt(network_path + 'Similarity_Matrix_Drugs.txt')
    drug_chemical = drug_chemical[:true_drug, :true_drug]
    drug_disease = np.loadtxt(network_path + 'mat_drug_disease.txt')
    drug_sideeffect = np.loadtxt(network_path + 'mat_drug_se.txt')

    protein_protein = np.loadtxt(network_path + 'mat_protein_protein.txt')
    protein_sequence = np.loadtxt(network_path + 'Similarity_Matrix_Proteins.txt')
    protein_disease = np.loadtxt(network_path + 'mat_protein_disease.txt')

    num_drug = len(drug_drug)
    num_protein = len(protein_protein)

    # Removed the self-loop
    drug_chemical = drug_chemical - np.identity(num_drug)
    protein_sequence = protein_sequence / 100.
    protein_sequence = protein_sequence - np.identity(num_protein)

    drug_protein = np.loadtxt(network_path + 'mat_drug_protein.txt')

    # Removed DTIs with similar drugs or proteins
    #drug_protein = np.loadtxt(network_path + 'mat_drug_protein_homo_protein_drug.txt')

    print("Load data finished.")

    return drug_drug, drug_chemical, drug_disease, drug_sideeffect, protein_protein, protein_sequence, \
           protein_disease, drug_protein


def ConstructGraph(drug_drug, drug_chemical, drug_disease, drug_sideeffect, protein_protein, protein_sequence,
                   protein_disease, drug_protein):
    num_drug = len(drug_drug)
    num_protein = len(protein_protein)
    num_disease = len(drug_disease.T)
    num_sideeffect = len(drug_sideeffect.T)

    list_drug = []
    for i in range(num_drug):
        list_drug.append((i, i))

    list_protein = []
    for i in range(num_protein):
        list_protein.append((i, i))

    list_disease = []
    for i in range(num_disease):
        list_disease.append((i, i))

    list_sideeffect = []
    for i in range(num_sideeffect):
        list_sideeffect.append((i, i))

    list_DDI = []
    for row in range(num_drug):
        for col in range(num_drug):
            if drug_drug[row, col] > 0:
                list_DDI.append((row, col))

    list_PPI = []
    for row in range(num_protein):
        for col in range(num_protein):
            if protein_protein[row, col] > 0:
                list_PPI.append((row, col))

    list_drug_protein = []
    list_protein_drug = []
    for row in range(num_drug):
        for col in range(num_protein):
            if drug_protein[row, col] > 0:
                list_drug_protein.append((row, col))
                list_protein_drug.append((col, row))

    list_drug_sideeffect = []
    list_sideeffect_drug = []
    for row in range(num_drug):
        for col in range(num_sideeffect):
            if drug_sideeffect[row, col] > 0:
                list_drug_sideeffect.append((row, col))
                list_sideeffect_drug.append((col, row))

    list_drug_disease = []
    list_disease_drug = []
    for row in range(num_drug):
        for col in range(num_disease):
            if drug_disease[row, col] > 0:
                list_drug_disease.append((row, col))
                list_disease_drug.append((col, row))

    list_protein_disease = []
    list_disease_protein = []
    for row in range(num_protein):
        for col in range(num_disease):
            if protein_disease[row, col] > 0:
                list_protein_disease.append((row, col))
                list_disease_protein.append((col, row))

    g_HIN = dgl.heterograph({('disease', 'disease_disease virtual', 'disease'): list_disease,
                             ('drug', 'drug_drug virtual', 'drug'): list_drug,
                             ('protein', 'protein_protein virtual', 'protein'): list_protein,
                             ('sideeffect', 'sideeffect_sideeffect virtual', 'sideeffect'): list_sideeffect,
                             ('drug', 'drug_drug interaction', 'drug'): list_DDI, \
                             ('protein', 'protein_protein interaction', 'protein'): list_PPI, \
                             ('drug', 'drug_protein interaction', 'protein'): list_drug_protein, \
                             ('protein', 'protein_drug interaction', 'drug'): list_protein_drug, \
                             ('drug', 'drug_sideeffect association', 'sideeffect'): list_drug_sideeffect, \
                             ('sideeffect', 'sideeffect_drug association', 'drug'): list_sideeffect_drug, \
                             ('drug', 'drug_disease association', 'disease'): list_drug_disease, \
                             ('disease', 'disease_drug association', 'drug'): list_disease_drug, \
                             ('protein', 'protein_disease association', 'disease'): list_protein_disease, \
                             ('disease', 'disease_protein association', 'protein'): list_disease_protein})

    g = g_HIN.edge_type_subgraph(['drug_drug interaction', 'protein_protein interaction',
                                  'drug_protein interaction', 'protein_drug interaction',
                                  'drug_sideeffect association', 'sideeffect_drug association',
                                  'drug_disease association', 'disease_drug association',
                                  'protein_disease association', 'disease_protein association'
                                  ])

    return g


def TrainAndEvaluate(DTItrain, DTIvalid, DTItest, args, drug_drug, drug_chemical, drug_disease,
                     drug_sideeffect, protein_protein, protein_sequence, protein_disease):
    device = th.device(args.device)

    # Numbers of different nodes
    num_disease = len(drug_disease.T)
    num_drug = len(drug_drug)
    num_protein = len(protein_protein)
    num_sideeffect = len(drug_sideeffect.T)

    drug_protein = th.zeros((num_drug, num_protein))
    mask = th.zeros((num_drug, num_protein)).to(device)
    for ele in DTItrain:
        drug_protein[ele[0], ele[1]] = ele[2]
        mask[ele[0], ele[1]] = 1

    best_valid_aupr = 0.
    # best_valid_auc = 0
    test_aupr = 0.
    test_auc = 0.
    patience = 0.

    pos = np.count_nonzero(DTItest[:, 2])
    neg = np.size(DTItest[:, 2]) - pos
    xy_roc_sampling = []
    xy_pr_sampling = []

    g = ConstructGraph(drug_drug, drug_chemical, drug_disease, drug_sideeffect, protein_protein, protein_sequence,
                       protein_disease, drug_protein)

    drug_drug = th.tensor(drug_drug).to(device)
    drug_chemical = th.tensor(drug_chemical).to(device)
    drug_disease = th.tensor(drug_disease).to(device)
    drug_sideeffect = th.tensor(drug_sideeffect).to(device)
    protein_protein = th.tensor(protein_protein).to(device)
    protein_sequence = th.tensor(protein_sequence).to(device)
    protein_disease = th.tensor(protein_disease).to(device)
    drug_protein = drug_protein.to(device)

    model = GRDTI(g, num_disease, num_drug, num_protein, num_sideeffect, args)
    model.to(device)

    optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    for i in range(args.epochs):

        model.train()
        tloss, dtiloss, l2loss, dp_re, DTI_p = model(drug_drug, drug_chemical, drug_disease, drug_sideeffect,
                                                     protein_protein, protein_sequence, protein_disease,
                                                     drug_protein, mask)

        results = dp_re.detach().cpu()
        optimizer.zero_grad()
        loss = tloss
        loss.backward()
        th.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()

        model.eval()

        if i % 25 == 0:
            with th.no_grad():
                print("step", i, ":", "Total_loss & DTIloss & L2_loss:", loss.cpu().data.numpy(), ",", dtiloss.item(),
                      ",", l2loss.item())

                pred_list = []
                ground_truth = []

                for ele in DTIvalid:
                    pred_list.append(results[ele[0], ele[1]])
                    ground_truth.append(ele[2])

                valid_auc = roc_auc_score(ground_truth, pred_list)
                valid_aupr = average_precision_score(ground_truth, pred_list)

                if valid_aupr >= best_valid_aupr:
                    best_valid_aupr = valid_aupr
                    # best_valid_auc = valid_auc
                    best_DTI_potential = DTI_p
                    patience = 0

                    # Calculating AUC & AUPR (pos:neg=1:10)
                    db = []
                    xy_roc = []
                    xy_pr = []
                    for ele in DTItest:
                        db.append([results[ele[0], ele[1]], ele[2]])

                    db = sorted(db, key=lambda x: x[0], reverse=True)

                    tp, fp = 0., 0.
                    for i_db in range(len(db)):
                        if db[i_db][0]:
                            if db[i_db][1]:
                                tp = tp + 1
                            else:
                                fp = fp + 1
                            xy_roc.append([fp / neg, tp / pos])
                            xy_pr.append([tp / pos, tp / (tp + fp)])

                    test_auc = 0.
                    prev_x = 0.
                    for x, y in xy_roc:
                        if x != prev_x:
                            test_auc += (x - prev_x) * y
                            prev_x = x

                    test_aupr = 0.
                    prev_x = 0.
                    for x, y in xy_pr:
                        if x != prev_x:
                            test_aupr += (x - prev_x) * y
                            prev_x = x

                    # All unknown DTI pairs all treated as negative examples
                    '''pred_list = []
                    ground_truth = []
                    for ele in DTItest:
                        pred_list.append(results[ele[0], ele[1]])
                        ground_truth.append(ele[2])
                    test_auc = roc_auc_score(ground_truth, pred_list)
                    test_aupr = average_precision_score(ground_truth, pred_list)'''

                else:
                    patience += 1
                    if patience > args.patience:
                        print("Early Stopping")

                        # sampling (pos:neg=1:10) for averaging and plotting
                        xy_roc_sampling = []
                        xy_pr_sampling = []
                        for i_xy in range(len(xy_roc)):
                            if i_xy % 10 == 0:
                                xy_roc_sampling.append(xy_roc[i_xy])
                                xy_pr_sampling.append(xy_pr[i_xy])

                        # Record data for sampling, averaging and plotting.
                        # All unknown DTI pairs all treated as negative examples
                        '''t1 = time.localtime()
                        time_creat_txt = str(t1.tm_year) + '_' + str(t1.tm_mon) + '_' + str(t1.tm_mday) + '_' + str(
                            t1.tm_hour) + '_' + str(t1.tm_min)
                        fpr, tpr, threshold = roc_curve(ground_truth, pred_list)
                        print("len(fpr):", len(fpr))
                        np.savetxt('fpr_' + time_creat_txt + '.csv', fpr)
                        np.savetxt('tpr_' + time_creat_txt + '.csv', tpr)
                        np.savetxt('ROC_threshold_' + time_creat_txt + '.csv', threshold)

                        precision, recall, threshold = precision_recall_curve(ground_truth, pred_list)
                        print("len(recall):", len(recall))
                        np.savetxt('precision_' + time_creat_txt + '.csv', precision)
                        np.savetxt('recall_' + time_creat_txt + '.csv', recall)
                        np.savetxt('PRC_threshold_' + time_creat_txt + '.csv', threshold)'''

                        break

                print('Valid auc & aupr:', valid_auc, valid_aupr, ";  ", 'Test auc & aupr:', test_auc, test_aupr)
                th.cuda.empty_cache()
           
    return test_auc, test_aupr, xy_roc_sampling, xy_pr_sampling, best_DTI_potential
    

def main(args):
    drug_d, drug_ch, drug_di, drug_side, protein_p, protein_seq, protein_di, dti_original = loda_data()

    # sampling
    whole_positive_index = []
    whole_negative_index = []
    for i in range(np.shape(dti_original)[0]):
        for j in range(np.shape(dti_original)[1]):
            if int(dti_original[i][j]) == 1:
                whole_positive_index.append([i, j])
            elif int(dti_original[i][j]) == 0:
                whole_negative_index.append([i, j])

    # pos:neg=1:10
    negative_sample_index = np.random.choice(np.arange(len(whole_negative_index)),
                                             size=10 * len(whole_positive_index), replace=False)

    # All unknown DTI pairs all treated as negative examples
    '''negative_sample_index = np.random.choice(np.arange(len(whole_negative_index)),
                                             size=len(whole_negative_index), replace=False)'''

    data_set = np.zeros((len(negative_sample_index) + len(whole_positive_index), 3), dtype=int)
    count = 0
    for i in whole_positive_index:
        data_set[count][0] = i[0]
        data_set[count][1] = i[1]
        data_set[count][2] = 1
        count += 1
    for i in negative_sample_index:
        data_set[count][0] = whole_negative_index[i][0]
        data_set[count][1] = whole_negative_index[i][1]
        data_set[count][2] = 0
        count += 1

    test_auc_round = []
    test_aupr_round = []
    tpr_mean = []
    fpr = []
    precision_mean = []
    recall = []

    rounds = args.rounds
    for r in range(rounds):
        print("----------------------------------------")

        test_auc_fold = []
        test_aupr_fold = []

        kf = StratifiedKFold(n_splits=10, random_state=None, shuffle=True)
        k_fold = 0

        for train_index, test_index in kf.split(data_set[:, :2], data_set[:, 2]):
            train = data_set[train_index]
            DTItest = data_set[test_index]
            DTItrain, DTIvalid = train_test_split(train, test_size=0.05, random_state=None)

            k_fold += 1
            print("--------------------------------------------------------------")
            print("round ", r + 1, " of ", rounds, ":", "KFold ", k_fold, " of 10")
            print("--------------------------------------------------------------")

            time_roundStart = time.time()

            t_auc, t_aupr, xy_roc, xy_pr, DTI_potential = TrainAndEvaluate(DTItrain, DTIvalid, DTItest, args, drug_d,
                                                                           drug_ch, drug_di, drug_side, protein_p,
                                                                           protein_seq, protein_di)

            time_roundEnd = time.time()
            print("Time spent in this fold:", time_roundEnd - time_roundStart)
            test_auc_fold.append(t_auc)
            test_aupr_fold.append(t_aupr)

            order_txt1 = 'DTI_potential_' + 'r' + str(r + 1) + '_f' + str(k_fold) + '.csv'
            np.savetxt(order_txt1, DTI_potential.detach().cpu().numpy(), fmt='%-.4f', delimiter=',')
            top_values, top_indices = th.topk(DTI_potential, 40)
            order_txt2 = 'top40_' + 'r' + str(r + 1) + '_f' + str(k_fold) + '.csv'
            np.savetxt(order_txt2, top_indices.detach().cpu().numpy(), fmt='%d', delimiter=',')

            # pos:neg=1:10
            if not fpr:
                fpr = [_v[0] for _v in xy_roc]
            if not recall:
                recall = [_v[0] for _v in xy_pr]

            temp = [_v[1] for _v in xy_roc]
            tpr_mean.append(temp)
            temp = [_v[1] for _v in xy_pr]
            precision_mean.append(temp)

        print("Training and evaluation is OK.")

        test_auc_round.append(np.mean(test_auc_fold))
        test_aupr_round.append(np.mean(test_aupr_fold))

    t1 = time.localtime()
    time_creat_txt = str(t1.tm_year) + '_' + str(t1.tm_mon) + '_' + str(t1.tm_mday) + '_' + str(t1.tm_hour) + '_' + str(
        t1.tm_min)
    np.savetxt('test_auc_' + time_creat_txt, test_auc_round)
    np.savetxt('test_aupr_' + time_creat_txt, test_aupr_round)

    # pos:neg=1:10
    tpr = (np.mean(np.array(tpr_mean), axis=0)).tolist()
    precision = (np.mean(np.array(precision_mean), axis=0)).tolist()

    np.savetxt('fpr.csv', fpr, fmt='%-.4f', delimiter=',')
    np.savetxt('tpr.csv', tpr, fmt='%-.4f', delimiter=',')
    np.savetxt('recall.csv', recall, fmt='%-.4f', delimiter=',')
    np.savetxt('precision.csv', precision, fmt='%-.4f', delimiter=',')


if __name__ == "__main__":
    #args = parse_args()
    print(args)

    start = time.time()
    main(args)
    end = time.time()
    print("Total time:", end - start)