In [1]:
import torch.nn as nn
import torch
import math
import numpy as np
import pandas as pd
from numpy import linalg
from scipy.linalg import eigh
import torch.nn.functional as F
import random
# import wandb
from torch.nn.parameter import Parameter
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import scipy.sparse as sp
from torch import optim
import gc
import os.path as osp
import csv
from typing import Optional
from torch_geometric.typing import OptTensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import get_laplacian
from torch.autograd import Variable
from torch_geometric.nn import GCNConv, GATConv,RGCNConv
from torch_geometric.nn.inits import glorot, uniform
from torch_geometric.utils import softmax
from torch_sparse import SparseTensor, set_diag
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch.nn import Sequential, Linear, ReLU, Dropout
from scipy.special import comb
from torch import Tensor
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import copy
import sklearn.metrics
import time
from sklearn.cluster import KMeans
from scipy.sparse import coo_matrix
from torch_geometric.nn.dense.linear import Linear as gLinear

In [2]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

cudaid = "cpu"
device = torch.device(cudaid)

root = r"/home/lu/data/circRNA"

In [3]:
seq_sim_matrix = pd.read_csv(root + r"/gene_seq_sim.csv", index_col=0, dtype=np.float32).to_numpy()
str_sim_matrix = pd.read_csv(root + r"/drug_str_sim.csv", index_col=0, dtype=np.float32).to_numpy()
association = pd.read_csv(root + r"/association.csv", index_col=0).to_numpy()


print(seq_sim_matrix.shape)
print(str_sim_matrix.shape)
print(association.shape)


(271, 271)
(218, 218)
(271, 218)


In [4]:
def dense2sparse(matrix: np.ndarray):
    mat_coo = coo_matrix(matrix)
    edge_idx = np.vstack((mat_coo.row, mat_coo.col))
    return edge_idx, mat_coo.data

def load_data():
    cir_sim = pd.read_csv(root + r"/gene_seq_sim.csv", index_col=0, dtype=np.float32).to_numpy()
    drug_sim=pd.read_csv(root + r"/drug_str_sim.csv", index_col=0, dtype=np.float32).to_numpy()
    drug_cir_ass =pd.read_csv(root + r"/association.csv", index_col=0).to_numpy().T
    diag = np.diag(cir_sim)
    if np.sum(diag) != 0:
        cir_sim = cir_sim - np.diag(diag)

    # get the edge idx of positives samplese
    rng = np.random.default_rng(10086)
    pos_samples, edge_attr = dense2sparse(drug_cir_ass)
    pos_samples_shuffled = rng.permutation(pos_samples, axis=1)

    # get the edge index of negative samples
    rng = np.random.default_rng(10086)
    neg_samples = np.where(drug_cir_ass == 0)
    neg_samples_shuffled = rng.permutation(neg_samples, axis=1)[:, :pos_samples_shuffled.shape[1]]
    # split positive samples into training message samples, training supervision samples, test samples
    edge_idx_dict = dict()
    edge_idx_dict['pos_edges'] = pos_samples_shuffled
    edge_idx_dict['neg_edges'] = neg_samples_shuffled

    return drug_sim,cir_sim, edge_idx_dict,drug_cir_ass

In [5]:
def k_matrix(matrix, k=20):
    num = matrix.shape[0]
    knn_graph = np.zeros(matrix.shape)
    idx_sort = np.argsort(-(matrix - np.eye(num)), axis=1)
    for i in range(num):
        #将第i行最大的前k个值赋值给knn_graph(确保是对称矩阵)
        knn_graph[i, idx_sort[i, :k + 1]] = matrix[i, idx_sort[i, :k + 1]]
        knn_graph[idx_sort[i, :k + 1], i] = matrix[idx_sort[i, :k + 1], i]
    return knn_graph + np.eye(num)


In [6]:
def GIP_kernel(Asso_RNA_Dis):
    # the number of row
    nc = Asso_RNA_Dis.shape[0]
    # initate a matrix as results matrix
    matrix = np.zeros((nc, nc))
    # calculate the down part of GIP fmulate
    r = getGosiR(Asso_RNA_Dis)
    # calculate the results matrix
    for i in range(nc):
        for j in range(nc):
            # calculate the up part of GIP formulate
            temp_up = np.square(np.linalg.norm(Asso_RNA_Dis[i, :] - Asso_RNA_Dis[j, :]))
            if r == 0:
                matrix[i][j] = 0
            elif i == j:
                matrix[i][j] = 1
            else:
                matrix[i][j] = np.e ** (-temp_up / r)
    return matrix


def getGosiR(Asso_RNA_Dis):
    # calculate the r in GOsi Kerel
    nc = Asso_RNA_Dis.shape[0]
    summ = 0
    for i in range(nc):
        x_norm = np.linalg.norm(Asso_RNA_Dis[i, :])
        x_norm = np.square(x_norm)
        summ = summ + x_norm
    r = summ / nc
    return r


def get_syn_sim(A, seq_sim, str_sim, mode):
    """
    :param A:
    :param seq_sim:
    :param str_sim:
    :param mode: 0 = GIP kernel sim
    :return:
    """

    GIP_c_sim = GIP_kernel(A)
    GIP_d_sim = GIP_kernel(A.T)

    if mode == 0:
        return GIP_c_sim, GIP_d_sim

    syn_c = np.zeros((A.shape[0], A.shape[0]))
    syn_d = np.zeros((A.shape[1], A.shape[1]))

    for i in range(A.shape[0]):
        for j in range(A.shape[0]):
            if seq_sim[i, j] == 0:
                syn_c[i, j] = GIP_c_sim[i, j]
            else:
                syn_c[i, j] = (GIP_c_sim[i, j] + seq_sim[i, j]) / 2

    for i in range(A.shape[1]):
        for j in range(A.shape[1]):
            if str_sim[i, j] == 0:
                syn_d[i, j] = GIP_d_sim[i, j]
            else:
                syn_d[i, j] = (GIP_d_sim[i, j] + str_sim[i, j]) / 2

    return syn_c, syn_d


def sim_thresholding(matrix: np.ndarray, threshold):
    matrix_copy = matrix.copy()
    matrix_copy[matrix_copy >= threshold] = 1
    matrix_copy[matrix_copy < threshold] = 0
    print(f"rest links: {np.sum(np.sum(matrix_copy))}")
    return matrix_copy


# ######################################################################################################################


def get_syn_sim_circ_drug(A, seq_sim, str_sim, k1, k2):
    disease_sim1 = str_sim
    circRNA_sim1 = seq_sim

    GIP_c_sim = GIP_kernel(A)
    GIP_d_sim = GIP_kernel(A.T)
    # miRNA_sim1 = GIP_m_sim
    m1 = new_normalization(circRNA_sim1)
    m2 = new_normalization(GIP_c_sim)

    Sm_1 = KNN_kernel(circRNA_sim1, k1)
    Sm_2 = KNN_kernel(GIP_c_sim, k1)
    Pm = circRNA_updating(Sm_1, Sm_2, m1, m2)
    Pm_final = (Pm + Pm.T) / 2

    d1 = new_normalization(disease_sim1)
    d2 = new_normalization(GIP_d_sim)

    Sd_1 = KNN_kernel(disease_sim1, k2)
    Sd_2 = KNN_kernel(GIP_d_sim, k2)
    Pd = disease_updating(Sd_1, Sd_2, d1, d2)
    Pd_final = (Pd + Pd.T) / 2

    return Pm_final, Pd_final


def new_normalization(w):
    m = w.shape[0]
    p = np.zeros([m, m])
    for i in range(m):
        for j in range(m):
            if i == j:
                p[i][j] = 1 / 2
            elif np.sum(w[i, :]) - w[i, i] > 0:
                p[i][j] = w[i, j] / (2 * (np.sum(w[i, :]) - w[i, i]))
    return p


def KNN_kernel(S, k):
    n = S.shape[0]
    S_knn = np.zeros([n, n])
    for i in range(n):
        sort_index = np.argsort(S[i, :])
        for j in sort_index[n - k:n]:
            if np.sum(S[i, sort_index[n - k:n]]) > 0:
                S_knn[i][j] = S[i][j] / (np.sum(S[i, sort_index[n - k:n]]))
    return S_knn


def circRNA_updating(S1, S2, P1, P2):
    P = (P1 + P2) / 2
    dif = 1
    while dif > 0.0000001:
        P111 = np.dot(np.dot(S1, P2), S1.T)
        P111 = new_normalization(P111)
        P222 = np.dot(np.dot(S2, P1), S2.T)
        P222 = new_normalization(P222)
        P1 = P111
        P2 = P222
        P_New = (P1 + P2) / 2
        dif = np.linalg.norm(P_New - P) / np.linalg.norm(P)
        P = P_New
    return P


def disease_updating(S1, S2, P1, P2):
    P = (P1 + P2) / 2
    dif = 1
    while dif > 0.0000001:
        P111 = np.dot(np.dot(S1, P2), S1.T)
        P111 = new_normalization(P111)
        P222 = np.dot(np.dot(S2, P1), S2.T)
        P222 = new_normalization(P222)
        P1 = P111
        P2 = P222
        P_New = (P1 + P2) / 2
        dif = np.linalg.norm(P_New - P) / np.linalg.norm(P)
        P = P_New
    return P


# ====================================================================
def skf_normalization(w):
    row_sum = np.sum(w, axis=0)
    p = (w / row_sum).T
    return p


def skf(A, seq_sim, str_sim, k1, k2):
    disease_sim1 = str_sim
    circRNA_sim1 = seq_sim

    GIP_c_sim = GIP_kernel(A)
    GIP_d_sim = GIP_kernel(A.T)
    # miRNA_sim1 = GIP_m_sim
    m1 = skf_normalization(circRNA_sim1)
    m2 = skf_normalization(GIP_c_sim)

    Sm_1 = KNN_kernel(circRNA_sim1, k1)
    Sm_2 = KNN_kernel(GIP_c_sim, k1)
    Pm = skf_updating(Sm_1, Sm_2, m1, m2, 0.1)
    nei_weight1 = neighborhood_Com(Pm, k1)
    Pm_final = Pm * nei_weight1

    d1 = skf_normalization(disease_sim1)
    d2 = skf_normalization(GIP_d_sim)

    Sd_1 = KNN_kernel(disease_sim1, k2)
    Sd_2 = KNN_kernel(GIP_d_sim, k2)
    Pd = skf_updating(Sd_1, Sd_2, d1, d2, 0.1)
    nei_weight2 = neighborhood_Com(Pd, k2)
    Pd_final = Pd * nei_weight2

    return Pm_final, Pd_final


def skf_updating(S1, S2, P1, P2, alpha):
    P = (P1 + P2) / 2
    dif = 1
    while dif > 0.0000001:
        P111 = alpha * np.dot(np.dot(S1, P2), S1.T) + (1 - alpha) * P2
        P111 = new_normalization(P111)
        P222 = alpha * np.dot(np.dot(S2, P1), S2.T) + (1 - alpha) * P1
        P222 = new_normalization(P222)
        P1 = P111
        P2 = P222
        P_New = (P1 + P2) / 2
        dif = np.linalg.norm(P_New - P) / np.linalg.norm(P)
        P = P_New
    return P


def neighborhood_Com(sim, k):
    weight = np.zeros(sim.shape)

    for i in range(sim.shape[0]):
        iu = sim[i, :]
        iu_list = np.abs(np.sort(-iu))
        iu_nearest_list_end = iu_list[k - 1]
        for j in range(sim.shape[1]):
            ju = sim[:, j]
            ju_list = np.abs(np.sort(-ju))
            ju_nearest_list_end = ju_list[k - 1]

            if sim[i, j] >= iu_nearest_list_end and sim[i, j] >= ju_nearest_list_end:
                weight[i, j] = 1
                weight[j, i] = 1
            elif sim[i, j] < iu_nearest_list_end and sim[i, j] < ju_nearest_list_end:
                weight[i, j] = 0
                weight[j, i] = 0
            else:
                weight[i, j] = 0.5
                weight[j, i] = 0.5

    return weight

In [7]:
def get_all_the_samples(A):
    m,n = A.shape
    pos = []
    neg = []
    for i in range(m):
        for j in range(n):
            if A[i,j] ==1:
                pos.append([i,j,1])
            else:
                neg.append([i,j,0])
    n = len(pos)
    neg_new = random.sample(neg, n)
    tep_samples = pos + neg_new
    samples = random.sample(tep_samples, len(tep_samples))
    samples = random.sample(samples, len(samples))
    samples = np.array(samples)
    return samples

def update_Adjacency_matrix (A, test_samples):
    m = test_samples.shape[0]
    A_tep = A.copy()
    for i in range(m):
        if test_samples[i,2] ==1:
            A_tep [test_samples[i,0], test_samples[i,1]] = 0
    return A_tep

def set_digo_zero(sim, z):
    sim_new = sim.copy()
    n = sim.shape[0]
    for i in range(n):
        sim_new[i][i] = z
    return sim_new

In [8]:
def calculate_evaluation_metrics(pred_mat, pos_edges, neg_edges):
    pos_pred_socres = pred_mat[pos_edges[0], pos_edges[1]]
    neg_pred_socres = pred_mat[neg_edges[0], neg_edges[1]]
    pred_labels = np.hstack((pos_pred_socres, neg_pred_socres))
    true_labels = np.hstack((np.ones(pos_pred_socres.shape[0]), np.zeros(neg_pred_socres.shape[0])))
    return get_metrics_new(true_labels, pred_labels)

def get_metrics_new(real_score, predict_score):
    real_score, predict_score = real_score.flatten(), predict_score.flatten()
    sorted_predict_score = np.array(
        sorted(list(set(np.array(predict_score).flatten()))))
    sorted_predict_score_num = len(sorted_predict_score)
    thresholds = sorted_predict_score[np.int32(
        sorted_predict_score_num*np.arange(1, 1000)/1000)]
    thresholds = np.mat(thresholds)
    thresholds_num = thresholds.shape[1]

    predict_score_matrix = np.tile(predict_score, (thresholds_num, 1))
    negative_index = np.where(predict_score_matrix < thresholds.T)
    positive_index = np.where(predict_score_matrix >= thresholds.T)
    predict_score_matrix[negative_index] = 0
    predict_score_matrix[positive_index] = 1
    TP = predict_score_matrix.dot(real_score.T)
    FP = predict_score_matrix.sum(axis=1)-TP
    FN = real_score.sum()-TP
    TN = len(real_score.T)-TP-FP-FN

    fpr = FP/(FP+TN)
    tpr = TP/(TP+FN)
    ROC_dot_matrix = np.mat(sorted(np.column_stack((fpr, tpr)).tolist())).T
    ROC_dot_matrix.T[0] = [0, 0]
    ROC_dot_matrix = np.c_[ROC_dot_matrix, [1, 1]]

    # np.savetxt(roc_path.format(i), ROC_dot_matrix)

    x_ROC = ROC_dot_matrix[0].T
    y_ROC = ROC_dot_matrix[1].T
    auc = 0.5*(x_ROC[1:]-x_ROC[:-1]).T*(y_ROC[:-1]+y_ROC[1:])

    recall_list = tpr
    precision_list = TP/(TP+FP)
    PR_dot_matrix = np.mat(sorted(np.column_stack(
        (recall_list, precision_list)).tolist())).T
    PR_dot_matrix.T[0] = [0, 1]
    PR_dot_matrix = np.c_[PR_dot_matrix, [1, 0]]

    # np.savetxt(pr_path.format(i), PR_dot_matrix)

    x_PR = PR_dot_matrix[0].T
    y_PR = PR_dot_matrix[1].T
    aupr = 0.5*(x_PR[1:]-x_PR[:-1]).T*(y_PR[:-1]+y_PR[1:])

    f1_score_list = 2*TP/(len(real_score.T)+TP-TN)
    accuracy_list = (TP+TN)/len(real_score.T)
    specificity_list = TN/(TN+FP)
    # plt.plot(x_ROC, y_ROC)
    # plt.plot(x_PR,y_PR)
    # plt.show()
    max_index = np.argmax(f1_score_list)
    f1_score = f1_score_list[max_index]
    accuracy = accuracy_list[max_index]
    specificity = specificity_list[max_index]
    recall = recall_list[max_index]
    precision = precision_list[max_index]
    print( ' auc:{:.4f} ,aupr:{:.4f},f1_score:{:.4f}, accuracy:{:.4f}, recall:{:.4f}, specificity:{:.4f}, precision:{:.4f}'.format( auc[0, 0],aupr[0, 0], f1_score, accuracy, recall, specificity, precision))
    return [auc[0, 0], aupr[0, 0], f1_score, accuracy, recall, specificity, precision]


In [9]:

def get_edge_index_f(matrix, threshold):
    edge_index = [[], []]
    edge_type = []    
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            if matrix[i, j] >= 0.001:                
                edge_index[0].append(i)
                edge_index[1].append(j)
                if (i < threshold) and (j < threshold):                
                    edge_type.append(0)
                elif (i >= threshold) and (j < threshold):                
                    edge_type.append(1)
                elif (i < threshold) and (j >= threshold):                
                    edge_type.append(2)
                else:
                    edge_type.append(3)
    return torch.LongTensor(edge_index), torch.LongTensor(edge_type)

In [10]:

def get_edge_index(matrix, new_A, threshold):
    edge_index = [[], []]
    edge_type = []
    threshold_list = new_A.sum(1)
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            if matrix[i, j] >= 0.001:                
                edge_index[0].append(i)
                edge_index[1].append(j)
                if (threshold_list[i] <= threshold) and (threshold_list[j] <= threshold):                
                    edge_type.append(0)
                elif (threshold_list[i] <= threshold) and (threshold_list[j] > threshold):                
                    edge_type.append(1)
                elif (threshold_list[i] > threshold) and (threshold_list[j] <= threshold):                
                    edge_type.append(2)
                else:
                    edge_type.append(3)
    return torch.LongTensor(edge_index), torch.LongTensor(edge_type)

In [11]:
def get_edge_index_h(matrix, threshold_m, threshold_d):
    edge_index = [[], []]
    edge_type = []
    threshold_list_m = matrix.sum(1)
    threshold_list_d = matrix.sum(0)
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            if matrix[i, j] != 0:                
                edge_index[0].append(i)
                edge_index[1].append(j)
                if (threshold_list_m[i] <= threshold_m) and (threshold_list_d[j] <= threshold_d):                
                    edge_type.append(0)
                elif (threshold_list_m[i] <= threshold_m) and (threshold_list_d[j] > threshold_d):                
                    edge_type.append(1)
                elif (threshold_list_m[i] > threshold_m) and (threshold_list_d[j] <= threshold_d):                
                    edge_type.append(2)
                else:
                    edge_type.append(3)
    return torch.LongTensor(edge_index), torch.LongTensor(edge_type)



In [12]:
class DGATConv(MessagePassing):
    def __init__(self, in_hid, out_hid, 
                 num_edge_types,negative_slope=0.2,dual=True,heads=1,mask=None,global_weight=True):
        super(DGATConv, self).__init__(aggr='add')

        self.in_hid = in_hid
        self.out_hid = out_hid
        self.num_edge_types = num_edge_types
        self.negative_slope=negative_slope
        self.dual=dual
        self.mask=mask
        self.global_weight=global_weight
        
        self.rel_wi=nn.Parameter(torch.Tensor(num_edge_types,out_hid*2,1))

        self.rel_bt=nn.Parameter(torch.Tensor(out_hid*2,1))
        self.w_wi=nn.Linear(in_hid, out_hid, bias=False)
        self.w_bt=nn.Linear(out_hid,out_hid,bias=False)

        self.w_out=nn.Linear(out_hid,out_hid,bias=False)
        self.q_trans=nn.Parameter(torch.Tensor(out_hid,1))

        self.norm=nn.LayerNorm(out_hid)
        self.norm_list=nn.ModuleList()
        for i in range(num_edge_types):
            self.norm_list.append(nn.LayerNorm(out_hid))


        self.skip = nn.Parameter(torch.ones(1))
        self.beta_weight=nn.Parameter(torch.ones(1))
        self.overall_beta=nn.Parameter(torch.randn(num_edge_types))
        # self.drop=Dropout(0.2)

        glorot(self.rel_wi)
        glorot(self.rel_bt)
        glorot(self.q_trans)



    def forward(self, x, edge_idx, edge_type):

        x=self.w_wi(x)
        out_list=[]
        edg_list=[]
        overall_rel=[]
        for i in range(self.num_edge_types):
            mask = (edge_type == i)
            edge_index = edge_idx[:, mask]
            if mask.sum() !=0:
                rs=self.w_bt(F.leaky_relu(self.norm_list[i](self.propagate(edge_index, x=x,edge_type=i)),self.negative_slope))   #Nxd
                out_list+=[rs]
                edg_list+=[i]
     
            
        if self.dual:
            overall_beta=F.softmax(self.overall_beta,dim=0)

            rs_list=[]
            for i in range(len(edg_list)):
                conc=torch.cat((x,out_list[i]),dim=1)                                       #Nx2d
                rs=torch.matmul(conc,self.rel_bt)         #Nx1
                rs_list+=[rs]

            rs=torch.stack(rs_list)                                                         #rxNx1
            beta=F.softmax(rs,dim=0)                                                        #rxNx1
            res=0
            if self.mask:
                for i in self.mask:
                    out_list[i]=torch.zeros_like(out_list[i])
            beta_weight=torch.sigmoid(self.beta_weight)
            for i in range(len(edg_list)):
                if self.global_weight:
                    res+=out_list[i]*((1-beta_weight)*beta[i]+beta_weight*overall_beta[i])
                else:
                    res+=out_list[i]*beta[i]
        else:
            res=0
            for i in range(len(edg_list)):
                res+=out_list[i]
        
        final_weight=torch.sigmoid(self.skip)
        res = self.norm(F.gelu(res) * (final_weight) + x* (1 - final_weight))

        return res


    def message(self,edge_index,x_i, x_j,edge_type):
        
        node_f = torch.cat((x_i, x_j), 1)                                       #nx2d

        temp = torch.matmul(node_f, self.rel_wi[edge_type]).to(x_i.device)      #nx1

        alpha=softmax(temp,edge_index[1])

        rs=x_j*alpha                                                            #nxd
        return rs


sft = torch.nn.Softmax(dim=0)


class HetGATConv(MessagePassing):
    def __init__(self, in_hid, out_hid, negative_slope=0.2,norm=True,dual=True,global_weight=True):
        super(HetGATConv, self).__init__(aggr='add')

        self.in_hid = in_hid
        self.out_hid = out_hid
        self.negative_slope=negative_slope
        self.norm=norm
        self.dual=dual
        self.global_weight=global_weight

        
        self.rel_wi=nn.Parameter(torch.Tensor(4,out_hid*2,1))
        self.rel_bt=nn.Parameter(torch.Tensor(out_hid*2,1))
        self.w_bt=nn.Linear(out_hid,out_hid,bias=False)
        self.w_out=nn.Linear(out_hid,out_hid,bias=False)

        self.out_norm=nn.LayerNorm(out_hid)

        self.skip = nn.Parameter(torch.ones(1))

        glorot(self.rel_wi)
        glorot(self.rel_bt)
        

    def forward(self, a_hid,p_hid, edge_idx, edge_type):

        xi=p_hid[edge_idx[1]]
        out_list=[]
        num_edge_types=4
        edg_list=[]
        for i in range(num_edge_types):
            mask = (edge_type == i)
            edge_index = edge_idx[:, mask]
            if mask.sum() !=0:
                rs=self.w_bt(F.leaky_relu(self.propagate(edge_index, x=(a_hid,p_hid),edge_type=i),self.negative_slope))   #Nxd
                out_list+=[rs]
                edg_list+=[i]
            
        if self.dual:
            rs_list=[]
            for i in range(len(edg_list)):
                conc=torch.cat((p_hid,out_list[i]),dim=1)                                       #Nx2d
                rs=torch.matmul(conc,self.rel_bt)                                             #Nx1
                rs_list+=[rs]

            rs=torch.stack(rs_list)                                                         #Nxr
            beta=F.softmax(rs,dim=0)                                                        #Nxr
            res=0
            for i in range(len(edg_list)):
                res+=out_list[i]*beta[i]
        else:
            res=0
            for i in range(len(edg_list)):
                res+=out_list[i]
        final_weight=torch.sigmoid(self.skip)
        res = self.out_norm(F.gelu(res)* (final_weight) + p_hid* (1 - final_weight))

        return res


    def message(self,edge_index,x_i, x_j,edge_type):
        
        node_f = torch.cat((x_i, x_j), 1)                                       #nx2d

        temp = torch.matmul(node_f, self.rel_wi[edge_type]).to(x_i.device)      #nx1

        alpha=softmax(temp,edge_index[1])

        rs=x_j*alpha                                                            #nxd
        return rs




In [13]:

def laplacian(kernel):
    d1 = sum(kernel)
    D_1 = torch.diag(d1)
    #I_M = t.diag(t.ones(len(d1)))
    L_D_1 = D_1 - kernel# + I_M
    D_5 = D_1.rsqrt()
    D_5 = torch.where(torch.isinf(D_5), torch.full_like(D_5, 0), D_5)
    L_D_11 = torch.mm(D_5, L_D_1)
    L_D_11 = torch.mm(L_D_11, D_5)
    return L_D_11


def normalized_embedding(embeddings):
    #[row, col] = embeddings.size()
    [row, col] = embeddings.shape
    ne = torch.zeros([row, col])
    for i in range(row):
        if (max(embeddings[i, :]) - min(embeddings[i, :])) != 0:
            ne[i, :] = (embeddings[i, :] - min(embeddings[i, :])) / (max(embeddings[i, :]) - min(embeddings[i, :]))
        else:
            ne[i, :] = (embeddings[i, :] - min(embeddings[i, :]))
        #ne[i, :] = (embeddings[i, :] - t.mean(embeddings[i, :])) / (t.std(embeddings[i, :]))
    return ne


def getGipKernel(y, trans, gamma, normalized=False):
    if trans:
        y = y.T
    if normalized:
        y = normalized_embedding(y)
    krnl = torch.mm(y, y.T)
    krnl = krnl / torch.mean(torch.diag(krnl))
    krnl = torch.exp(-kernelToDistance(krnl) * gamma)
    #krnl = cosine_kernel(krnl, krnl)
    return krnl


def kernelToDistance(k):
    di = torch.diag(k).T
    d = di.repeat(len(k)).reshape(len(k), len(k)).T + di.repeat(len(k)).reshape(len(k), len(k)) - 2 * k
    return d


def cosine_kernel(tensor_1, tensor_2):
    return torch.DoubleTensor([torch.cosine_similarity(tensor_1[i], tensor_2, dim=-1).tolist() for i in
                           range(tensor_1.shape[0])])


def normalized_kernel(K):
    K = abs(K)
    k = K.flatten().sort()[0]
    min_v = k[torch.nonzero(k, as_tuple=False)[0]]
    K[torch.where(K == 0)] = min_v
    D = torch.diag(K)
    D = D.sqrt()
    S = K / (D * D.T)
    return S

In [14]:
def construct_het_mat(rna_dis_mat, dis_mat, rna_mat):
    mat1 = np.hstack((rna_mat, rna_dis_mat))
    mat2 = np.hstack((rna_dis_mat.T, dis_mat))
    ret = np.vstack((mat1, mat2))
    return ret

def construct_adj_mat(training_mask):
    adj_tmp = training_mask.copy()
    rna_mat = np.zeros((training_mask.shape[0], training_mask.shape[0]))
    dis_mat = np.zeros((training_mask.shape[1], training_mask.shape[1]))

    mat1 = np.hstack((rna_mat, adj_tmp))
    mat2 = np.hstack((adj_tmp.T, dis_mat))
    ret = np.vstack((mat1, mat2))
    return ret

In [15]:
class Mylossw(nn.Module):
    def __init__(self):
        super(Mylossw, self).__init__()
        
    def forward(self, target, prediction, miRNA_lap, dis_lap, alpha1, alpha2, phi1, phi2):
        
        loss_ls = torch.norm((target - prediction), p='fro') ** 2        
        loss_ls = loss_ls.sum() 
        miRNA_reg = torch.trace(torch.mm(torch.mm(alpha1.T, miRNA_lap), alpha1))
        dis_reg = torch.trace(torch.mm(torch.mm(alpha2.T, dis_lap), alpha2))
        graph_reg = phi1 * miRNA_reg + phi2 * dis_reg

        loss_sum = loss_ls + graph_reg

        return loss_sum.sum()


In [16]:
class HRGATConv(nn.Module):
    def __init__(self,in_hid_1, in_hid_2, out_hid,num_m1,num_m2,conv_name="hrgat",n_heads=8,n_layers=2,
                 dropout=0.2,norm=True,hgt_layer=2, feature_MFm = None, feature_MFd = None, sim_m = None, 
                 sim_d = None,  gamma = 1/128, phi1 = 1/256, phi2 = 1/256, **kwargs):
        super(HRGATConv,self).__init__()
        self.conv_name=conv_name
        self.hetgat=nn.ModuleList()
        self.layer=n_layers
        self.hgt_layer = hgt_layer

        self.feature_MFm = torch.Tensor(feature_MFm)
        self.feature_MFd = torch.Tensor(feature_MFd)
        self.sim_m, self.sim_d = torch.Tensor(sim_m), torch.Tensor(sim_d)
        self.miRNA_size, self.dis_size = sim_m.shape[0], sim_d.shape[0]
        
        self.gamma = gamma
        self.phi1, self.phi2 = phi1, phi2

        self.miRNA_l = []
        self.dis_l = []

        self.miRNA_k = []
        self.dis_k = []
        
        self.alpha1 = torch.randn(self.miRNA_size, self.dis_size).double()
        self.alpha2 = torch.randn(self.dis_size, self.miRNA_size).double()        

        self.hgt=nn.ModuleList()
        self.norm=nn.LayerNorm(out_hid)
        self.drop=Dropout(dropout)
        self.in_hid = 32
        self.proj_a=gLinear(in_hid_2,out_hid,weight_initializer="kaiming_uniform", bias=True)
        self.proj_p=gLinear(in_hid_1,out_hid,weight_initializer="kaiming_uniform", bias=True)
        
        self.CNN_drug = nn.Conv2d(in_channels = hgt_layer + n_layers,
                       out_channels=128,
                       kernel_size=(16, 1),
                       stride=1,
                       bias=True)
        self.CNN_cir = nn.Conv2d(in_channels = hgt_layer + n_layers,
                      out_channels=128,
                      kernel_size=(16, 1),
                      stride=1,
                      bias=True) 
        

        for _ in range(hgt_layer):
            if _ == 0:
                if self.conv_name == "rgcn":
                    self.hgt.append(RGCNConv(in_hid_1, out_hid, 4))
                    self.hgt.append(RGCNConv(in_hid_2, out_hid, 4))
                elif self.conv_name == "dgat":
                    self.hgt.append(DGATConv(in_hid_1, out_hid, num_m1,heads=n_heads))
                    self.hgt.append(DGATConv(in_hid_2, out_hid, num_m2,heads=n_heads))
            else:
                if self.conv_name == "rgcn":
                    self.hgt.append(RGCNConv(out_hid, out_hid, 4))
                    self.hgt.append(RGCNConv(out_hid, out_hid, 4))
                elif self.conv_name == "dgat":
                    self.hgt.append(DGATConv(out_hid, out_hid, num_m1,heads=n_heads))
                    self.hgt.append(DGATConv(out_hid, out_hid, num_m2,heads=n_heads))

        if self.conv_name == "dhan1":
            for n in range(n_layers):
                self.hetgat.append(HetGATConv(out_hid, out_hid,dual=False))
                self.hetgat.append(HetGATConv(out_hid, out_hid,dual=False))
        else:
            for n in range(n_layers):
                self.hetgat.append(HetGATConv(out_hid, out_hid))
                self.hetgat.append(HetGATConv(out_hid, out_hid))
    # _gh=(node_feature, edge_index, edge_type, id_list)
        self.lenw = int(len(self.hgt)/2+len(self.hetgat)/2+2)
        self.weight_mlp = (torch.ones(2, self.lenw)/self.lenw).double()
        
        self.params1 = list(self.hgt.parameters())
        self.params2 = list(self.hetgat.parameters())
        #self.params3 = list(self.proj_a.parameters())
        #self.params4 = list(self.proj_p.parameters()) 

    def forward(self,m_f, d_f, edge_index_m,edge_index_d, edge_index_h, edge_type_m, edge_type_d, edge_type_h, device):

        miRNA_kernels = []
        dis_kernels = []
        tem_m = []
        tem_d = []
        
        miRNA_kernels.append(self.sim_m)
        dis_kernels.append(self.sim_d)         
                        
        miRNA_kernels.append(torch.DoubleTensor(getGipKernel(self.feature_MFm, 0, self.gamma, True).double()))
        dis_kernels.append(torch.DoubleTensor(getGipKernel(self.feature_MFd, 0, self.gamma, True).double()))

        #h_p_cnn = self.proj_p(m_f.to(device))
        #h_a_cnn = self.proj_a(d_f.to(device))

        h_p=m_f.to(device)
        h_a=d_f.to(device)
        for hl in range(int((len(self.hgt)/2))):
            if hl==0:
                h_p=self.hgt[2*hl](h_p, edge_index_m.to(device), edge_type_m.to(device))
                h_a=self.hgt[2*hl+1](h_a, edge_index_d.to(device), edge_type_d.to(device))
                miRNA_kernels.append(torch.DoubleTensor(getGipKernel(h_p, 0, self.gamma, True).double()))
                dis_kernels.append(torch.DoubleTensor(getGipKernel(h_a, 0, self.gamma, True).double()))
                tem_m.append(h_p)
                tem_d.append(h_a)
            else:
                h_p = self.hgt[2 * hl](h_p.to(device), edge_index_m.to(device), edge_type_m.to(device))
                h_a = self.hgt[2 * hl + 1](h_a.to(device), edge_index_d.to(device), edge_type_d.to(device))
                miRNA_kernels.append(torch.DoubleTensor(getGipKernel(h_p, 0, self.gamma, True).double()))
                dis_kernels.append(torch.DoubleTensor(getGipKernel(h_a, 0, self.gamma, True).double()))
                tem_m.append(h_p)
                tem_d.append(h_a)

        #h_p_cnn = self.proj_p(m_f.to(device))
        #h_a_cnn = self.proj_a(d_f.to(device))
        for ly in range(int(len(self.hetgat)/2)):
            edge_index_h = torch.stack((edge_index_h[1],edge_index_h[0]))
            p_hid = self.hetgat[2*ly](h_a, h_p, edge_index_h.to(device), edge_type_h.to(device))
            miRNA_kernels.append(torch.DoubleTensor(getGipKernel(p_hid, 0, self.gamma, True).double()))

            edge_index_h = torch.stack((edge_index_h[1],edge_index_h[0]))
            a_hid = self.hetgat[2*ly+1](h_p, h_a, edge_index_h.to(device), edge_type_h.to(device))
            dis_kernels.append(torch.DoubleTensor(getGipKernel(a_hid, 0, self.gamma, True).double()))
            tem_m.append(p_hid)
            tem_d.append(a_hid)

            #edge_indx_h = torch.stack((edge_index_h[1], edge_index_h[0]))
            h_a=a_hid
            h_p=p_hid
        #h_a=self.drop(self.norm(h_a))
        #h_p=self.drop(self.norm(h_p))
        #A_pre = h_p@h_a.T
        #cnn_embd_cir = h_p_cnn
        """
        cnn_embd_cir = tem_m[0]
        for i in range(1, len(tem_m)):            
            cnn_embd_cir = torch.cat((cnn_embd_cir, tem_m[i]), 1)
        cnn_embd_cir = cnn_embd_cir.t().view(1, len(tem_m), 16, self.miRNA_size)
        cnn_embd_cir = self.CNN_cir(cnn_embd_cir)
        cnn_embd_cir = cnn_embd_cir.view(128, self.miRNA_size).t()
        h_p = cnn_embd_cir
        
        cnn_embd_drug = tem_d[0]       
        for i in range(1, len(tem_d)):            
            cnn_embd_drug = torch.cat((cnn_embd_drug, tem_d[i]), 1)
        cnn_embd_drug = cnn_embd_drug.t().view(1, len(tem_d), 16, self.dis_size)
        cnn_embd_drug = self.CNN_drug(cnn_embd_drug)
        cnn_embd_drug = cnn_embd_drug.view(128, self.dis_size).t()
        h_a = cnn_embd_drug
        """
        #miRNA_kernels.append(torch.DoubleTensor(getGipKernel(h_p, 0, self.gamma, True).double()))
        #dis_kernels.append(torch.DoubleTensor(getGipKernel(h_a, 0, self.gamma, True).double()))
        
        miRNA_k = sum([(1/len(miRNA_kernels)) * miRNA_kernels[i] for i in range(len(miRNA_kernels))])
        self.miRNA_k = normalized_kernel(miRNA_k)
        dis_k = sum([(1/len(dis_kernels)) * dis_kernels[i] for i in range(len(dis_kernels))])        
        self.dis_k = normalized_kernel(dis_k)
        
        self.miRNA_l = laplacian(miRNA_k)
        self.dis_l = laplacian(dis_k)

        out1 = torch.mm(self.miRNA_k, self.alpha1)
        out2 = torch.mm(self.dis_k, self.alpha2)

        #out = (out1 + out2.T + h_p@h_a.T) / 3 
        out = (out1 + out2.T) / 2

        #return A_pre
        return out, miRNA_kernels, dis_kernels
        #return h_p@h_a.T, 0, 0

In [17]:
def k_matrix(matrix, k=20):
    num = matrix.shape[0]
    knn_graph = np.zeros(matrix.shape)
    idx_sort = np.argsort(-(matrix - np.eye(num)), axis=1)
    for i in range(num):
        #将第i行最大的前k个值赋值给knn_graph(确保是对称矩阵)
        knn_graph[i, idx_sort[i, :k + 1]] = matrix[i, idx_sort[i, :k + 1]]
        knn_graph[idx_sort[i, :k + 1], i] = matrix[idx_sort[i, :k + 1], i]
    return knn_graph + np.eye(num)


In [18]:

def cross_validation_experiment_3(edge_idx_dict, A, k_fold = 5, k1 = 27, k2 = 22, seq_sim_matrix = None, str_sim_matrix = None,
                  lr = 0.05, weight_decay = 0.0005, threshold_m = 15, threshold_d = 19,
                  dropout = 0.1, device = device, hgt_layer = 1, k = 25, epoch = 50,
                  n_layers = 1, n_heads = 1, conv_name = 'dgat', num_m1 = 4, num_m2 = 4, out_hid = 16,
                  gamma = 1/128, phi1 = 1/100, phi2 = 1/100):
    #A = miRNA_dis_matrix
        
    metric = np.zeros((1, 7))
    pre_matrix = np.zeros(A.shape)
    
    pos_edges = edge_idx_dict['pos_edges']
    neg_edges = edge_idx_dict['neg_edges']
    idx = np.arange(pos_edges.shape[1])
    np.random.shuffle(idx)
    idx_splited = np.array_split(idx, k_fold)        
       
    
    for i in range(k_fold):
        print("------this is %dth cross validation------" % (i + 1))
        tmp = []
        for j in range(1, k_fold):
            tmp.append(idx_splited[(j + i) % k_fold])
        tmp = np.concatenate(tmp)
        training_pos_edges = pos_edges[:, tmp]
        print("training_pos_edges.shape: {}".format(training_pos_edges.shape))
        training_neg_edges = neg_edges[:, tmp]
        test_pos_edges = pos_edges[:, idx_splited[i]]
        test_neg_edges = neg_edges[:, idx_splited[i]]
        temp_drug_dis = np.zeros((A.shape[0], A.shape[1]))
        temp_drug_dis[training_pos_edges[0], training_pos_edges[1]] = 1        
        new_A = temp_drug_dis
        sim_d, sim_m = skf(new_A, str_sim_matrix, seq_sim_matrix, k2, k1)
        sim_m_0 = set_digo_zero(sim_m, 0)
        sim_d_0 = set_digo_zero(sim_d, 0)        
        
        m_adj = k_matrix(sim_m, k = k)
        d_adj = k_matrix(sim_d, k = k)
        m_adj_0 = set_digo_zero(m_adj, 1)
        d_adj_0 = set_digo_zero(d_adj, 1)
        
        #trainlu = np.hstack((training_neg_edges, training_pos_edges))
        #testlu = np.hstack((test_neg_edges, test_pos_edges))
        #trainlu = np.vstack((trainlu[1], trainlu[0]))
        #testlu = np.vstack((testlu[1], testlu[0]))
                

        feature_MFd1, feature_MFm1 = d_adj@new_A, m_adj@new_A.T
        #feature_MFd1, feature_MFm1 = torch.Tensor(d_adj)@feature_d_Z.detach(), torch.Tensor(m_adj)@feature_m_Z.detach()
        feature_MFm1, feature_MFd1 = torch.Tensor(feature_MFm1),torch.Tensor(feature_MFd1)
        #feature_MFd1, feature_MFm1 = feature_d_Z.detach(), feature_m_Z.detach()
        
        gnn = None
        
        gnn = HRGATConv(in_hid_1 = feature_MFm1.shape[1], in_hid_2 = feature_MFd1.shape[1], out_hid = out_hid, num_m1 = num_m1, 
                        num_m2 = num_m2, conv_name = conv_name, n_heads= n_heads, n_layers = n_layers, 
                        dropout = dropout, hgt_layer = hgt_layer, feature_MFm = feature_MFm1, feature_MFd = feature_MFd1, 
                        sim_m = sim_m, sim_d = sim_d, gamma = gamma, phi1 = phi1, phi2 = phi2).to(device)


        
        optimizer = optim.Adam([{'params':gnn.params1, 'weight_decay':0.001},{'params':gnn.params2, 'weight_decay':0.001},], lr = lr)
        #optimizer = optim.Adam(gnn.parameters(), weight_decay = weight_decay, lr = lr)
        
        
        regression_critgnn = Mylossw()
        #regression_critMLP = nn.BCEWithLogitsLoss()
        

        edge_index_m, edge_type_m = get_edge_index(m_adj, new_A.T, threshold = threshold_m)
        edge_index_d, edge_type_d = get_edge_index(d_adj, new_A, threshold = threshold_d)
        
        edge_index_h, edge_type_h = get_edge_index_h(new_A.T, threshold_m = threshold_m, threshold_d = threshold_d)

        new_A = new_A.T
        for i in range(epoch):
            gnn.zero_grad()    
            
            PRE, m_k, d_k = gnn(feature_MFm1, feature_MFd1, edge_index_m,edge_index_d, edge_index_h, 
                  edge_type_m, edge_type_d, edge_type_h, device)      
            
            #"""
            gnn.alpha1 = torch.mm(
            torch.mm((torch.mm(gnn.miRNA_k, gnn.miRNA_k) + gnn.phi1 * gnn.miRNA_l).inverse(), 
                 gnn.miRNA_k),2 * torch.Tensor(new_A) - torch.mm(gnn.alpha2.T, gnn.dis_k.T)).detach()
            gnn.alpha2 = torch.mm(torch.mm((torch.mm(gnn.dis_k, gnn.dis_k) + gnn.phi2 * gnn.dis_l).inverse(), gnn.dis_k),
                  2 * torch.Tensor(new_A).T - torch.mm(gnn.alpha1.T, gnn.miRNA_k.T)).detach()
    
    
    
            loss = regression_critgnn(torch.Tensor(new_A), PRE, gnn.miRNA_l, gnn.dis_l, gnn.alpha1,
                           gnn.alpha2, gnn.phi1, gnn.phi2)
            #"""
            #loss = regression_critMLP(PRE[tuple(np.array(trainlu))], torch.Tensor(new_A[tuple(np.array(trainlu))]))
                    
            loss = loss.requires_grad_()
            
            loss.backward()
            optimizer.step()
            #scheduler.step()
            if i%100 == 0:
                print(loss)


        new_A = new_A.T
        gnn.eval()
        
        PRE_test, m_k_t, d_k_t = gnn(feature_MFm1, feature_MFd1, edge_index_m,edge_index_d, edge_index_h, 
                  edge_type_m, edge_type_d, edge_type_h, device) 

        #INN0 = PRE_test.detach().numpy()
        metric_tmp = calculate_evaluation_metrics(PRE_test.detach().T, test_pos_edges, test_neg_edges)

        print(metric_tmp)
        metric = metric + metric_tmp
        gc.collect()

    print("{} coval".format(k_fold))
    print("auc[0, 0], aupr[0, 0], f1_score, accuracy, recall, specificity, precision")
    print(metric / k_fold)
    metric = np.array(metric / k_fold)
    return metric



In [19]:
k1 = 27
k2 = 22
drug_sim, cir_sim, edge_idx_dict, drug_cir_matrix = load_data()
A = drug_cir_matrix

In [94]:
result = cross_validation_experiment_3(edge_idx_dict, A, k_fold = 5, k1 = 25, k2 = 25, seq_sim_matrix = cir_sim, 
                  str_sim_matrix = drug_sim,
                  lr = 0.05, weight_decay = 0.01, threshold_m = 18, threshold_d = 37,
                  dropout = 0.026, device = device, hgt_layer = 1, k = 25, epoch = 40,
                  n_layers = 1, n_heads = 5, conv_name = 'dgat', num_m1 = 4, num_m2 = 4, out_hid = 16,
                  gamma = 1/75, phi1 = 1/120, phi2 = 1/120)

------this is 1th cross validation------
training_pos_edges.shape: (2, 3307)
tensor(6945137.8885, dtype=torch.float64, grad_fn=<SumBackward0>)
 auc:0.9309 ,aupr:0.9362,f1_score:0.8618, accuracy:0.8561, recall:0.8972, specificity:0.8150, precision:0.8291
[0.9308700172093883, 0.9361883821468933, 0.8617886178861789, 0.8561064087061668, 0.8972188633615478, 0.814993954050786, 0.829050279329609]
------this is 2th cross validation------
training_pos_edges.shape: (2, 3307)
tensor(7048559.8945, dtype=torch.float64, grad_fn=<SumBackward0>)
 auc:0.9014 ,aupr:0.9135,f1_score:0.8330, accuracy:0.8349, recall:0.8235, specificity:0.8464, precision:0.8428
[0.9013501401461261, 0.9134607314176039, 0.8330275229357799, 0.8349455864570737, 0.8234582829504232, 0.8464328899637243, 0.8428217821782178]
------this is 3th cross validation------
training_pos_edges.shape: (2, 3307)
tensor(7368466.3480, dtype=torch.float64, grad_fn=<SumBackward0>)
 auc:0.9147 ,aupr:0.9219,f1_score:0.8453, accuracy:0.8416, recall:0.8

In [None]:
"""
lr = 0.05, weight_decay = 0.01, threshold_m = 15, threshold_d = 19,dropout = 0.02, device = device, hgt_layer = 1, 
k = 25, epoch = 50,n_layers = 1, n_heads = 1, conv_name = 'dgat', num_m1 = 4, num_m2 = 4, out_hid = 16,
gamma = 1/128, phi1 = 1/100, phi2 = 1/100
[[0.91079151 0.91802955 0.84281533 0.84300558 0.84083373 0.84517744  0.84550056]]

lr = 0.05, weight_decay = 0.01, threshold_m = 15, threshold_d = 19,dropout = 0.02, device = device, hgt_layer = 1, 
k = 25, epoch = 50,n_layers = 1, n_heads = 1, conv_name = 'dgat', num_m1 = 4, num_m2 = 4, out_hid = 16,
gamma = 1/128, phi1 = 1/100, phi2 = 1/100  set_digo_zero(d_adj, 1)
[[0.90982799 0.91853294 0.84172715 0.84167972 0.84155338 0.84180606  0.84230898]]

lr = 0.05, weight_decay = 0.01, threshold_m = 15, threshold_d = 19,dropout = 0.02, device = device, hgt_layer = 1, 
k = 25, epoch = 50,n_layers = 1, n_heads = 1, conv_name = 'dgat', num_m1 = 4, num_m2 = 4, out_hid = 16,
gamma = 1/128, phi1 = 1/100, phi2 = 1/100  set_digo_zero(d_adj, 1)
[[0.91104823 0.91897384 0.84544202 0.84313045 0.85824225 0.82801866  0.83341387]]

k1 = 25, k2 = 27, lr = 0.05, weight_decay = 0.01, threshold_m = 25, threshold_d = 36, dropout = 0.02, hgt_layer = 1, 
k = 25, epoch = 60, n_layers = 1, n_heads = 5, conv_name = 'dgat', num_m1 = 4, num_m2 = 4, out_hid = 16,
gamma = 1/75, phi1 = 1/100, phi2 = 1/100





"""