In [1]:
import torch
import torch.nn.functional as F
#from torch_geometric.nn.dense.linear import Linear
from torch.nn import Linear
from torch_geometric.nn import GCNConv, APPNP, GCN2Conv, GATv2Conv, ResGatedGraphConv, GENConv
from typing import Optional
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    SparseTensor,
)
from  torch_geometric.nn.conv.gcn_conv import gcn_norm
import torch.nn as nn

import numpy as np

from scipy.sparse import coo_matrix
import torch.optim as optim



In [2]:
class Graph_Conv(MessagePassing):

    _cached_edge_index: Optional[OptPairTensor]
    _cached_adj_t: Optional[SparseTensor]

    def __init__(
        self,
        improved: bool = False,
        cached: bool = False,
        add_self_loops: Optional[bool] = None,
        normalize: bool = True,
        bias: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        if add_self_loops is None:
            add_self_loops = normalize

        if add_self_loops and not normalize:
            raise ValueError(f"'{self.__class__.__name__}' does not support "
                             f"adding self-loops to the graph when no "
                             f"on-the-fly normalization is applied")
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.normalize = normalize

        self._cached_edge_index = None
        self._cached_adj_t = None

    def reset_parameters(self):
        super().reset_parameters()
        self._cached_edge_index = None
        self._cached_adj_t = None

    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:

        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim))
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim))
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        return out

    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
        return spmm(adj_t, x, reduce=self.aggr)

In [3]:
def full_attention_conv(q, k, v, output_attn=False):
    sqrt_n = torch.sqrt(torch.tensor(q.shape[0], dtype=torch.float32))

    a = torch.einsum("lmh,ldh->mdh", k/sqrt_n, v)

    attention = torch.softmax(a, dim=0)

    output = torch.einsum("lmh,mdh->ldh", q, attention)

    if output_attn:
        return output, attention

    return output

class TransConvLayer(nn.Module):
    def __init__(self, in_channels,
                 out_channels,
                 num_heads,
                 use_weight=True):
        super().__init__()
        self.Wk = nn.Linear(in_channels, out_channels * num_heads)
        self.Wq = nn.Linear(in_channels, out_channels * num_heads)
        if use_weight:
            self.Wv = nn.Linear(in_channels, out_channels * num_heads)

        self.out_channels = out_channels
        self.num_heads = num_heads
        self.use_weight = use_weight

    def reset_parameters(self):
        self.Wk.reset_parameters()
        self.Wq.reset_parameters()
        if self.use_weight:
            self.Wv.reset_parameters()

    def forward(self, query_input, source_input, output_attn=False):
        # feature transformation
        query = self.Wq(query_input).reshape(-1, self.out_channels ,
                                             self.num_heads)
        key = self.Wk(source_input).reshape(-1, self.out_channels ,
                                            self.num_heads)
        if self.use_weight:
            value = self.Wv(source_input).reshape(-1, self.out_channels,
                                                  self.num_heads)
        else:
            value = source_input.reshape(-1, self.out_channels, 1)

        # compute full attentive aggregation
        if output_attn:
            attention_output, attn = full_attention_conv(
                query, key, value, output_attn)  # [N, H, D]
        else:
            attention_output = full_attention_conv(
                query, key, value)  # [N, H, D]

        final_output = attention_output
        final_output = final_output.mean(dim=-1)

        if output_attn:
            return final_output, attn
        else:
            return final_output

class TransConv(nn.Module):
    def __init__(self, in_channels, hidden_channels, activation, num_layers=1, num_heads=2,
                 alpha=0.1, dropout=0.3, use_bn=True, use_residual=True, use_weight=True, use_act=True):
        super().__init__()
        self.convs = nn.ModuleList()
        self.fcs = nn.ModuleList()
        self.fcs.append(nn.Linear(in_channels, hidden_channels))
        self.bns = nn.ModuleList()
        self.bns.append(nn.LayerNorm(hidden_channels))
        for i in range(num_layers):
            self.convs.append(
                TransConvLayer(hidden_channels, hidden_channels, num_heads=num_heads, use_weight=use_weight))
            self.bns.append(nn.LayerNorm(hidden_channels))
        self.dropout = dropout
        self.use_bn = use_bn
        self.residual = use_residual
        self.alpha = alpha
        self.use_act = use_act
        self.activation = activation
        self.reset_parameters()
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        for fc in self.fcs:
            fc.reset_parameters()

    def forward(self, x):
        layer_ = []
        # input MLP layer
        x = self.fcs[0](x)
        if self.use_bn:
            x = self.bns[0](x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        layer_.append(x)

        for i, conv in enumerate(self.convs):
            x = conv(x, x)
            if self.residual:
                x = self.alpha * x + (1 - self.alpha) * layer_[i]
            if self.use_bn:
                x = self.bns[i + 1](x)
            if self.use_act:
                x = self.activation(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            layer_.append(x)
        return x

    def get_attentions(self, x):
        layer_, attentions = [], []
        x = self.fcs[0](x)
        if self.use_bn:
            x = self.bns[0](x)
        x = self.activation(x)
        layer_.append(x)
        for i, conv in enumerate(self.convs):
            x, attn = conv(x, x, output_attn=True)
            attentions.append(attn)
            if self.residual:
                x = self.alpha * x + (1 - self.alpha) * layer_[i]
            if self.use_bn:
                x = self.bns[i + 1](x)
            layer_.append(x)
        return torch.stack(attentions, dim=0)  # [layer num, N, N]



In [4]:
class DUALFormer_Model(torch.nn.Module):
    def __init__(self, sno_feature, dis_feature, input_dim,
                 hidden_dim,
                 output_dim,
                 activation,
                 num_gnns,
                 num_trans,
                 num_heads,
                 gat_heads,
                 dropout_trans,
                 dropout,
                 alpha,
                 use_bn,
                 lammda=0.1,
                 GraphConv='sgc'):
        super(DUALFormer_Model, self).__init__()

        self.sno_feature = sno_feature
        self.dis_feature = dis_feature
        self.lin_sf = Linear(sno_feature.shape[1], input_dim)
        self.lin_df = Linear(dis_feature.shape[1], input_dim)   
        
        self.activation = activation()
        self.num_gnns = num_gnns
        self.layers_trans = TransConv(input_dim, hidden_dim, self.activation,
                                      num_layers=num_trans, num_heads=num_heads,
                                      alpha=alpha, dropout=dropout_trans,
                                      use_bn=use_bn, use_residual=True,
                                      use_weight=True, use_act=True)

        if GraphConv == 'sgc':
            self.convs = torch.nn.ModuleList()
            for _ in range(num_gnns):
                self.convs.append(Graph_Conv())
        elif GraphConv == 'gcn':
            self.convs = torch.nn.ModuleList()
            for _ in range(num_gnns):
                self.convs.append(GCNConv(hidden_dim, hidden_dim))
        elif GraphConv == 'appnp':
            self.convs = APPNP(num_gnns, lammda)
        elif GraphConv == 'gcn2':
            self.convs = GCN2Conv(hidden_dim, lammda, theta = 0.1, layer = num_gnns)
        elif GraphConv == 'gat':
            self.convs = torch.nn.ModuleList()
            for _ in range(num_gnns):
                self.convs.append(GATv2Conv(hidden_dim, hidden_dim, heads = gat_heads, concat = False))
        elif GraphConv == 'resgatedgcn':
            self.convs = torch.nn.ModuleList()
            for _ in range(num_gnns):
                self.convs.append(ResGatedGraphConv(hidden_dim, hidden_dim))
        elif GraphConv == 'gen':
            self.convs = torch.nn.ModuleList()
            for _ in range(num_gnns):
                self.convs.append(GENConv(hidden_dim, hidden_dim))

        self.GraphConv = GraphConv
        #self.linear_project = Linear(hidden_dim, output_dim, weight_initializer = "glorot")
        self.linear_project = Linear(hidden_dim, output_dim)
        self.dropout = dropout

        self.params1 = list(self.layers_trans.parameters())
        self.params2 = list(self.linear_project.parameters())
        self.params3 = list(self.lin_sf.parameters())
        self.params4 = list(self.lin_df.parameters())

        self.traning = True
        self.reset_parameters()

    def reset_parameters(self):

        self.layers_trans.reset_parameters()
        self.linear_project.reset_parameters()
        self.lin_sf.reset_parameters()
        self.lin_df.reset_parameters()

    def forward(self, edge_index):

        sno_x = self.lin_sf(self.sno_feature)
        dis_x = self.lin_df(self.dis_feature)
        x = torch.cat((sno_x, dis_x), dim = 0)

        #x = F.dropout(x, p = self.dropout, training=self.training)
        #x = F.relu(x)
        z = self.layers_trans(x)
        temp = z
        if self.GraphConv in ['sgc', 'gcn', 'resgatedgcn', 'gen']:#sgc, gcn
            for i, conv in enumerate(self.convs):
                #z = F.dropout(z, p = self.dropout, training=self.training)
                #z = F.relu(z)
                z = conv(z, edge_index)
        elif self.GraphConv == 'appnp':
            z = self.convs(z, edge_index) #appnp
        elif self.GraphConv == 'gcn2':
            z = self.convs(z, temp, edge_index)
        elif self.GraphConv == 'gat':
            for i, conv in enumerate(self.convs):
                z = conv(z, edge_index)
                
        z = F.dropout(z, p=self.dropout, training=self.training)
        z = self.linear_project(z)
        adj_rec = torch.sigmoid(torch.mm(z, z.T))

        return adj_rec


In [5]:
def construct_adj_mat(training_mask):
    adj_tmp = training_mask.copy()
    rna_mat = np.zeros((training_mask.shape[0], training_mask.shape[0]))
    dis_mat = np.zeros((training_mask.shape[1], training_mask.shape[1]))

    mat1 = np.hstack((rna_mat, adj_tmp))
    mat2 = np.hstack((adj_tmp.T, dis_mat))
    ret = np.vstack((mat1, mat2))
    return ret

In [6]:
def get_metrics(real_score, predict_score, roc_path, pr_path, i):
    real_score, predict_score = real_score.flatten(), predict_score.flatten()
    sorted_predict_score = np.array(
        sorted(list(set(np.array(predict_score).flatten()))))
    sorted_predict_score_num = len(sorted_predict_score)
    thresholds = sorted_predict_score[np.int32(
        sorted_predict_score_num*np.arange(1, 1000)/1000)]
    thresholds = np.mat(thresholds)
    thresholds_num = thresholds.shape[1]

    predict_score_matrix = np.tile(predict_score, (thresholds_num, 1))
    negative_index = np.where(predict_score_matrix < thresholds.T)
    positive_index = np.where(predict_score_matrix >= thresholds.T)
    predict_score_matrix[negative_index] = 0
    predict_score_matrix[positive_index] = 1
    TP = predict_score_matrix.dot(real_score.T)
    FP = predict_score_matrix.sum(axis=1)-TP
    FN = real_score.sum()-TP
    TN = len(real_score.T)-TP-FP-FN

    fpr = FP/(FP+TN)
    tpr = TP/(TP+FN)
    ROC_dot_matrix = np.mat(sorted(np.column_stack((fpr, tpr)).tolist())).T
    ROC_dot_matrix.T[0] = [0, 0]
    ROC_dot_matrix = np.c_[ROC_dot_matrix, [1, 1]]

    #np.savetxt(roc_path.format(i), ROC_dot_matrix)

    x_ROC = ROC_dot_matrix[0].T
    y_ROC = ROC_dot_matrix[1].T
    auc = 0.5*(x_ROC[1:]-x_ROC[:-1]).T*(y_ROC[:-1]+y_ROC[1:])

    recall_list = tpr
    precision_list = TP/(TP+FP)
    PR_dot_matrix = np.mat(sorted(np.column_stack(
        (recall_list, precision_list)).tolist())).T
    PR_dot_matrix.T[0] = [0, 1]
    PR_dot_matrix = np.c_[PR_dot_matrix, [1, 0]]

    #np.savetxt(pr_path.format(i), PR_dot_matrix)

    x_PR = PR_dot_matrix[0].T
    y_PR = PR_dot_matrix[1].T
    aupr = 0.5*(x_PR[1:]-x_PR[:-1]).T*(y_PR[:-1]+y_PR[1:])

    f1_score_list = 2*TP/(len(real_score.T)+TP-TN)
    accuracy_list = (TP+TN)/len(real_score.T)
    specificity_list = TN/(TN+FP)
    # plt.plot(x_ROC, y_ROC)
    # plt.plot(x_PR,y_PR)
    # plt.show()
    max_index = np.argmax(f1_score_list)
    f1_score = f1_score_list[max_index]
    accuracy = accuracy_list[max_index]
    specificity = specificity_list[max_index]
    recall = recall_list[max_index]
    precision = precision_list[max_index]
    print( ' auc:{:.4f} ,aupr:{:.4f},f1_score:{:.4f}, accuracy:{:.4f}, recall:{:.4f}, specificity:{:.4f}, precision:{:.4f}'.format( auc[0, 0],aupr[0, 0], f1_score, accuracy, recall, specificity, precision))
    return [auc[0, 0], aupr[0, 0], f1_score, accuracy, recall, specificity, precision]


def cv_model_evaluate(interaction_matrix, predict_matrix, train_matrix):
    test_index = np.where(train_matrix == 0)
    real_score = interaction_matrix[test_index]
    predict_score = predict_matrix[test_index]
    return get_metrics(real_score, predict_score)


# turn dense matrix into a sparse foramt
def dense2sparse(matrix: np.ndarray):
    mat_coo = coo_matrix(matrix)
    edge_idx = np.vstack((mat_coo.row, mat_coo.col))
    return edge_idx, mat_coo.data

def calculate_loss(pred, pos_edge_idx, neg_edge_idx, device):
    pos_pred_socres = pred[pos_edge_idx[0], pos_edge_idx[1]]
    neg_pred_socres = pred[neg_edge_idx[0], neg_edge_idx[1]]
    pred_scores = torch.hstack((pos_pred_socres, neg_pred_socres))
    true_labels = torch.hstack((torch.ones(pos_pred_socres.shape[0]), torch.zeros(neg_pred_socres.shape[0]))).to(device)
    loss_fun=torch.nn.BCELoss(reduction='mean')
    # loss_fun=torch.nn.BCEWithLogitsLoss(reduction='mean')
    return loss_fun(pred_scores, true_labels)

def calculate_evaluation_metrics(pred_mat, pos_edges, neg_edges, roc_path, pr_path, i):
    pos_pred_socres = pred_mat[pos_edges[0], pos_edges[1]]
    neg_pred_socres = pred_mat[neg_edges[0], neg_edges[1]]
    pred_labels = np.hstack((pos_pred_socres, neg_pred_socres))
    true_labels = np.hstack((np.ones(pos_pred_socres.shape[0]), np.zeros(neg_pred_socres.shape[0])))
    return get_metrics(true_labels, pred_labels, roc_path, pr_path, i)


In [7]:
random_seed = 10086
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [8]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

input_graph_all = r"SDI/data/graph_data/hetero_graph_2-4.pkl"
graph_all = torch.load(input_graph_all)

snorna_feature = torch.tensor(graph_all["snorna"]["x"], dtype=torch.float, device = device)
disease_feature = torch.tensor(graph_all["disease"]["x"], dtype=torch.float, device = device)
snorna_disease_edge_index = graph_all["snorna", "to", "disease"]["edge_index"]
snorna_disease_mat = np.zeros((snorna_feature.shape[0], disease_feature.shape[0]))
snorna_disease_mat[snorna_disease_edge_index[0], snorna_disease_edge_index[1]] = 1

new_adj_mat = construct_adj_mat(snorna_disease_mat)
edge_index =  np.array(tuple(np.where(new_adj_mat !=0)))

edge_index = torch.tensor(edge_index, dtype=torch.long, device = device)

  graph_all = torch.load(input_graph_all)
  snorna_feature = torch.tensor(graph_all["snorna"]["x"], dtype=torch.float, device = device)
  disease_feature = torch.tensor(graph_all["disease"]["x"], dtype=torch.float, device = device)


In [45]:
def main(edge_index, new_adj_mat, snorna_feature, disease_feature, args_config, device):
    # initialize parameters
    lr = args_config['lr']
    weight_decay = args_config['weight_decay']
    kfolds = args_config['kfolds']
    num_epoch = args_config['num_epoch']    
    input_dim = args_config['input_dim']
    activation = args_config['activation']
    num_gnns = args_config['num_gnns']
    num_heads = args_config['num_heads']
    gat_heads = args_config['gat_heads']
    num_trans = args_config['num_trans']
    dropout_trans = args_config['dropout_trans']
    dropout = args_config['dropout']
    alpha = args_config['alpha']
    use_bn = args_config['use_bn']
    lammda = args_config['lammda']
    GraphConv = args_config['GraphConv']
    hidden_dim = args_config['hidden_dim']
    output_dim = args_config['output_dim']      

    rng = np.random.default_rng(10086)
    pos_samples, edge_attr = dense2sparse(new_adj_mat)
    pos_samples_shuffled = rng.permutation(pos_samples, axis=1)

    # get the edge index of negative samples
    rng = np.random.default_rng(10086)
    neg_samples = np.where(new_adj_mat == 0)
    neg_samples_shuffled = rng.permutation(neg_samples, axis=1)[:, :pos_samples_shuffled.shape[1]]
    
    pos_edges = pos_samples_shuffled
    neg_edges = neg_samples_shuffled
    idx = np.arange(pos_edges.shape[1])
    np.random.shuffle(idx)
    idx_splited = np.array_split(idx, kfolds)
    metrics_tensor = np.zeros((1, 7))
    for i in range(kfolds):
        tmp = []
        for j in range(1, kfolds):
            tmp.append(idx_splited[(j + i) % kfolds])
        tmp = np.concatenate(tmp)
        training_pos_edges = pos_edges[:, tmp]
        training_neg_edges = neg_edges[:, tmp]
        test_pos_edges = pos_edges[:, idx_splited[i]]
        test_neg_edges = neg_edges[:, idx_splited[i]]      

        print(f'################Fold {i + 1} of {kfolds}################')
        #sno_feature, dis_feature, hidden_dim = 64, out_dim = 32, K = [16, 8, 4], drop_rate = 0.5
        model = DUALFormer_Model(sno_feature = snorna_feature, dis_feature = disease_feature, input_dim = input_dim,
                    hidden_dim = hidden_dim,
                    output_dim = output_dim,
                    activation =  activation,
                    num_gnns = num_gnns,
                    num_trans = num_trans,
                    num_heads = num_heads,
                    gat_heads = gat_heads,
                    dropout_trans = dropout_trans,
                    dropout = dropout,
                    alpha = alpha,
                    use_bn = use_bn,
                    lammda = lammda,
                    GraphConv = GraphConv
                    ).to(device)
        #print(model)

        optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=weight_decay)
        #optimizer = optim.RMSprop(model.parameters(),lr=lr,alpha=0.99, eps=1e-08, momentum=0.1,weight_decay=weight_decay,centered=False)
        """
        base_lr=5e-5       
        scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr, max_lr=lr, step_size_up=200,
                                                step_size_down=200, mode='exp_range', gamma=0.99, scale_fn=None,
                                                cycle_momentum=False, last_epoch=-1)
        """
        
        for epoch in range(num_epoch):
            model.train()
            output = model(edge_index)
            
            #print(de_result)
            loss = calculate_loss(output, training_pos_edges, training_neg_edges, device = device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #scheduler.step()
            
            if (epoch + 1) % 500 == 0 or epoch == 0:
                pass
                print('------EPOCH {} of {}------'.format(epoch + 1, args_config['num_epoch']))
                print('Loss: {}'.format(loss))
        model.eval()
        with torch.no_grad():
            output = model(edge_index)
            
            roc_path = r'SDI/data/graph_data/SDI_DUAL_ROC_fold_{}.csv'
            pr_path = r'SDI/data/graph_data/SDI_DUAL_PR_fold_{}.csv'

            metrics = calculate_evaluation_metrics(output.detach().cpu(), test_pos_edges, 
                                                   test_neg_edges, roc_path, pr_path, i)
            metrics_tensor += metrics
            
    print('Average result:' ,end='')
    avg_metrics = metrics_tensor / kfolds
    del metrics_tensor
    # print( ' {:.4f} {:.4f} {:.4f}  {:.4f} {:.4f} {:.4f} {:.4f}'
    #        .format(avg_metrics[0][0],avg_metrics[0][1],avg_metrics[0][2],avg_metrics[0][3],avg_metrics[0][4],avg_metrics[0][5],avg_metrics[0][6]))
    print(avg_metrics)
    return avg_metrics


In [46]:
args_config = {
        'kfolds': 10,         
        'input_dim' : 256,
        'hidden_dim' : 128,
        'output_dim' : 64,
        'activation' : nn.ReLU,
        'num_gnns' : 3,#3 best
        'num_trans' : 4,#1-4 up, 5 down
        'num_heads' : 4,#
        'gat_heads' : 4,#this para work only when uses gat  
        'dropout_trans' : 0.2,
        'dropout' : 0.2,
        'alpha' : 0.1,
        'use_bn' : True,
        'lammda' : 0.3,
        'GraphConv' : "gcn",  #"sgc"\"gcn"\"appnp"\"gcn2"\"gat"\"resgatedgcn"      
        'lr': 0.0005,#1e-3，la是1e-3   #0.0001,0.0005,0.001,0.005,0.01,0.05 
        'weight_decay': 0.0001,#5e-3   #0.0001,0.0005,0.001,0.005,0.01,0.05    
        'num_epoch' : 1200
    }

reulst = main(edge_index, new_adj_mat, snorna_feature, disease_feature, args_config, device)

################Fold 1 of 10################
------EPOCH 1 of 1200------
Loss: 2.0123496055603027
------EPOCH 500 of 1200------
Loss: 0.4511253833770752
------EPOCH 1000 of 1200------
Loss: 0.4037947356700897
 auc:0.9847 ,aupr:0.9818,f1_score:0.9558, accuracy:0.9560, recall:0.9505, specificity:0.9615, precision:0.9611
################Fold 2 of 10################
------EPOCH 1 of 1200------
Loss: 3.035507917404175
------EPOCH 500 of 1200------
Loss: 0.4818693697452545
------EPOCH 1000 of 1200------
Loss: 0.4142950773239136
 auc:0.9821 ,aupr:0.9773,f1_score:0.9539, accuracy:0.9533, recall:0.9670, specificity:0.9396, precision:0.9412
################Fold 3 of 10################
------EPOCH 1 of 1200------
Loss: 2.7753405570983887
------EPOCH 500 of 1200------
Loss: 0.47443529963493347
------EPOCH 1000 of 1200------
Loss: 0.4202680289745331
 auc:0.9933 ,aupr:0.9913,f1_score:0.9672, accuracy:0.9670, recall:0.9725, specificity:0.9615, precision:0.9620
################Fold 4 of 10############

In [11]:
model = DUALFormer_Model(sno_feature = snorna_feature, dis_feature = disease_feature, input_dim = 128,
                    hidden_dim = 64,
                    output_dim = 32,
                    activation =  nn.ReLU,
                    num_gnns = 2,
                    num_trans = 1,
                    num_heads = 2,
                    gat_heads = 2,
                    dropout_trans = 0.5,
                    dropout = 0.5,
                    alpha = 0.1,
                    use_bn = True,
                    lammda = 0.1,
                    GraphConv = "gcn"
                    ).to(device)

In [12]:
import torch_geometric.typing
torch_geometric.typing.WITH_TORCH_SPLINE_CONV

False

In [13]:
model.parameters()

<generator object Module.parameters at 0x7aa514b66b90>