In [119]:
import time
import numpy as np
import csv
from numpy import genfromtxt, interp
import pandas as pd
import random

import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.model_selection import KFold
from scipy import interpolate


import torch
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
from torch import tensor
import torch_geometric.utils
from torch_geometric.nn.conv import MessagePassing
import dgl

import os

In [121]:
def nSGCConv(graph, feats, order):
    with graph.local_scope():
        degs = graph.in_degrees().float().clamp(min=1)
        norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
        graph.ndata['norm'] = norm
        graph.apply_edges(fn.u_mul_v('norm', 'norm', 'weight'))
        x = feats
        # x = F.dropout(feats, p=0.5)
        y = 0 + feats
        for i in range(order):
            graph.ndata['h'] = x
            graph.update_all(fn.u_mul_e('h', 'weight', 'm'), fn.sum('m', 'h'))
            x = graph.ndata.pop('h')
            y = torch.cat((y, x), dim=1)
    return y



def pretreatment(graph, feats):
    with graph.local_scope():
        row = graph.edges()[0]
        col = graph.edges()[1]
        edge_index = torch.vstack((row, col))
        deg = torch_geometric.utils.degree(col, feats.size(0), dtype=feats.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        
    return edge_index, edge_weight


class nSGCN(MessagePassing):
    def __init__(self, add_self_loops: bool = True, normalize: bool = True):
        super(nSGCN, self).__init__()
        self.edge_weight = None

    def reset_parameters(self):
        pass

    def forward(self, graph, feats, drop_rate, order):
        edge_index, self.edge_weight = pretreatment(graph, feats)
        y = self.propagate(edge_index=edge_index, size=None, x=feats, drop_rate=drop_rate)
        return y

    def message(self, x_j, drop_rate: float):
        # normalize
        if self.edge_weight is not None:
            x_j = x_j * self.edge_weight.view(-1, 1)
        if not self.training:
            return x_j
        # drop messages
        x_j = F.dropout(x_j, drop_rate)
        return x_j


class nSGC(nn.Module):
    def __init__(self, G, hid_dim, n_class, K, batchnorm, num_diseases, num_mirnas,
                 d_sim_dim, m_sim_dim, out_dim, dropout, slope, node_dropout=0.5, input_droprate=0.0,
                 hidden_droprate=0.0):
        super(nSGC, self).__init__()
        self.G = G
        self.hid_dim = hid_dim

        self.K = K
        self.n_class = n_class
        self.num_diseases = num_diseases
        self.num_mirnas = num_mirnas
        self.disease_nodes = G.filter_nodes(lambda nodes: nodes.data['type'] == 1)
        self.mirna_nodes = G.filter_nodes(lambda nodes: nodes.data['type'] == 0)

        self.m_fc = nn.Linear(G.ndata['m_sim'].shape[1], hid_dim, bias=False)
        self.d_fc = nn.Linear(G.ndata['d_sim'].shape[1], hid_dim, bias=False)
        self.f_fc = nn.Linear(out_dim * (K + 1), out_dim)
        self.f_fc2 = nn.Linear(out_dim * K, out_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.predict = nn.Linear(out_dim * 2, 1)
        self.predict_onlycross = nn.Linear(out_dim, 1)

        self.predict_addcross = nn.Linear(out_dim * 3, 1)

        self.backbone = nSGCN(True, True)

    def forward(self, graph, diseases, mirnas, training=True):
        self.G.apply_nodes(lambda nodes: {'z': self.d_fc(nodes.data['d_sim'])}, self.disease_nodes)
        self.G.apply_nodes(lambda nodes: {'z': self.m_fc(nodes.data['m_sim'])}, self.mirna_nodes)

        feats = self.G.ndata.pop('z')  # 1709*64
        X = feats
        # X = F.dropout(feats, p=0.5)
        if training:
            feat0 = []
            y = 0 + X
            x = self.backbone(graph, X, 0.5, 1)
            x = F.relu(x)
            y = torch.cat((y, x), dim=1)
            for i in range(self.K - 2):
                x = self.backbone(graph, x, 0.5, 1)
                x = F.relu(x)
                y = torch.cat((y, x), dim=1)

            h = self.f_fc2(y)
            h_diseases = h[diseases]
            h_mirnas = h[mirnas]
            h_cross = h_diseases * h_mirnas
            h_edge = torch.cat((h_diseases, h_mirnas, h_cross), 1)
            predict_score = torch.sigmoid(self.predict_addcross(h_edge))
            return predict_score

        else:
            feat0 = nSGCConv(graph, X, self.K)
            h = self.f_fc(feat0)
            h_diseases = h[diseases]
            h_mirnas = h[mirnas]
            h_cross = h_diseases * h_mirnas
            h_edge = torch.cat((h_diseases, h_mirnas, h_cross), 1)
            predict_score = torch.sigmoid(self.predict_addcross(h_edge))
            return predict_score


In [122]:
parameter = 'FINAL'

def load_data(directory, random_seed):
    D_SSM1 = np.loadtxt(directory + '/disease_sim/SemSim1.txt')  # 792 * 792
    D_SSM2 = np.loadtxt(directory + '/disease_sim/SemSim2.txt')
    D_GSM = np.loadtxt(directory + '/diseaseSim/DG.txt')

    np.fill_diagonal(D_SSM1, 1)
    np.fill_diagonal(D_SSM2, 1)

    M_FSM = np.loadtxt(directory + '/miRNA_sim/FuncSim.txt')  # 917 * 917
    M_SeSM = np.loadtxt(directory + '/miRNA_sim/SeqSim.txt')  # 917 * 917
    M_Fam = np.loadtxt(directory + '/miRNA_sim/Famsim.txt')
    M_GSM = np.loadtxt(directory + '/miRNA_sim/MG.txt')

    all_associations = pd.read_csv(directory + '/adjacency_matrix.csv',
                                   names=['miRNA', 'disease', 'label'])  # 726264 * 3

    D_SSM = (D_SSM1 + D_SSM2) / 2
    ID = D_SSM
    for i in range(D_SSM.shape[0]):
        for j in range(D_SSM.shape[1]):
            if ID[i][j] == 0:
                ID[i][j] = D_GSM[i][j]
            else:
                ID[i][j] = (D_GSM[i][j] + ID[i][j]) / 2

    M_FmSM = M_Fam
    for i in range(M_Fam.shape[0]):
        for j in range(M_Fam.shape[1]):
            if M_Fam[i][j] == 1:
                M_FmSM[i][j] = (M_Fam[i][j] + M_GSM[i][j]) / 2
            else:
                M_FmSM[i][j] = M_GSM[i][j]

    #IM = M_SeSM     #1
    #IM = M_FSM      #2
    #IM = M_FmSM     #3
    #IM = np.concatenate((M_SeSM, M_FSM), axis=1) #4
    #IM = np.concatenate((M_SeSM, M_FmSM), axis=1) #5
    #IM = np.concatenate((M_FSM, M_FmSM), axis=1) #6
    IM = np.concatenate((M_SeSM, M_FSM, M_FmSM), axis=1) #7

    # M_SSM = (M_FSM + M_SeSM) / 2
    #M_SSM = seq*M_SeSM + fun*M_FSM + fam*M_FmSM
    # for i in range(M_FSM.shape[0]):
    #     for j in range(M_FSM.shape[1]):
    #         if IM[i][j] == 0:
    #             IM[i][j] = M_GSM[i][j]
    #         else:
    #             IM[i][j] = (M_GSM[i][j] + IM[i][j]) / 2


    known_associations = all_associations.loc[all_associations['label'] == 1]  # 14550 * 3
    unknown_associations = all_associations.loc[all_associations['label'] == 0]  # 711714 * 3
    random_negative = unknown_associations.sample(n=known_associations.shape[0], random_state=random_seed,
                                                  axis=0)  # 14550 * 3
    sample_df = known_associations.append(random_negative)
    sample_df.reset_index(drop=True, inplace=True)
    samples = sample_df.values
    return ID, IM, samples

def load_data_full(directory, random_seed):
    D_SSM1 = np.loadtxt(directory + '/disease_sim/SemSim1.txt')  # 792 * 792
    D_SSM2 = np.loadtxt(directory + '/disease_sim/SemSim2.txt')
    D_GSM = np.loadtxt(directory + '/diseaseSim/DG.txt')

    np.fill_diagonal(D_SSM1, 1)
    np.fill_diagonal(D_SSM2, 1)

    M_FSM = np.loadtxt(directory + '/miRNA_sim/FuncSim.txt')  # 917 * 917
    M_SeSM = np.loadtxt(directory + '/miRNA_sim/SeqSim.txt')  # 917 * 917
    M_Fam = np.loadtxt(directory + '/miRNA_sim/Famsim.txt')
    M_GSM = np.loadtxt(directory + '/miRNA_sim/MG.txt')

    all_associations = pd.read_csv(directory + '/new_adjacency_matrix.csv',
                                   names=['miRNA', 'disease', 'label'])  # 726264 * 3

    D_SSM = (D_SSM1 + D_SSM2) / 2
    ID = D_SSM
    for i in range(D_SSM.shape[0]):
        for j in range(D_SSM.shape[1]):
            if ID[i][j] == 0:
                ID[i][j] = D_GSM[i][j]
            else:
                ID[i][j] = (D_GSM[i][j] + ID[i][j]) / 2

    M_FmSM = M_Fam
    for i in range(M_Fam.shape[0]):
        for j in range(M_Fam.shape[1]):
            if M_Fam[i][j] == 1:
                M_FmSM[i][j] = (M_Fam[i][j] + M_GSM[i][j]) / 2
            else:
                M_FmSM[i][j] = M_GSM[i][j]

    IM = np.concatenate((M_SeSM, M_FSM, M_FmSM), axis=1) #7

    known_associations = all_associations.loc[all_associations['label'] == 1]  # 14550 * 3
    unknown_associations = all_associations.loc[all_associations['label'] == 0]  # 711714 * 3
    random_negative = unknown_associations.sample(n=known_associations.shape[0], random_state=random_seed,
                                                  axis=0)  # 14550 * 3
    sample_df = known_associations.append(random_negative)
    sample_df.reset_index(drop=True, inplace=True)
    samples = all_associations.values
    return ID, IM, samples
    
def build_graph(directory, random_seed):
    ID, IM, samples = load_data(directory, random_seed)
    g = dgl.DGLGraph()
    g.add_nodes(ID.shape[0] + IM.shape[0])
    node_type = torch.zeros(g.number_of_nodes(), dtype=torch.int64)
    node_type[: ID.shape[0]] = 1
    g.ndata['type'] = node_type

    d_sim = torch.zeros(g.number_of_nodes(), ID.shape[1])
    d_sim[: ID.shape[0], :] = torch.from_numpy(ID.astype('float32'))
    g.ndata['d_sim'] = d_sim

    m_sim = torch.zeros(g.number_of_nodes(), IM.shape[1])
    m_sim[ID.shape[0]: ID.shape[0] + IM.shape[0], :] = torch.from_numpy(IM.astype('float32'))
    g.ndata['m_sim'] = m_sim

    disease_ids = list(range(1, ID.shape[0] + 1))
    mirna_ids = list(range(1, IM.shape[0] + 1))

    disease_ids_invmap = {id_: i for i, id_ in enumerate(disease_ids)}
    mirna_ids_invmap = {id_: i for i, id_ in enumerate(mirna_ids)}

    sample_disease_vertices = [disease_ids_invmap[id_] for id_ in samples[:, 1]]
    sample_mirna_vertices = [mirna_ids_invmap[id_] + ID.shape[0] for id_ in samples[:, 0]]

    g.add_edges(sample_disease_vertices, sample_mirna_vertices,
                data={'label': torch.from_numpy(samples[:, 2].astype('float32'))})
    g.add_edges(sample_mirna_vertices, sample_disease_vertices,
                data={'label': torch.from_numpy(samples[:, 2].astype('float32'))})

    return g, sample_disease_vertices, sample_mirna_vertices, ID, IM, samples


def build_graph_full(directory, random_seed):
    ID, IM, samples = load_data_full(directory, random_seed)
    g = dgl.DGLGraph()
    g.add_nodes(ID.shape[0] + IM.shape[0])
    node_type = torch.zeros(g.number_of_nodes(), dtype=torch.int64)
    node_type[: ID.shape[0]] = 1
    g.ndata['type'] = node_type

    d_sim = torch.zeros(g.number_of_nodes(), ID.shape[1])
    d_sim[: ID.shape[0], :] = torch.from_numpy(ID.astype('float32'))
    g.ndata['d_sim'] = d_sim

    m_sim = torch.zeros(g.number_of_nodes(), IM.shape[1])
    m_sim[ID.shape[0]: ID.shape[0] + IM.shape[0], :] = torch.from_numpy(IM.astype('float32'))
    g.ndata['m_sim'] = m_sim

    disease_ids = list(range(1, ID.shape[0] + 1))
    mirna_ids = list(range(1, IM.shape[0] + 1))

    disease_ids_invmap = {id_: i for i, id_ in enumerate(disease_ids)}
    mirna_ids_invmap = {id_: i for i, id_ in enumerate(mirna_ids)}

    sample_disease_vertices = [disease_ids_invmap[id_] for id_ in samples[:, 1]]
    sample_mirna_vertices = [mirna_ids_invmap[id_] + ID.shape[0] for id_ in samples[:, 0]]

    g.add_edges(sample_disease_vertices, sample_mirna_vertices,
                data={'label': torch.from_numpy(samples[:, 2].astype('float32'))})
    g.add_edges(sample_mirna_vertices, sample_disease_vertices,
                data={'label': torch.from_numpy(samples[:, 2].astype('float32'))})

    return g, sample_disease_vertices, sample_mirna_vertices, ID, IM, samples

def weight_reset(m):
    if isinstance(m, nn.Linear):
        m.reset_parameters()

In [123]:
def draw_roc_curve(fprs, tprs, auc_result):
    plt.figure(figsize=(10, 8))
    
    mean_fpr = np.linspace(0, 1, 100)
    tprs_interp = []

    for i, (fpr, tpr, auc) in enumerate(zip(fprs, tprs, auc_result), 1):
        tprs_interp.append(interp(mean_fpr, fpr, tpr))
        tprs_interp[-1][0] = 0.0
        plt.plot(fpr, tpr, lw=2, alpha=0.3, label=f'ROC Fold {i} (AUC = {auc:.2f})')

    mean_tpr = np.mean(tprs_interp, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = metrics.auc(mean_fpr, mean_tpr)
    std_tpr = np.std(tprs_interp, axis=0)
    
    plt.plot(mean_fpr, mean_tpr, color='b', label=f'Mean ROC (AUC = {mean_auc:.2f})', lw=2, alpha=.4)
    
    std_auc = np.std(auc_result)
    plt.fill_between(mean_fpr, mean_tpr - std_tpr, mean_tpr + std_tpr, alpha=.2, color='g', label=f'±1 std. dev. (AUC = {std_auc:.2f})')

    plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)
    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.savefig('cv/ROC_curves.png')
    plt.close()


def draw_pr_curve(precisions, recalls, aupr_result):
    plt.figure(figsize=(10, 8))
    
    mean_recall = np.linspace(0, 1, 1000)
    
    precisions_interp = []

    for i, (precision, recall, aupr) in enumerate(zip(precisions, recalls, aupr_result), 1):
        precision = precision[::-1]
        recall = recall[::-1]
        interp_precision = interp(mean_recall, recall, precision)
        precisions_interp.append(interp_precision)
        plt.plot(recall, precision, lw=1, alpha=0.3, label=f'PR Fold {i} (AUPR = {aupr:.3f})')

    # Calculate mean precision and its standard deviation
    mean_precision = np.mean(precisions_interp, axis=0)
    std_precision = np.std(precisions_interp, axis=0)
    mean_aupr = metrics.auc(mean_recall, mean_precision)
    
    # Plot mean PR curve
    plt.plot(mean_recall, mean_precision, color='b', 
             label=f'Mean PR (AUPR = {mean_aupr:.3f})', lw=2, alpha=.4)
    
    # Plot standard deviation area
    std_aupr = np.std(aupr_result)
    plt.fill_between(mean_recall, mean_precision - std_precision, 
                     mean_precision + std_precision, alpha=.2, color='g', 
                     label=f'±1 std. dev. (AUPR std = {std_aupr:.3f})')


    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curve', fontsize=14)
    plt.legend(loc="lower left", fontsize=10)
    
    # Save and show the plot
    plt.savefig('cv/PR_curves.png', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()


def save_mean_metrics(metrics_dict, param=''):
    os.makedirs('cv/metrics/mean', exist_ok=True)
    with open(f'cv/metrics/mean/mean_metrics{param}.txt', 'w') as f:
        for key, value in metrics_dict.items():
            f.write(f"{key} mean: {value['mean']:.4f}, variance: {value['var']:.4f}\n")

def save_loss_plot(train_losses, val_losses, fold=''):
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(train_losses)), train_losses, label='Training Loss')
    plt.plot(range(len(val_losses)), val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title(f'Training and Validation Loss {fold}')
    plt.legend(loc='upper right')
    plt.grid(True)
    os.makedirs('cv/loss', exist_ok=True)
    plt.savefig(f'cv/loss/training_validation_loss{fold}.png')
    plt.close()

def save_roc_curve(fpr, tpr, auc, fold=''):
    plt.figure(figsize=(10, 6))
    plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {auc:.4f})')
    plt.plot([0, 1], [0, 1], linestyle='--', label='Random Classifier')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'Receiver Operating Characteristic (ROC) Curve {fold}')
    plt.legend(loc='lower right')
    plt.grid(True)
    os.makedirs('cv/roc', exist_ok=True)
    plt.savefig(f'cv/roc/roc_curve{fold}.png')
    plt.close()

def save_pr_curve(recall, precision, aupr, fold=''):
    plt.figure(figsize=(10, 6))
    plt.plot(recall, precision, label=f'PR Curve (AUPR = {aupr:.4f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curve {fold}')
    plt.legend(loc='lower left')
    plt.grid(True)
    os.makedirs('cv/pr', exist_ok=True)
    plt.savefig(f'cv/pr/pr_curve{fold}.png')
    plt.close()

def save_accuracy_curve(train_accs, val_accs, fold=''):
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(train_accs)), train_accs, label='Training Accuracy')
    plt.plot(range(len(val_accs)), val_accs, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title(f'Training and Validation Accuracy {fold}')
    plt.legend(loc='lower right')
    plt.grid(True)
    os.makedirs('cv/acc', exist_ok=True)
    plt.savefig(f'cv/acc/accuracy_curve{fold}.png')
    plt.close()

def save_auc_curve(train_aucs, val_aucs, fold=''):
    plt.figure(figsize=(10, 6))
    plt.plot(range(len(train_aucs)), train_aucs, label='Training AUC')
    plt.plot(range(len(val_aucs)), val_aucs, label='Validation AUC')
    plt.xlabel('Epochs')
    plt.ylabel('AUC')
    plt.title(f'Training and Validation AUC {fold}')
    plt.legend(loc='lower right')
    plt.grid(True)
    os.makedirs('cv/auc', exist_ok=True)
    plt.savefig(f'cv/auc/auc_curve{fold}.png')
    plt.close()

def save_metrics(metrics_dict, fold):
    os.makedirs('cv/metrics', exist_ok=True)
    with open(f'cv/metrics/fold{fold}_metrics.txt', 'w') as f:
        for key, value in metrics_dict.items():
            f.write(f"{key}: {value:.4f}\n")

def calculate_metrics(label, score):
    pred = [0 if j < 0.5 else 1 for j in score]

    auc = metrics.roc_auc_score(label, score)
    acc = metrics.accuracy_score(label, pred)    
    pre = metrics.precision_score(label, pred)
    rec = metrics.recall_score(label, pred)
    f1 = metrics.f1_score(label, pred)

    fpr, tpr, thresholds = metrics.roc_curve(label, score)
    tn, fp, fn, tp = metrics.confusion_matrix(label, pred).ravel()
    mcc = metrics.matthews_corrcoef(label, pred)
    test_auc = metrics.auc(fpr, tpr)
    spe = tn / (tn + fp)
    precision, recall, _ = metrics.precision_recall_curve(label, score)
    aupr = metrics.auc(recall, precision)

    return acc, pre, rec, spe, f1, auc, aupr, mcc, fpr, tpr, precision, recall


In [124]:
import warnings
warnings.filterwarnings("ignore")

def Train(directory, epochs, n_classes, in_size, out_dim, dropout, slope, lr, wd, random_seed, cuda, kk):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    context = torch.device('cpu')

    g, disease_vertices, mirna_vertices, ID, IM, samples = build_graph(directory, random_seed)
    samples_df = pd.DataFrame(samples, columns=['miRNA', 'disease', 'label'])
    g.to(context)

    auc_result = []
    acc_result = []
    pre_result = []
    recall_result = []
    f1_result = []
    aupr_result = []
    mcc_result = []
    specificity_result = []

    fprs = []
    tprs = []
    pres = []
    recs = []

    i = 0
    kf = KFold(n_splits=5, shuffle=True, random_state=random_seed)

    for train_idx, test_idx in kf.split(samples[:, 2]):
        train_losses, train_accs, val_losses, val_accs = [], [], [], []
        train_aucs, val_aucs = [], []
        i += 1
        print('\nTraining for Fold', i)

        samples_df['train'] = 0
        samples_df['train'].iloc[train_idx] = 1

        train_tensor = torch.from_numpy(samples_df['train'].values.astype('int64'))

        edge_data = {'train': train_tensor}

        g.edges[disease_vertices, mirna_vertices].data.update(edge_data)
        g.edges[mirna_vertices, disease_vertices].data.update(edge_data)

        train_eid = g.filter_edges(lambda edges: edges.data['train'])
        g_train = g.edge_subgraph(train_eid, relabel_nodes=False)
        label_train = g_train.edata['label'].unsqueeze(1)
        src_train, dst_train = g_train.all_edges()
        test_eid = g.filter_edges(lambda edges: edges.data['train'] == 0)
        src_test, dst_test = g.find_edges(test_eid)
        label_test = g.edges[test_eid].data['label'].unsqueeze(1)
        print('Training edges:', len(train_eid))
        print('Testing edges:', len(test_eid))

        model = nSGC(G=g_train,
                     hid_dim=in_size,
                     n_class=n_classes,
                     K=kk,
                     batchnorm=False,
                     num_diseases=ID.shape[0],
                     num_mirnas=IM.shape[0],
                     d_sim_dim=ID.shape[1],
                     m_sim_dim=IM.shape[1],
                     out_dim=out_dim,
                     dropout=dropout,
                     slope=slope)
        
        model.apply(weight_reset)
        model.to(context)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        loss = nn.BCELoss()

        for epoch in range(epochs):
            start = time.time()

            model.train()
            with torch.autograd.set_detect_anomaly(True):
                score_train = model(g_train, src_train, dst_train, True)
                loss_train = loss(score_train, label_train)

                optimizer.zero_grad()
                loss_train.backward()
                optimizer.step()

            model.eval()
            with torch.no_grad():
                score_val = model(g, src_test, dst_test, True)
                loss_val = loss(score_val, label_test)

            score_train_cpu = np.squeeze(score_train.cpu().detach().numpy())
            score_val_cpu = np.squeeze(score_val.cpu().detach().numpy())
            label_train_cpu = np.squeeze(label_train.cpu().detach().numpy())
            label_val_cpu = np.squeeze(label_test.cpu().detach().numpy())

            pred_train = [0 if j < 0.5 else 1 for j in score_train_cpu]
            pred_val = [0 if j < 0.5 else 1 for j in score_val_cpu]
        
            train_acc = metrics.accuracy_score(label_train_cpu, pred_train) 
            val_acc = metrics.accuracy_score(label_val_cpu, pred_val) 
            train_auc = metrics.roc_auc_score(label_train_cpu, score_train_cpu)
            val_auc = metrics.roc_auc_score(label_val_cpu, score_val_cpu)

            train_losses.append(loss_train.item())
            val_losses.append(loss_val.item())
            train_accs.append(train_acc)
            val_accs.append(val_acc)
            train_aucs.append(train_auc)
            val_aucs.append(val_auc)

            end = time.time()
            if (epoch + 1) % 50 == 0:
                print('    Epoch:', epoch + 1, 'Train Loss: %.4f' % loss_train.item(), 'Val Loss: %.4f' % loss_val.item(), 'Time: %.2f' % (end - start))

        model.eval()
        with torch.no_grad():
            score_test = model(g, src_test, dst_test, True)

        score_test_cpu = np.squeeze(score_test.cpu().detach().numpy())
        label_test_cpu = np.squeeze(label_test.cpu().detach().numpy())

        test_acc, test_pre, test_rec, test_spe, test_f1, test_auc, test_aupr, test_mcc, fpr, tpr, precision, recall = calculate_metrics(label_test_cpu, score_test_cpu)


        save_loss_plot(train_losses, val_losses, f'Fold{i}')
        save_accuracy_curve(train_accs, val_accs, f'Fold{i}')
        save_auc_curve(train_aucs, val_aucs, f'Fold{i}')
        save_roc_curve(fpr, tpr, test_auc, f'Fold{i}')
        save_pr_curve(recall, precision, test_aupr, f'Fold{i}')
    

        print('Fold:', i, 'Acc: %.4f' % test_acc, 'Pre: %.4f' % test_pre, 'Rec: %.4f' % test_rec, 'Spec: %.4f' % test_spe,
              'F1: %.4f' % test_f1, 'AUC: %.4f' % test_auc, 'AUPR: %.4f' % test_aupr, 'MCC: %.4f' % test_mcc)
                      
        # Save metrics for this fold
        fold_metrics = {
            'Accuracy': test_acc,
            'Precision': test_pre,
            'Recall': test_rec,
            'F1-score': test_f1,
            'AUC': test_auc,
            'AUPR': test_aupr,
            'MCC': test_mcc,
            'Specificity': test_spe
        }
        save_metrics(fold_metrics, i)

        auc_result.append(test_auc)
        acc_result.append(test_acc)
        pre_result.append(test_pre)
        recall_result.append(test_rec)
        f1_result.append(test_f1)
        aupr_result.append(test_aupr)
        mcc_result.append(test_mcc)
        specificity_result.append(test_spe)

        fprs.append(fpr)
        tprs.append(tpr)
        pres.append(precision)
        recs.append(recall)

    # Draw ROC and PR curves
    draw_roc_curve(fprs, tprs, auc_result)
    draw_pr_curve(pres, recs, aupr_result)

    # Calculate and save mean metrics
    mean_metrics = {
        'AUC': {'mean': np.mean(auc_result), 'var': np.std(auc_result)},
        'Accuracy': {'mean': np.mean(acc_result), 'var': np.std(acc_result)},
        'Precision': {'mean': np.mean(pre_result), 'var': np.std(pre_result)},
        'Recall': {'mean': np.mean(recall_result), 'var': np.std(recall_result)},
        'F1-score': {'mean': np.mean(f1_result), 'var': np.std(f1_result)},
        'AUPR': {'mean': np.mean(aupr_result), 'var': np.std(aupr_result)},
        'MCC': {'mean': np.mean(mcc_result), 'var': np.std(mcc_result)},
        'Specificity': {'mean': np.mean(specificity_result), 'var': np.std(specificity_result)}
    }
    save_mean_metrics(mean_metrics, parameter)

    print('DONE!')
    print('-----------------------------------------------------------------------------------------------')
    for metric, values in mean_metrics.items():
        print(f'{metric} mean: {values["mean"]:.4f}, variance: {values["var"]:.4f}')
    
    torch.save(model.state_dict(), f'best_model.pth')
    return fprs, tprs, auc_result, pres, recs, aupr_result


In [125]:
fprs, tprs, auc, precisions, recalls, aupr = Train(directory='data',
                                                  epochs=200,
                                                  n_classes=64,
                                                  in_size=64,
                                                  out_dim=64,
                                                  dropout=0.5,
                                                  slope=0.2,
                                                  lr=0.0001,
                                                  wd=5e-3,
                                                  random_seed=1225,
                                                  cuda=True, kk=4)

In [174]:
def load_full_data(directory):
    g, disease_vertices, mirna_vertices, ID, IM, samples = build_graph_full(directory, random_seed=42)
    return g, disease_vertices, mirna_vertices, ID, IM, samples

def train_full_model(g, disease_vertices, mirna_vertices, ID, IM, samples):
    epochs = 200
    lr = 0.0001
    wd = 5e-3

    model = nSGC(G=g,
                 hid_dim=64,
                 n_class=64,
                 K=4,
                 batchnorm=False,
                 num_diseases=ID.shape[0],
                 num_mirnas=IM.shape[0],
                 d_sim_dim=ID.shape[1],
                 m_sim_dim=IM.shape[1],
                 out_dim=64,
                 dropout=0.5,
                 slope=0.2)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    loss_fn = nn.BCELoss()

    src, dst = g.all_edges()
    label = g.edata['label'].unsqueeze(1)

    model.train()
    for epoch in range(epochs):
        start = time.time()

        score = model(g, src, dst, True)
        loss = loss_fn(score, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        end = time.time()

        if (epoch + 1) % 50 == 0:
            print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}, Time: {end - start:.2f}s')
    
    torch.save(model.state_dict(), './GCN-MDA')    
    return model

def predict_associations(model, g, node_id):
    g_sub = g.in_subgraph(node_id, relabel_nodes=False, store_ids=True)
    src, dst = g_sub.all_edges()
    
    model.eval()
    with torch.no_grad():
        scores = model(g, src, dst, True)
    
    return scores, g_sub

def predict_associations2(model, g, node_id):
    g_sub = g.in_subgraph(node_id, relabel_nodes=False, store_ids=True)
    src, dst = g_sub.all_edges()
    
    model.eval()
    with torch.no_grad():
        scores = model(g, src, dst, True)
    
    return scores.squeeze().tolist()

def predict_associations3(model, g, ind_d, mir):
    dst = torch.zeros(mir, dtype=torch.int64).tolist()
    dst[ind_d] = 1
    
    # g_sub = g.in_subgraph(node_id, relabel_nodes=False, store_ids=True)
    # src, dst = g_sub.all_edges()
    
    model.eval()
    with torch.no_grad():
        scores = model(g, dst, mir, True)
    
    return scores.squeeze().tolist()


In [127]:
# g, disease_vertices, mirna_vertices, ID, IM, samples = load_full_data('data/')
# model = nSGC(G=g,
#              hid_dim=64,           # Use the same parameters as used during training
#              n_class=64,
#              K=4,
#              batchnorm=False,
#              num_diseases=ID.shape[0],
#              num_mirnas=IM.shape[0],
#              d_sim_dim=ID.shape[1],
#              m_sim_dim=IM.shape[1],
#              out_dim=64,
#              dropout=0.5,
#              slope=0.2)
# model = model.load_state_dict(torch.load('./GCN-MDA'))
# #disease_nodes = g.filter_nodes(lambda nodes: nodes.data['type'] == 0)
# predictions = predict_associations(model, 0)
# predictions = predictions.squeeze().tolist()
# #len(predictions.squeeze().tolist())

In [176]:
g, disease_vertices, mirna_vertices, ID, IM, samples = load_full_data('data')

In [None]:
model = train_full_model(g, disease_vertices, mirna_vertices, ID, IM, samples)

In [None]:
K=50

disease_df = pd.read_csv("data/disease_name.csv", names=['disease_name'])
mirna_df = pd.read_csv("data/miRNA_name.csv", names=['mirna_name'])
output_dir = 'disease_predictions'
os.makedirs(output_dir, exist_ok=True)

diseases_to_predict = ['breast neoplasms', 
                       'leukemia', 'lung neoplasms']

for disease_name in diseases_to_predict:
    print(f"\nPredicting top {K} miRNAs for {disease_name}:")
    disease_id = disease_df.index[disease_df['disease_name'] == disease_name].tolist()
    print(disease_id)
    predictions, gg = predict_associations(model, g, disease_id)
    predictions = predictions.squeeze().tolist()

    a, b = gg.all_edges()
    print(a.shape, b.shape, len(predictions))
    a = a.tolist()
    b = b.tolist()
    df = pd.DataFrame({
        'src': a,
        'dst': b,
        'preds': predictions
    })
    
    filtered_df = df[(df['src'] == disease_id[0]) | (df['dst'] == disease_id[0])]
    sorted_df = filtered_df.sort_values(by='preds', ascending=False)
    bib = sorted_df[:K]
    preddss = bib['preds'].values
    selected_columns = bib[['dst', 'src']]
    values_list = selected_columns.values.flatten().tolist()
    top_k_indices = []
    for x in values_list:
        if x != disease_id[0]:
            top_k_indices.append(x)
        
    # Create a list to store top K results for this disease
    top_k_results = []
    
    for rank, idx in enumerate(top_k_indices, 1):
        mirna_name = mirna_df['mirna_name'].iloc[idx - disease_df.shape[0]]
        score = preddss[rank - 1]
        top_k_results.append({
            'miRNA': mirna_name,
            'Rank': rank,
            'Score': score
        })
        print(f"{rank}. {mirna_name}: {score:.4f}")
    
    # Save top K results for this disease to a CSV file
    disease_filename = disease_name.replace(" ", "_").lower()
    output_file = os.path.join(output_dir, f'{disease_filename}_top_{K}_predictions2.csv')
    fieldnames = ['miRNA', 'Score', 'Rank']

    with open(output_file, 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for result in top_k_results:
            writer.writerow(result)
    
    print(f"Top {K} prediction results for {disease_name} have been saved to {output_file}")
