### Import

In [1]:
import gc 
import os
import threading
import tqdm
import time
import copy
import random
from datetime import datetime

import numpy as np
import pandas as pd
from itertools import *

from rdkit import Chem

import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter


from torch_geometric.loader.link_neighbor_loader import LinkNeighborLoader
import torch_geometric.transforms as T
from torch_geometric.data import (
                                    HeteroData,
                                    Data, 
                                    Batch
                                 )   
from torch_geometric.nn import (
                                GATv2Conv,
                                SAGPooling,
                                global_add_pool,
                                HeteroConv,
                                Linear,
                                to_hetero
                                )

from sklearn.model_selection import StratifiedShuffleSplit, KFold, train_test_split, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score,
    roc_auc_score, 
    precision_recall_curve, 
    auc, 
    average_precision_score, 
    matthews_corrcoef,
    roc_curve
    )
import sklearn.metrics

### Seed all randomness

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Usage example:
seed_everything(29)  # Set the seed to 42

### Load HeteroData - idse

In [3]:
# data_dict = data.to_dict()
fnm = '../prep_data/hetero_graph/hetero_data_idse_dict.pt'
data = torch.load(fnm)

In [4]:
data

HeteroData(
  [1mdrug[0m={
    node_id=[1007],
    fpt=[1007, 128],
    mpnn=[1007, 617]
  },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] }
)

### Load Transformation Maps

In [5]:
DB_TO_ID_DICT = {}
drug_id_mol_graph_tup = []
ID_TO_DB_DICT = {}
MEDRAID_TO_ID_DICT = {}
ID_TO_MEDRAID_DICT = {}

In [6]:
dict_list = [DB_TO_ID_DICT, ID_TO_DB_DICT, MEDRAID_TO_ID_DICT, ID_TO_MEDRAID_DICT, drug_id_mol_graph_tup]
file_names = ['db_to_id.pt', 'id_to_db.pt', 'uml_to_id.pt', 'id_to_uml.pt', 'drug_to_mol.pt']

for data_dict, fnm in zip(dict_list, file_names):
    full_path = f"../prep_data/hetero_graph/{fnm}"
    loaded_data = torch.load(full_path)
    
    if isinstance(data_dict, dict):
        data_dict.update(loaded_data)
    elif isinstance(data_dict, list):
        data_dict.extend(loaded_data)
    else:
        # If it's neither a dict nor a list, just replace it
        index = dict_list.index(data_dict)
        dict_list[index] = loaded_data

### HeteroData Undirected

In [7]:
# data = T.ToUndirected()(data)

### Molecule Featurization Utils

In [8]:
# Bond featurization
def get_bond_features(bond):
    # Simplified list of bond types
    permitted_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, 
                            Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, 'Unknown']
    bond_type = bond.GetBondType() if bond.GetBondType() in permitted_bond_types else 'Unknown'
    
    # Features: Bond type, Is in a ring
    features = one_of_k_encoding_unk(bond_type, permitted_bond_types) \
               + [bond.IsInRing()]
    
    return np.array(features, dtype=np.float32)

def get_mol_edge_list_and_feat_mtx(mol_graph):
    n_features = [(atom.GetIdx(), atom_features(atom)) for atom in mol_graph.GetAtoms()]
    n_features.sort() # to make sure that the feature matrix is aligned according to the idx of the atom
    _, n_features = zip(*n_features)
    # n_features = torch.stack(n_features)
    n_features = torch.tensor(n_features, dtype=torch.float32)

    edge_list = torch.LongTensor([(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol_graph.GetBonds()])
    undirected_edge_list = torch.cat([edge_list, edge_list[:, [1, 0]]], dim=0) if len(edge_list) else edge_list 

    # Extract bond features
    bond_features = [get_bond_features(bond) for bond in mol_graph.GetBonds()]
    undirected_bond_features = bond_features + bond_features  # duplicate for undirected edges
    edge_attr = torch.tensor(undirected_bond_features, dtype=torch.float32)

    return undirected_edge_list.T, n_features, edge_attr 


def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def all_of_k_encoding_unk(x, allowable_set):
    enc = np.zeros(len(allowable_set))
    for idx, side_eff_id in enumerate(allowable_set):
        if side_eff_id in x:
            enc[idx] = 1
    return enc
    
def atom_features(atom,
                explicit_H=True,
                use_chirality=False):

    results = one_of_k_encoding_unk(
        atom.GetSymbol(),
        ['C','N','O', 'S','F','Si','P', 'Cl','Br','Mg','Na','Ca','Fe','As','Al','I','B','V','K','Tl',
            'Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn','H', 'Li','Ge','Cu','Au','Ni','Cd','In',
            'Mn','Zr','Cr','Pt','Hg','Pb','Unknown'
        ]) + [atom.GetDegree()/10, atom.GetImplicitValence(), 
                atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
                one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
                                    SP3D, Chem.rdchem.HybridizationType.SP3D2
                ]) + [atom.GetIsAromatic()]
    # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
    if explicit_H:
        results = results + [atom.GetTotalNumHs()]

    if use_chirality:
        try:
            results = results + one_of_k_encoding_unk(
            atom.GetProp('_CIPCode'),
            ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
        except:
            results = results + [False, False
                            ] + [atom.HasProp('_ChiralityPossible')]

    results = np.array(results).astype(np.float32)

    return results #torch.from_numpy(results)

### Molecule Featurization

In [12]:
MOL_EDGE_LIST_FEAT_MTX = {DB_TO_ID_DICT[drug_id]: get_mol_edge_list_and_feat_mtx(mol) 
                                for drug_id, mol in drug_id_mol_graph_tup}
len(MOL_EDGE_LIST_FEAT_MTX.keys())



1007

In [13]:
MOL_EDGE_LIST_FEAT_MTX[998][1].shape, MOL_EDGE_LIST_FEAT_MTX[998][2].shape

(torch.Size([57, 55]), torch.Size([16, 6]))

### CV Split

In [13]:
def get_kfold_data_idse(data, k=10, shuffle=True):
    kf = KFold(n_splits=k, shuffle=shuffle)
    kf.get_n_splits()
    train_val_data_X = data['drug', 'known', 'side_effect'].edge_index.T.numpy()
    for train_index, test_index in kf.split(train_val_data_X):
        train_index_, valid_index_ = train_test_split(train_index, test_size=0.1)
        train_set = train_index_
        valid_set = valid_index_
        
        train_data_cv = copy.deepcopy(data)
        train_data_cv['drug', 'known', 'side_effect'].edge_index = torch.tensor(train_val_data_X[train_set].T)
        # train_data_cv['side_effect', 'rev_known', 'drug'].edge_index = torch.tensor(train_val_data_X[train_set].T)[[1, 0]]
    
        val_data_cv = copy.deepcopy(data)
        val_data_cv['drug', 'known', 'side_effect'].edge_index = torch.tensor(train_val_data_X[valid_set].T)
        # val_data_cv['side_effect', 'rev_known', 'drug'].edge_index = torch.tensor(train_val_data_X[valid_set].T)[[1, 0]]
    
        
        
        test_data_cv = copy.deepcopy(data)
        test_data_cv['drug', 'known', 'side_effect'].edge_index = torch.tensor(train_val_data_X[test_index].T)
        # test_data_cv['side_effect', 'rev_known', 'drug'].edge_index = torch.tensor(train_val_data_X[test_index].T)[[1, 0]]
        
        # use RandomLinkSplit to get disjoint train ratio an other pyg transforms
        transform = T.RandomLinkSplit(
            num_val=0.0,
            num_test=0.0,
            # disjoint_train_ratio=0.3236238313900354,
            neg_sampling_ratio=1.0,
            add_negative_train_samples=True,
            edge_types=('drug', 'known', 'side_effect'),
            # rev_edge_types=('side_effect', 'rev_known', 'drug'), 
        )
        train_cv, _, _ = transform(train_data_cv)
        
        val_cv, _, _ = transform(val_data_cv)
        del val_cv[('drug', 'known', 'side_effect')].edge_index
        
        test_cv, _, _ = transform(test_data_cv)
        del test_cv[('drug', 'known', 'side_effect')].edge_index
        yield train_cv, val_cv, test_cv


In [14]:
train_, val_, test_ = next(get_kfold_data_idse(data))
train_, val_, test_

(HeteroData(
   [1mdrug[0m={
     node_id=[1007],
     fpt=[1007, 128],
     mpnn=[1007, 617]
   },
   [1mside_effect[0m={ node_id=[5587] },
   [1m(drug, known, side_effect)[0m={
     edge_index=[2, 106970],
     edge_label=[213940],
     edge_label_index=[2, 213940]
   }
 ),
 HeteroData(
   [1mdrug[0m={
     node_id=[1007],
     fpt=[1007, 128],
     mpnn=[1007, 617]
   },
   [1mside_effect[0m={ node_id=[5587] },
   [1m(drug, known, side_effect)[0m={
     edge_label=[23772],
     edge_label_index=[2, 23772]
   }
 ),
 HeteroData(
   [1mdrug[0m={
     node_id=[1007],
     fpt=[1007, 128],
     mpnn=[1007, 617]
   },
   [1mside_effect[0m={ node_id=[5587] },
   [1m(drug, known, side_effect)[0m={
     edge_label=[26414],
     edge_label_index=[2, 26414]
   }
 ))

In [15]:
# val_[('drug', 'known', 'side_effect')].edge_index

### Model

#### idse-HE Model

In [16]:
class HetAgg(nn.Module):
    def __init__(self, args, dropout, mpnn_feature, fpt_feature, drug_se_train, se_drug_train):
        super(HetAgg, self).__init__()
        self.embed_d = args['hidden']
        self.D_n = args['D_n']
        self.S_n = args['S_n']
        self.args = args
        self.dropout = dropout
        self.dim = int(args['hidden'] / 2)
        self.drug_dim = 617
        self.fpt_dim = 128
        
        self.Wd = nn.Linear(self.embed_d, self.dim)
        self.Ws = nn.Linear(self.embed_d, self.dim)
        self.bnd = nn.BatchNorm1d(self.embed_d)
        self.bns = nn.BatchNorm1d(self.embed_d)
        
        self.Gd = nn.Linear((self.embed_d + self.drug_dim + self.fpt_dim), self.embed_d)
        self.Gs = nn.Linear(2 * self.embed_d, self.embed_d)
        self.bd = nn.BatchNorm1d(self.embed_d + self.drug_dim + self.fpt_dim)
        self.bs = nn.BatchNorm1d(2 * self.embed_d)
        
        self.Gd2 = nn.Linear((self.embed_d + self.drug_dim + self.fpt_dim), self.embed_d)
        self.Gs2 = nn.Linear(2 * self.embed_d, self.embed_d)
        self.bd2 = nn.BatchNorm1d(self.embed_d + self.drug_dim + self.fpt_dim)
        self.bs2 = nn.BatchNorm1d(2 * self.embed_d)
        
        self.Fd = nn.Linear(self.drug_dim, self.dim)
        self.Fs = nn.Linear(self.embed_d, self.embed_d)
        self.Ft = nn.Linear(self.fpt_dim, self.dim)
        
        se_feature = nn.Parameter(torch.Tensor(self.S_n, self.embed_d))
        se_feature.data.normal_(0, 0.1)

        self.softmax = nn.Softmax(dim=1)
        self.act = nn.LeakyReLU(args['alpha'])

        self.drug_feature = mpnn_feature
        self.fpt_feature = fpt_feature
        self.se_feature = se_feature
        self.drug_se_train = drug_se_train
        self.se_drug_train = se_drug_train


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, std=0.1)


    def node_het_agg(self, node_type, drug_agg, se_agg):

        if node_type == 1:
            d_s_agg = torch.matmul(self.drug_se_train, se_agg)
            concate_embed = torch.cat((self.drug_feature, d_s_agg, self.fpt_feature), 1)
            concate_embed = self.act(self.Gd(self.bd(concate_embed)))
        elif node_type == 2:
            s_d_agg = torch.matmul(self.se_drug_train, drug_agg)
            concate_embed = torch.cat((self.se_feature, s_d_agg), 1)
            concate_embed = self.act(self.Gs(self.bs(concate_embed)))
        elif node_type == 3:
            d_s_agg = torch.matmul(self.drug_se_train, se_agg)
            concate_embed = torch.cat((self.drug_feature, d_s_agg, self.fpt_feature), 1)
            concate_embed = self.act(self.Gd2(self.bd2(concate_embed)))
        elif node_type == 4:
            s_d_agg = torch.matmul(self.se_drug_train, drug_agg)
            concate_embed = torch.cat((self.se_feature, s_d_agg), 1)
            concate_embed = self.act(self.Gs2(self.bs2(concate_embed)))

        return concate_embed

    def forward(self):
        drug_embedding = self.Fd(self.drug_feature)
        se_embedding = self.Fs(self.se_feature)
        fpt_embedding = self.Ft(self.fpt_feature)
        
        drug_embedding = torch.cat((drug_embedding, fpt_embedding), 1)
        
        drug_embedding = nn.functional.dropout(drug_embedding, self.dropout, training=self.training)
        se_embedding = nn.functional.dropout(se_embedding, self.dropout, training=self.training)

        drug_embeddings_t = self.node_het_agg(1, drug_embedding, se_embedding)
        se_embeddings_t = self.node_het_agg(2, drug_embedding, se_embedding)
        
        drug_embeddings_t = nn.functional.dropout(drug_embeddings_t, self.dropout, training=self.training)
        se_embeddings_t = nn.functional.dropout(se_embeddings_t, self.dropout, training=self.training)
        
        drug_embedding_e = self.node_het_agg(3, drug_embeddings_t, se_embeddings_t)
        se_embedding_e = self.node_het_agg(4, drug_embeddings_t, se_embeddings_t)
        
        drug_embedding_e = nn.functional.dropout(drug_embedding_e, self.dropout, training=self.training)
        se_embedding_e = nn.functional.dropout(se_embedding_e, self.dropout, training=self.training)
        
        drug_embeddings = self.Wd(self.bnd(drug_embedding_e))
        se_embeddings = self.Ws(self.bns(se_embedding_e))
        
        outputs = torch.mm(drug_embeddings, se_embeddings.t())
        return outputs


#### MHGNN Hetero

In [48]:
class HeteroMHGNN(nn.Module):
    def __init__(self, metadata, in_channels, hidden_dims, heads, use_edge_attr=None):
        super().__init__()
        
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleDict()
        self.skips = nn.ModuleDict()
        self.final_norms = nn.ModuleDict()
        
        # Define which edge types should use edge attributes
        if use_edge_attr is None:
            use_edge_attr = {edge_type: False for edge_type in metadata[1]}
        
        for i, (out_dim, head) in enumerate(zip(hidden_dims, heads)):
            conv_dict = {}
            for edge_type in metadata[1]:
                src, _, dst = edge_type
                if i == 0:
                    in_channels = in_channels
                else:
                    in_channels = hidden_dims[i-1] * heads[i-1]
                
                if use_edge_attr[edge_type]:
                    conv_dict[edge_type] = GATv2Conv(in_channels, out_dim, heads=head, add_self_loops=False, edge_dim=1)
                else:
                    conv_dict[edge_type] = GATv2Conv(in_channels, out_dim, heads=head, add_self_loops=False)
            
            self.convs.append(HeteroConv(conv_dict, aggr='sum'))
            
            for node_type in metadata[0]:
                self.norms[f'{node_type}_{i}'] = nn.LayerNorm(out_dim * head)
                if i == 0:
                    self.skips[f'{node_type}_{i}'] = Linear(in_channels, out_dim * head)
                else:
                    self.skips[f'{node_type}_{i}'] = Linear(hidden_dims[i-1] * heads[i-1], out_dim * head)
        
        self.node_types = metadata[0]
        for node_type in metadata[0]:
            self.final_norms[f'{node_type}'] = nn.LayerNorm(out_dim * head *len(heads))
        
        # Initialize skips with xavier init
        for skip in self.skips.values():
            nn.init.xavier_uniform_(skip.weight)
            nn.init.zeros_(skip.bias)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        x_repr_dict = {node_type: [] for node_type in self.node_types}
        # edge_attr_dict = {key: value.to(torch.float32) for key, value in edge_attr_dict.items()}

        
        for i, conv in enumerate(self.convs):
            skip_x = {}
            for node_type in self.node_types:
                skip_x[node_type] = self.skips[f'{node_type}_{i}'](x_dict[node_type])
            
            x_dict_new = conv(x_dict, edge_index_dict, edge_attr_dict)
            
            for node_type in self.node_types:
                # skip_x = self.skips[f'{node_type}_{i}'](x_dict[node_type])
                x = x_dict_new[node_type]
                x = self.norms[f'{node_type}_{i}'](x) + skip_x[node_type]
                x = self.norms[f'{node_type}_{i}'](x)
                x = F.elu(x)
                x_repr_dict[node_type].append(x)
                x_dict[node_type] = x
        
        # Concatenate all representations for each node type
        for node_type in self.node_types:
            x_repr_dict[node_type] = self.final_norms[f'{node_type}'](torch.cat(x_repr_dict[node_type], dim=1))
        
        return x_repr_dict

# Specify which edge types should use edge attributes
use_edge_attr = {
    ('drug', 'known', 'side_effect'): False,
    ('drug', 'struct', 'drug'): True,
    ('drug', 'word', 'drug'): True,
    ('drug', 'target', 'drug'): True,
    ('drug', 'se_encoded', 'drug'): True,
    ('side_effect', 'name', 'side_effect'): True,
    ('side_effect', 'dg_encoded', 'side_effect'): True,
    ('side_effect', 'atc', 'side_effect'): True,
    ('side_effect', 'rev_known', 'drug'): False
}


#### MHGNN - Outer HGNN

In [49]:
# class MHGNN(nn.Module):
#     def __init__(self, input_dim, hidden_dims, heads):
#         super().__init__()
#         self.GATLayers = nn.ModuleList()
#         self.norms = nn.ModuleList() 
#         self.skips = nn.ModuleList()
#         for i, (out_dim, head) in enumerate(zip(hidden_dims, heads)):
#             self.GATLayers.append(GATv2Conv(input_dim, out_dim, heads=head, add_self_loops=False, name=f'GATLayer{i}'))
#             self.norms.append(nn.LayerNorm(out_dim * head))
#             self.skips.append(nn.Linear(input_dim, out_dim * head))
#             input_dim = out_dim * head      
        
#         # # # initialize skips with xavier init
#         for skip in self.skips:
#             nn.init.xavier_uniform_(skip.weight)
#             nn.init.zeros_(skip.bias)

#     def forward(self, x, edge_index, edge_attr):
#         x_repr = []
#         for idx, (layer, skip, norm) in enumerate(zip(self.GATLayers, self.skips, self.norms)): # norm, self.norms
#             skip_x = skip(x)
#             x = layer(x, edge_index)
#             x = norm(x) + skip_x  # Add skip connection
#             x = norm(x)     # Apply normalization
#             x = F.elu(x)    # Apply activation
#             x_repr.append(x)
#             # x = F.elu(norm(x))
#             # x += skip_x
#             # x_repr.append(F.elu(x))
#             # if idx < len(self.GATLayers) - 1:
#             #     x = F.elu(x) # norm(x)
#         x_repr = torch.cat(x_repr, dim=1)
#         return x_repr

#### DVModel

In [50]:
class DrugInterView_Block(nn.Module):
    def __init__(self, n_heads, in_features, head_out_feats):
        super().__init__()
        self.n_heads = n_heads
        self.in_features = in_features
        self.out_features = head_out_feats

        self.feature_conv = GATv2Conv(in_features, head_out_feats, n_heads, edge_dim=6)

        self.readout = SAGPooling(n_heads * head_out_feats, min_score=-1)

    def forward(self, mol_data):
        mol_data.x = self.feature_conv(mol_data.x, mol_data.edge_index, mol_data.edge_attr)
        mol_data_att_x, att_edge_index, att_edge_attr, h_att_batch, att_perm, h_att_scores = self.readout(mol_data.x, mol_data.edge_index, batch=mol_data.batch)

        mol_data_global_graph_emb = global_add_pool(mol_data_att_x, h_att_batch)

        return mol_data, mol_data_global_graph_emb, h_att_scores, h_att_batch

In [51]:
class FinalDrugMolEmb(nn.Module):
    def __init__(self, in_features, heads_out_feat_params, blocks_params):
        super().__init__()
        self.in_features = in_features
        self.n_blocks = len(blocks_params)

        self.inital_norm = nn.LayerNorm(self.in_features)

        self.blocks = nn.ModuleList()
        self.net_norms = nn.ModuleList()

        for i, (head_out_feats, n_heads) in enumerate(zip(heads_out_feat_params, blocks_params)):
            block = DrugInterView_Block(n_heads, in_features, head_out_feats)
            self.blocks.append(block)
            self.net_norms.append(nn.LayerNorm(head_out_feats * n_heads))
            in_features = head_out_feats * n_heads
       
    def forward(self, mol_data):
        repr_mol = []
        mol_data.x = self.inital_norm(mol_data.x)
        attention_weights = []
        attention_batch = []
        for idx, (block, norm) in enumerate(zip(self.blocks, self.net_norms)):
            mol_data, mol_data_global_graph_emb, mol_data_att_x, h_att_batch = block(mol_data)
            attention_weights.append(mol_data_att_x)
            attention_batch.append((mol_data.batch, h_att_batch))
            repr_mol.append(mol_data_global_graph_emb)
            if idx < len(self.blocks) - 1:
                mol_data.x = F.elu(norm(mol_data.x))
        # concat all the global graph embeddings
        mol_data_global_graph_emb = torch.cat(repr_mol, dim=1)
        return mol_data_global_graph_emb, attention_weights, attention_batch

In [64]:
class DVModel(torch.nn.Module):
    def __init__(self, hidden_channels, gnn_model, classifier_model, use_node_features=False, node_feature_mode="no"):
        super().__init__()
        # Instantiate node embeddings:
        self.seff_emb = torch.nn.Embedding(data["side_effect"].num_nodes, hidden_channels)
        # DV for Drug
        self.drug_emb = FinalDrugMolEmb(in_features=55, heads_out_feat_params=[64, 64], blocks_params=[3, 3])
        outer_emb_dim = hidden_channels
        self.inital_norm_outer_drug = nn.LayerNorm(outer_emb_dim)
        self.inital_norm_outer_se = nn.LayerNorm(outer_emb_dim)
        # Instantiate Outer GNNs
        self.gnn = gnn_model # outer message passing
        self.use_node_features = use_node_features
        self.node_feature_mode = node_feature_mode  # combine, feat
        if use_node_features:
            self.drug_feat_layernorm = torch.nn.LayerNorm(data["drug"].num_features)
            self.drug_lin = torch.nn.Linear(data["drug"].num_features, hidden_channels)
        # Instantiate classifier:
        self.classifier = classifier_model
        
        torch.nn.init.xavier_uniform_(self.seff_emb.weight)
        
    def __create_graph_data(self, drug_ids, device):
        drug_ids_ = drug_ids.cpu().numpy().astype(int).tolist()
        final_data = []
        for id in drug_ids_:
            _ = MOL_EDGE_LIST_FEAT_MTX[id]
            final_data.append(Data(x= _[1]  , edge_index=_[0], edge_attr=_[2]))
        return Batch.from_data_list(final_data).to(device)        
       
    
    def forward(self, data: HeteroData) -> Tensor:
        if self.use_node_features:
            if self.node_feature_mode=="feat":
                x_dict = {
                    "drug": self.drug_lin(self.drug_feat_layernorm(data["drug"].x)),
                    "side_effect": self.seff_emb(data["side_effect"].node_id)
                }
        else:
            drug_list_of_graph_data = self.__create_graph_data(data["drug"].node_id, data["drug"].node_id.device)
            
            drug_dv, attention_weights, h_att_batch = self.drug_emb(drug_list_of_graph_data)
            if self.node_feature_mode=="combined":
                drug_dv += self.drug_lin(self.drug_feat_layernorm(data["drug"].x))
            
            # layer normalization of input features for outer gnn:
            x_dict = {
                "drug":  self.inital_norm_outer_drug(drug_dv),
                "side_effect": self.inital_norm_outer_se(self.seff_emb(data["side_effect"].node_id)) 
            }

        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        # x_dict = self.gnn(x_dict, data.edge_index_dict, data.edge_attr_dict)
        # Forward pass
        x_dict = self.gnn(x_dict, data.edge_index_dict, data.edge_attr_dict)
        pred = self.classifier(
            x_dict["drug"],
            x_dict["side_effect"],
            data["drug", "known", "side_effect"].edge_label_index,
        )

        return pred, attention_weights, h_att_batch

#### Edge Classifier

In [65]:
# Our final classifier applies the hammard-product between source and destination
# node embeddings to derive edge-level predictions:
class VanillaClassifier(torch.nn.Module):
    def forward(self, x_drug: Tensor, x_se: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_drug = x_drug[edge_label_index[0]]
        edge_feat_se = x_se[edge_label_index[1]]

        # Apply hammard-product to get a prediction per supervision edge:
        return (edge_feat_drug * edge_feat_se).sum(dim=-1)

### Train Utils

#### idse-HE train loop

In [47]:
# 185965//num_side_effects, 185965%num_side_effects # row, col

In [17]:
def row_normalize(a_matrix):
    a_matrix = a_matrix.astype(float)
    row_sums = a_matrix.sum(axis=1) + 1e-12
    new_matrix = a_matrix / row_sums[:, np.newaxis]
    new_matrix[np.isnan(new_matrix) | np.isinf(new_matrix)] = 0.0
    return new_matrix

def dse_normalize(cuda, drug_se, D_n=1020, S_n=5599):
    se_drug = drug_se.T
    drug_se_normalize = torch.from_numpy(row_normalize(drug_se)).float()
    se_drug_normalize = torch.from_numpy(row_normalize(se_drug)).float()
    if cuda:
        drug_se_normalize = drug_se_normalize.cuda()
        se_drug_normalize = se_drug_normalize.cuda()
    return drug_se_normalize, se_drug_normalize

def save_all(final_outputs, test_mask, fold, path="idse-result"):
    np.save(path + 'result' + str(fold), final_outputs.cpu().detach().numpy())
    np.save(path + 'mask' + str(fold), test_mask.cpu().detach().numpy())
    

In [18]:
    
def mrank(y, y_pre):
    index = np.argsort(-y_pre)
    r_label = y[index]
    r_index = np.array(np.where(r_label == 1)) + 1
    reci_sum = np.sum(1 / r_index)
    reci_rank = np.mean(1 / r_index)
    return reci_sum

def validation(y_pre, y, flag=False):
    prec, recall, _ = precision_recall_curve(y, y_pre)
    pr_auc = sklearn.metrics.auc(recall, prec)
    fpr, tpr, threshold = roc_curve(y, y_pre)
    roc_auc = sklearn.metrics.auc(fpr, tpr)
    if flag:
        ap = average_precision_score(y, y_pre)
        mr = mrank(y, y_pre)
        y_predict_class = y_pre
        y_predict_class[y_predict_class > 0.5] = 1
        y_predict_class[y_predict_class <= 0.5] = 0
        prec = precision_score(y, y_predict_class)
        recall = recall_score(y, y_predict_class)
        mcc = matthews_corrcoef(y, y_predict_class)
        f1 = f1_score(y, y_predict_class)
        return roc_auc, pr_auc, prec, recall, mcc, f1, ap, mr
    return roc_auc, pr_auc, _, _, _, _, _, _

def binary_cross_entropy_loss(inputs, targets):
    criteria = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([40]).cuda())
    loss = criteria(inputs, targets)
    return loss

def train(model, optimizer, mask, target, train_idx, train_set):
    model.train()
    optimizer.zero_grad()
    outputs = model()
    output = torch.flatten(torch.mul(mask, outputs))
    loss_train = binary_cross_entropy_loss(output, target)
    loss_train.backward()
    optimizer.step()
    output = output[train_idx]
    noutput = torch.sigmoid(output).cpu().detach().numpy()
    metrics = validation(noutput, train_set)
    return loss_train.data.item(), metrics[0], metrics[1], outputs

def compute_test(test_set, outputs, mask, test_idx, flag=False):
    output = torch.flatten(torch.mul(mask, outputs))[test_idx]
    noutput = torch.sigmoid(output).cpu().detach().numpy()
    metrics = validation(noutput, test_set, flag)
    return metrics

In [19]:
counter = 1
auc_arr = []
aupr_arr = []
mcc_arr = []
f1_arr = []
prec_arr = []
recall_arr = []
ap_arr = []
mr_arr = []
valid_aupr_arr = []
# data_set is adjacency matrix bewtween data.drug and data,side_effect with values set to 1 if drug known side effect exist
data_set_adj = np.zeros((data["drug"].node_id.shape[0], data["side_effect"].node_id.shape[0]))
edge_index_data = data[('drug', 'known', 'side_effect')].edge_index.numpy()
data_set_adj[edge_index_data[0], edge_index_data[1]] = 1
data_set = data_set_adj.flatten()
D_n, S_n = data_set_adj.shape

for train_cv, val_cv, test_cv in get_kfold_data_idse(data):
#     train_index, valid_index = train_test_split(train_index, test_size=0.05) #test_index
    train_edges = train_cv[('drug', 'known', 'side_effect')].edge_label_index.numpy()
    test_edges = test_cv[('drug', 'known', 'side_effect')].edge_label_index.numpy()
    val_edges = val_cv[('drug', 'known', 'side_effect')].edge_label_index.numpy()
    
    num_side_effects = data_set_adj.shape[1]
    train_index = train_edges[0] * num_side_effects + train_edges[1]
    test_index = test_edges[0] * num_side_effects + test_edges[1]
    valid_index = val_edges[0] * num_side_effects + val_edges[1]

    train_set = data_set[train_index]
    valid_set = data_set[valid_index]
    print("train shape:", train_set.shape, ", valid shape:", valid_set.shape)
    test_set = data_set[test_index]
    print('Begin {}th folder'.format(counter),
          'train_size {}'.format(len(train_index)),
          'train_label {}'.format(np.sum(train_set)),
          'valid_label {}'.format(np.sum(valid_set)),
          'test_label {}'.format(np.sum(test_set)))
    
    train_mask = np.zeros(D_n * S_n)
    train_mask[train_index] = 1
    target = np.multiply(data_set, train_mask)
    matrix = target.reshape(D_n, S_n)
    
    print('train_mask {}'.format(np.sum(train_mask)),
            'matrix {}'.format(np.sum(matrix)))
    
    train_mask = torch.from_numpy(train_mask.reshape(D_n, S_n)).cuda()
    target = torch.from_numpy(target).cuda()
    cuda = torch.cuda.is_available()
    drug_se_train, se_drug_train = dse_normalize(cuda, matrix, D_n=D_n, S_n=S_n)
    
    test_mask = np.zeros(D_n * S_n)
    test_mask[test_index] = 1
    test_mask = torch.from_numpy(test_mask.reshape(D_n, S_n)).cuda()

    valid_mask = np.zeros(D_n * S_n)
    valid_mask[valid_index] = 1
    valid_mask = torch.from_numpy(valid_mask.reshape(D_n, S_n)).cuda()
    dropout = 0.1
    lr = 0.001
    weight_decay = 5e-3
    epochs = 2000
    patience = 500
    mpnn_feature = data["drug"].mpnn.float().cuda()
    fpt_feature = data["drug"].fpt.float().cuda()
    result_path = "idse-result"
    
    model_args = {"hidden": 1024, "D_n": D_n, "S_n": S_n, "alpha": 0.02}
    model = HetAgg(model_args, dropout, mpnn_feature, fpt_feature, drug_se_train, se_drug_train)
    model.init_weights()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    model.cuda()
    # Train model
    t_total = time.time()
    bad_counter = 0
    best_epoch = 0
    best_pr = 0
    final_outputs = []
    
    skf = StratifiedKFold(n_splits=100, shuffle=False) # batch
    
    for epoch in range(epochs):
        auc, aupr, outputs = [], [], []
        loss = 0
        t = time.time()
        loss, train_auc, train_aupr, outputs = train(model, optimizer, train_mask, target, train_index, train_set)
        print('folder: {}'.format(counter),
          'Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss),
          'train_auc: {:.4f}'.format(train_auc),
          'train_aupr: {:.4f}'.format(train_aupr),
          'time: {:.4f}s'.format(time.time() - t),
          'lr:', optimizer.defaults['lr'])
        valid_metrics = compute_test(valid_set, outputs, valid_mask, valid_index)
        valid_auc, valid_aupr = valid_metrics[0], valid_metrics[1]
        test_metrics = compute_test(test_set, outputs, test_mask, test_index)
        test_auc, test_aupr = test_metrics[0], test_metrics[1]
        print("Valid set results:",
              "folder= {}".format(counter),
              'Epoch: {:04d}'.format(epoch+1),
              'valid_auc: {:.4f}'.format(valid_auc),
              'valid_aupr: {:.4f}'.format(valid_aupr))
        print("Test set results:",
              "folder= {}".format(counter),
              'Epoch: {:04d}'.format(epoch+1),
              'test_auc: {:.4f}'.format(test_auc),
              'test_aupr: {:.4f}'.format(test_aupr),
              'Best_epoch: {:04d}'.format(best_epoch+1))
        if valid_aupr > best_pr:
            best_pr = valid_aupr
            best_epoch = epoch
            bad_counter = 0
            final_outputs = outputs
        else:
            bad_counter += 1
    
        if bad_counter >= patience:
            break
    
    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
    print('Loading {}th epoch'.format(best_epoch))

    # Testing
    # save_result(final_outputs, data_set, test_mask, counter)
    save_all(final_outputs, test_mask, counter)
    test_auc, test_aupr, prec, recall, mcc, f1, ap, mr = compute_test(test_set, final_outputs, test_mask, test_index, True)
    print("Test set results:",
          "folder= {}".format(counter),
          'test_auc: {:.4f}'.format(test_auc),
          'test_aupr: {:.4f}'.format(test_aupr),
          'test_prec: {:.4f}'.format(prec),
          'test_recall: {:.4f}'.format(recall),
          'test_mcc: {:.4f}'.format(mcc),
          'test_f1: {:.4f}'.format(f1),
          'test_ap: {:.4f}'.format(ap),
          'test_mr: {:.4f}'.format(mr))
    valid_aupr_arr.append(best_pr)
    auc_arr.append(test_auc)
    aupr_arr.append(test_aupr)
    mcc_arr.append(mcc)
    f1_arr.append(f1)
    prec_arr.append(prec)
    recall_arr.append(recall)
    ap_arr.append(ap)
    mr_arr.append(mr)
    np.savetxt(result_path + 'valid_aupr_avg', [counter, np.mean(np.array(valid_aupr_arr))])
    np.savetxt(result_path + 'auc_avg', [counter, np.mean(np.array(auc_arr))])
    np.savetxt(result_path + 'aupr_avg', [counter, np.mean(np.array(aupr_arr))])
    np.savetxt(result_path + 'mcc_avg', [counter, np.mean(np.array(mcc_arr))])
    np.savetxt(result_path + 'f1_avg', [counter, np.mean(np.array(f1_arr))])
    np.savetxt(result_path + 'prec_avg', [counter, np.mean(np.array(prec_arr))])
    np.savetxt(result_path + 'recall_avg', [counter, np.mean(np.array(recall_arr))])
    np.savetxt(result_path + 'ap_avg', [counter, np.mean(np.array(ap_arr))])
    np.savetxt(result_path + 'mr_avg', [counter, np.mean(np.array(mr_arr))])
    np.savetxt(result_path + 'valid_aupr', np.array(valid_aupr_arr))
    np.savetxt(result_path + 'auc', np.array(auc_arr))
    np.savetxt(result_path + 'aupr', np.array(aupr_arr))
    np.savetxt(result_path + 'mcc', np.array(mcc_arr))
    np.savetxt(result_path + 'f1', np.array(f1_arr))
    np.savetxt(result_path + 'prec', np.array(prec_arr))
    np.savetxt(result_path + 'recall', np.array(recall_arr))
    np.savetxt(result_path + 'ap', np.array(ap_arr))
    np.savetxt(result_path + 'mr', np.array(mr_arr))
    counter += 1

train shape: (213940,) , valid shape: (23772,)
Begin 1th folder train_size 213940 train_label 107479.0 valid_label 12144.0 test_label 13480.0
train_mask 213940.0 matrix 107479.0
folder: 1 Epoch: 0001 loss_train: 34.8040 train_auc: 0.5275 train_aupr: 0.6036 time: 1.6292s lr: 0.001
Valid set results: folder= 1 Epoch: 0001 valid_auc: 0.5229 valid_aupr: 0.6080
Test set results: folder= 1 Epoch: 0001 test_auc: 0.5264 test_aupr: 0.6092 Best_epoch: 0001
folder: 1 Epoch: 0002 loss_train: 17.0624 train_auc: 0.6992 train_aupr: 0.7767 time: 0.0975s lr: 0.001
Valid set results: folder= 1 Epoch: 0002 valid_auc: 0.6875 valid_aupr: 0.7721
Test set results: folder= 1 Epoch: 0002 test_auc: 0.6850 test_aupr: 0.7701 Best_epoch: 0001
folder: 1 Epoch: 0003 loss_train: 11.6735 train_auc: 0.7232 train_aupr: 0.8005 time: 0.0901s lr: 0.001
Valid set results: folder= 1 Epoch: 0003 valid_auc: 0.7009 valid_aupr: 0.7871
Test set results: folder= 1 Epoch: 0003 test_auc: 0.6969 test_aupr: 0.7859 Best_epoch: 0002
fol

KeyboardInterrupt: 

#### Train Loop

In [66]:
def do_train_compute(batch, device, model):
    # batch = batch.to(device)
    pred, _, _ = model(batch)
    actual = batch["drug", "known", "side_effect"].edge_label
    return pred, actual

# def do_train_compute(batch, device, model):
#     # batch = batch.to(device)
#     pred = model(batch)
#     actual = batch.edge_label
#     return pred, actual


def evaluate_metrics(probas_pred, ground_truth):
    # compute binary classification metrics using sklearn
    # convert to numpy array
    probas_pred = probas_pred.numpy()
    
    ground_truth = ground_truth.numpy()
    
    # convert to binary predictions
    binary_pred = np.where(probas_pred > 0.5, 1, 0)

    
    # compute metrics
    accuracy = accuracy_score(ground_truth, binary_pred)
    precision = precision_score(ground_truth, binary_pred)
    recall = recall_score(ground_truth, binary_pred)
    f1 = f1_score(ground_truth, binary_pred)
    roc_auc = roc_auc_score(ground_truth, probas_pred)
    precision_, recall_, _ = precision_recall_curve(ground_truth, probas_pred)
    pr_auc = auc(recall_, precision_)
    average_precision = average_precision_score(ground_truth, probas_pred)
    return accuracy, precision, recall, f1, roc_auc, pr_auc, average_precision

def train_loop(model, model_name, writer, train_loader, val_loader, loss_fn, optimizer, n_epochs, device, scheduler=None, early_stopping_patience=3, early_stopping_counter=0):
    early_stop = False
    best_val_metrics = -float("inf") #-float("inf")
    best_model_path = f"saved_models/{model_name}/best_model.pth"
    # make best_model_path parent directory if it doesn't exist
    os.makedirs(os.path.dirname(best_model_path), exist_ok=True)
    
    print("Starting training loop at", datetime.today().strftime("%Y-%m-%d %H:%M:%S"))
    
    total_train_val_steps = len(train_loader) + len(val_loader)
    epoch_progress_bar = tqdm.notebook.tqdm(range(1, (total_train_val_steps*n_epochs)+1), desc="MiniBatches")
    epoch = 0
    for _ in epoch_progress_bar:
        epoch += 1
        start_time = time.time()
        train_loss = 0
        val_loss = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []
        print("Epoch", epoch)
        
        model.train()
        for idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            lr = optimizer.param_groups[0]['lr']
            optimizer.zero_grad()
            out, actual = do_train_compute(batch, device, model)
            pred = torch.sigmoid(out)
            train_probas_pred.append(pred.detach().cpu())
            train_ground_truth.append(actual.detach().cpu())
            loss = loss_fn(out, actual)
            loss.backward()
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed

            optimizer.step()
            train_loss += loss.item()
            epoch_progress_bar.set_postfix_str(f"Epoch {epoch} - LR {lr:.7f} - Train Batch {idx+1}/{len(train_loader)} - Train loss: {train_loss/(idx+1):.4f}")
            epoch_progress_bar.update()
            writer.add_scalar("Training Loss MiniBatch", loss.item(), idx)
            batch = batch.to("cpu")
            # if scheduler is not None: # cosine annealing scheduler
            #     scheduler.step()
        
        train_loss /= len(train_loader)
        writer.add_scalar("Training Loss Epoch", train_loss, epoch)
        model.eval()
        with torch.no_grad():
            train_probas_pred = torch.cat(train_probas_pred, dim=0)
            train_ground_truth = torch.cat(train_ground_truth, dim=0)
            train_accuracy, train_precision, train_recall, train_f1, \
                train_roc_auc, train_pr_auc, train_average_precision = evaluate_metrics(train_probas_pred, train_ground_truth)
            writer.add_scalar("Training Accuracy", train_accuracy, epoch)
            writer.add_scalar("Training Precision", train_precision, epoch)
            writer.add_scalar("Training Recall", train_recall, epoch)
            writer.add_scalar("Training F1", train_f1, epoch)
            writer.add_scalar("Training ROC AUC", train_roc_auc, epoch)
            writer.add_scalar("Training PR AUC", train_pr_auc, epoch)
            writer.add_scalar("Training Average Precision", train_average_precision, epoch)

            for idx_, batch in enumerate(val_loader):
                batch = batch.to(device)
                out, actual = do_train_compute(batch, device, model)
                pred = torch.sigmoid(out)
                val_probas_pred.append(pred.detach().cpu())
                val_ground_truth.append(actual.detach().cpu())
                loss = loss_fn(out, actual)
                val_loss += loss.item()
                epoch_progress_bar.set_postfix_str(f"Epoch {epoch} - LR {lr:.7f} - Val Batch {idx_+1}/{len(val_loader)} - Val loss: {val_loss/(idx+1):.4f}")
                epoch_progress_bar.update()
                writer.add_scalar("Validation Loss MiniBatch", loss.item(), idx_)
                batch = batch.to("cpu")
            val_loss /= len(val_loader)
            val_probas_pred = torch.cat(val_probas_pred, dim=0)
            val_ground_truth = torch.cat(val_ground_truth, dim=0)
            val_accuracy, val_precision, val_recall, val_f1, \
                val_roc_auc, val_pr_auc, val_average_precision = evaluate_metrics(val_probas_pred, val_ground_truth)
            
            writer.add_scalar("Validation Loss Epoch", val_loss, epoch)
            writer.add_scalar("Validation Accuracy", val_accuracy, epoch)
            writer.add_scalar("Validation Precision", val_precision, epoch)
            writer.add_scalar("Validation Recall", val_recall, epoch)
            writer.add_scalar("Validation F1", val_f1, epoch)
            writer.add_scalar("Validation ROC AUC", val_roc_auc, epoch)
            writer.add_scalar("Validation PR AUC", val_pr_auc, epoch)
            writer.add_scalar("Validation Average Precision", val_average_precision, epoch)
            
            if val_f1 > best_val_metrics:
                best_val_metrics = val_f1
                early_stopping_counter = 0
                torch.save(model.state_dict(), best_model_path)
                print("New best model saved!") 
            else:
                early_stopping_counter += 1
                print("Early stopping counter:", early_stopping_counter)
                if early_stopping_counter >= early_stopping_patience:
                    print("Early stopping triggered!")
                    early_stop = True
        
        if scheduler is not None:
            scheduler.step(val_f1) #
      
        
        epoch_progress_bar.set_postfix_str("Train loss: {:.4f}, Train f1: {:.4f}, Train auc: {:.4f}, Train pr_auc: {:.4f},\
                                            Val loss: {:.4f}, Val f1: {:.4f}, Val auc: {:.4f}, Val pr_auc: {:.4f},\
                                            Best val f1: {:.4f}".format(train_loss, train_f1, train_roc_auc, train_pr_auc,\
                                            val_loss, val_f1, val_roc_auc, val_pr_auc, best_val_metrics))
        epoch_progress_bar.update()
        print("Epoch Number:", epoch)   
        print("Epoch time:", time.time() - start_time)
        print("Train loss:", train_loss)
        print("Train accuracy:", train_accuracy)
        print("Train precision:", train_precision)
        print("Train recall:", train_recall)
        print("Train f1:", train_f1)
        print("Train roc_auc:", train_roc_auc)
        print("Train pr_auc:", train_pr_auc)
        print("Train average_precision:", train_average_precision)
        
        print("Val loss:", val_loss)
        print("Val accuracy:", val_accuracy)
        print("Val precision:", val_precision)
        print("Val recall:", val_recall)
        print("Val f1:", val_f1)
        print("Val roc_auc:", val_roc_auc)
        print("Val pr_auc:", val_pr_auc)
        print("Val average_precision:", val_average_precision)
        print("Best val_f1:", best_val_metrics)
        print()
        if early_stop:
            break
        if epoch == n_epochs:
            print("Training completed!")
            break
    
    # load best model 
    model.load_state_dict(torch.load(best_model_path))
    return model

#### Test Evaluate Metrics

In [67]:
def mrank(y, y_pre):
    index = np.argsort(-y_pre)
    r_label = y[index]
    r_index = np.array(np.where(r_label == 1)) + 1
    reci_sum = np.sum(1 / r_index)
    reci_rank = np.mean(1 / r_index)
    return reci_sum

def evaluate_fold(loader, model, device, ret=False):
    preds = []
    ground_truths = []
    model.eval()
    for sampled_data in tqdm.tqdm(loader):
        with torch.no_grad():
            sampled_data.to(device)
            pred, _, _ = model(sampled_data) 
            # Applying sigmoid activation function to the predicted values
            output_probs = torch.sigmoid(pred)

            preds.append(output_probs)
            ground_truths.append(sampled_data["drug", "known", "side_effect"].edge_label)

    pred = torch.cat(preds, dim=0).cpu().numpy()
    pred_int = (pred>0.5).astype(int)
    ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()

    auc = roc_auc_score(ground_truth, pred)
    ap = average_precision_score(ground_truth, pred)
    mr = mrank(ground_truth, pred)
    f1 = f1_score(ground_truth, pred_int)
    mcc = matthews_corrcoef(ground_truth, pred_int)
    acc = (pred_int == ground_truth).mean()
    precision = precision_score(ground_truth, pred_int)
    recall = recall_score(ground_truth, pred_int)
    print()
    print(f"Test AUC: {auc:.4f}")
    print(f"Test AP: {ap:.4f}")
    print(f"Test F1: {f1:.4f}")
    print(f"Test Accuracy: {acc:.4f}")
    print(f"Test Precission: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"Test MCC: {mcc:.4f}")
    print(f"Test MR: {mr:.4f}")
    if ret:
        return auc, ap, f1, acc, precision, recall, mr, mcc


#### Train Wrap CV

In [68]:
def train_wrap_cv(data, model_name_, cv_fold=10, shuffle=True, num_neighbors=[10, 4], batch_size=64, n_epochs=10, early_stopping_patience=5):
    eval_metrics = []
    for i, (train_loader_cv, val_loader_cv,  test_loader_cv) in enumerate(get_kfold_data(data, k=cv_fold,
                                                                shuffle=shuffle,
                                                                num_neighbors=num_neighbors, 
                                                                batch_size=batch_size)):
        print(f"Fold {i+1}")
        model_name = f"{model_name_}/fold{i+1}"
        # Define the log directory where TensorBoard logs will be stored
        log_dir = f"logs/{model_name}/" + datetime.now().strftime("%Y%m%d-%H%M%S")
        os.makedirs(log_dir, exist_ok=True)

        # Create a SummaryWriter
        writer = SummaryWriter(log_dir)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Device: '{device}'")

        # gnn_model = MHGNN(input_dim=384, hidden_dims=[64, 64, 64], heads=[2, 2, 2])

        # gnn_model = to_hetero(gnn_model, metadata=data.metadata())

        gnn_model = HeteroMHGNN(data.metadata(), in_channels=384, hidden_dims=[64, 64, 64], heads=[2, 2, 2], use_edge_attr=use_edge_attr)
        classifier_model = VanillaClassifier()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = DVModel(hidden_channels=384, gnn_model=gnn_model,
                        classifier_model=classifier_model, use_node_features=False)
        
        model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.00025062971034390006,  weight_decay=0.001)
        
        # optimizer = torch.optim.SGD(model.parameters(), lr=0.00025062971034390006, weight_decay=0.001)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
        #                                                                  T_0=len(train_loader_cv), 
        #                                                                  T_mult=1, eta_min=1e-5, 
        #                                                                  verbose=False)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, patience=2, min_lr=1e-6
            )
        # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    
        print(f"Total Number of Parameters: {sum(p.numel() for p in model.parameters())}")
        print(f"Total Number of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
        
        model = train_loop(model, model_name, writer, train_loader_cv, val_loader_cv,
                             F.binary_cross_entropy_with_logits, optimizer, n_epochs=n_epochs, 
                             device=device, scheduler=scheduler, early_stopping_patience=early_stopping_patience)

        # load best model and store evaluation metrics
        model.load_state_dict(torch.load(f'saved_models/{model_name}/best_model.pth'))
        auc, ap, f1, acc, precision, recall, mr, mcc = evaluate_fold(test_loader_cv, model, torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), ret=True)
        eval_metrics.append([auc, ap, f1, acc, precision, recall, mr, mcc])
        gc.collect()
        torch.cuda.empty_cache()

    return eval_metrics

### Run Train CV 

In [69]:
gc.collect()
torch.cuda.empty_cache()

In [70]:
# ! rm -rf logs/* saved_models/*

In [71]:
eval_metrics = train_wrap_cv(data, "sdv-hgnn-cv", cv_fold=10, shuffle=True, 
                             num_neighbors=[10, 4], batch_size=64)

Fold 1
Device: 'cuda'
Total Number of Parameters: 3901040
Total Number of Trainable Parameters: 3901040
Starting training loop at 2024-08-05 12:59:33


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 116.96567869186401
Train loss: 0.7262760884440099
Train accuracy: 0.793792246808019
Train precision: 0.7917276195393397
Train recall: 0.7973308683343925
Train f1: 0.7945193650062607
Train roc_auc: 0.8487906594326882
Train pr_auc: 0.818858475743854
Train average_precision: 0.818269690230577
Val loss: 0.8867053204904431
Val accuracy: 0.8213648338573978
Val precision: 0.8207107115596641
Val recall: 0.8223846349961758
Val f1: 0.8215468206129554
Val roc_auc: 0.8834400426530578
Val pr_auc: 0.8829409663322565
Val average_precision: 0.8828596370664608
Best val_f1: 0.8215468206129554

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 116.47913646697998
Train loss: 0.40411356455076647
Train accuracy: 0.836905656017101
Train precision: 0.8331143607905861
Train recall: 0.8425963371656364
Train f1: 0.8378285221887117
Train roc_auc: 0.9058975049759517
Train pr_auc: 0.8911589471402481
Train average_precision: 0.891165363237246
Val loss

100%|██████████| 409/409 [00:35<00:00, 11.59it/s]



Test AUC: 0.8901
Test AP: 0.8994
Test F1: 0.8206
Test Accuracy: 0.8263
Test Precission: 0.8483
Test Recall: 0.7947
Test MCC: 0.6539
Test MR: 9.7896
Fold 2
Device: 'cuda'
Total Number of Parameters: 3901040
Total Number of Trainable Parameters: 3901040
Starting training loop at 2024-08-05 13:14:08


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1


KeyboardInterrupt: 

In [23]:
eval_metrics = train_wrap_cv(data, "sdv-hgnn-outer-skip", cv_fold=10, shuffle=True, 
                             num_neighbors=[10, 4], batch_size=64)

Fold 1


2024-08-02 14:59:33.935932: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-08-02 14:59:34.237713: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-02 14:59:34.956283: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-08-02 14:59:34.956438: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not l

Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 14:59:38


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 137.41405987739563
Train loss: 1.0651963151536896
Train accuracy: 0.7722138771737247
Train precision: 0.7707824487802075
Train recall: 0.7748570108036281
Train f1: 0.7728143591812039
Train roc_auc: 0.8263554953855519
Train pr_auc: 0.7902650274956355
Train average_precision: 0.7866666453584267
Val loss: 0.8991208196334217
Val accuracy: 0.8123565904648593
Val precision: 0.8335602141755151
Val recall: 0.7805727883062803
Val f1: 0.8061967875010971
Val roc_auc: 0.8730340215594732
Val pr_auc: 0.87425137760223
Val average_precision: 0.8741948464542115
Best val_f1: 0.8061967875010971

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 141.76551032066345
Train loss: 0.46953649885002213
Train accuracy: 0.8159916806285747
Train precision: 0.8121504394475516
Train recall: 0.8221445490785141
Train f1: 0.8171169360627027
Train roc_auc: 0.8863735832133619
Train pr_auc: 0.8688620853270397
Train average_precision: 0.8688634629641935
Val l

100%|██████████| 409/409 [00:43<00:00,  9.40it/s]



Test AUC: 0.8941
Test AP: 0.8972
Test F1: 0.8248
Test Accuracy: 0.8194
Test Precission: 0.8007
Test Recall: 0.8504
Test MCC: 0.6400
Test MR: 9.7989
Fold 2
Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 15:24:25


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 141.46417593955994
Train loss: 1.0586109393846086
Train accuracy: 0.7807210121901901
Train precision: 0.7805103336797137
Train recall: 0.7810965393725807
Train f1: 0.7808033264993792
Train roc_auc: 0.8339024820274503
Train pr_auc: 0.8010405524148596
Train average_precision: 0.7970219712639836
Val loss: 1.2563508706238202
Val accuracy: 0.7961672473867596
Val precision: 0.8936074090806415
Val recall: 0.6723888841675874
Val f1: 0.7673730662916445
Val roc_auc: 0.8763110835181989
Val pr_auc: 0.8763331632097453
Val average_precision: 0.8763155077174187
Best val_f1: 0.7673730662916445

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 144.17532324790955
Train loss: 0.4674192983209536
Train accuracy: 0.8170171587035646
Train precision: 0.8152414327980927
Train recall: 0.8198336125714946
Train f1: 0.8175310740159295
Train roc_auc: 0.887222183824919
Train pr_auc: 0.8701217321749768
Train average_precision: 0.8701089882763089
V

100%|██████████| 409/409 [00:44<00:00,  9.16it/s]



Test AUC: 0.8876
Test AP: 0.8763
Test F1: 0.8231
Test Accuracy: 0.8294
Test Precission: 0.8548
Test Recall: 0.7937
Test MCC: 0.6606
Test MR: 9.4590
Fold 3
Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 15:49:04


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 142.46999406814575
Train loss: 0.9848194807051731
Train accuracy: 0.7755069616962273
Train precision: 0.77339830872868
Train recall: 0.7793633369923162
Train f1: 0.7763693653511361
Train roc_auc: 0.8316357609226905
Train pr_auc: 0.7996743001059252
Train average_precision: 0.7971672788591346
Val loss: 0.9135174065342416
Val accuracy: 0.8073000764850854
Val precision: 0.8648103309120259
Val recall: 0.7284779468003739
Val f1: 0.7908113842889432
Val roc_auc: 0.8758788141946721
Val pr_auc: 0.8766902261070508
Val average_precision: 0.8766600455541824
Best val_f1: 0.7908113842889432

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 143.46434998512268
Train loss: 0.4738194406197825
Train accuracy: 0.8127130394592409
Train precision: 0.8113370337350091
Train recall: 0.8149228724940782
Train f1: 0.813125999798239
Train roc_auc: 0.8852959554462283
Train pr_auc: 0.8684345854192999
Train average_precision: 0.868441076151338
Val 

100%|██████████| 409/409 [00:38<00:00, 10.71it/s]



Test AUC: 0.9000
Test AP: 0.8953
Test F1: 0.8326
Test Accuracy: 0.8385
Test Precission: 0.8640
Test Recall: 0.8034
Test MCC: 0.6787
Test MR: 9.6666
Fold 4
Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 16:13:50


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 127.60084629058838
Train loss: 1.0915770986824953
Train accuracy: 0.7743081633832111
Train precision: 0.7719358533791524
Train recall: 0.7786700560402103
Train f1: 0.7752883315597229
Train roc_auc: 0.8283704353584287
Train pr_auc: 0.7968576666090832
Train average_precision: 0.7933336311874116
Val loss: 1.048677263143675
Val accuracy: 0.7801053794510071
Val precision: 0.8772030212863355
Val recall: 0.6513979773944081
Val f1: 0.7476225310899781
Val roc_auc: 0.8649389040506184
Val pr_auc: 0.8618156860909141
Val average_precision: 0.8618357521924584
Best val_f1: 0.7476225310899781

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 129.11410212516785
Train loss: 0.48642106081553615
Train accuracy: 0.8103732162458837
Train precision: 0.8071173360011433
Train recall: 0.8156739268588595
Train f1: 0.8113730729997269
Train roc_auc: 0.8823512983568996
Train pr_auc: 0.8644113992152477
Train average_precision: 0.8644359874237489
Val 

100%|██████████| 409/409 [00:46<00:00,  8.71it/s]



Test AUC: 0.8999
Test AP: 0.9039
Test F1: 0.8291
Test Accuracy: 0.8297
Test Precission: 0.8321
Test Recall: 0.8261
Test MCC: 0.6595
Test MR: 9.7103
Fold 5
Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 16:37:08


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 138.59833693504333
Train loss: 1.0045815542092384
Train accuracy: 0.7754058582240453
Train precision: 0.7736509758897818
Train recall: 0.7786122826275348
Train f1: 0.7761237006536325
Train roc_auc: 0.8291872741944155
Train pr_auc: 0.7983468050936354
Train average_precision: 0.7955279473074736
Val loss: 1.9180874624911193
Val accuracy: 0.6912552052349792
Val precision: 0.9582569741396865
Val recall: 0.3999320132574148
Val f1: 0.5643362513490827
Val roc_auc: 0.8774120042172638
Val pr_auc: 0.8888229236409324
Val average_precision: 0.8888318738882108
Best val_f1: 0.5643362513490827

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 129.31488752365112
Train loss: 0.4802443223167039
Train accuracy: 0.8116586746779132
Train precision: 0.8062620642670603
Train recall: 0.820469120110925
Train f1: 0.8133035535320562
Train roc_auc: 0.8838999706079558
Train pr_auc: 0.8664558726720324
Train average_precision: 0.8664581295308955
Val l

100%|██████████| 409/409 [00:39<00:00, 10.45it/s]



Test AUC: 0.8899
Test AP: 0.8763
Test F1: 0.8268
Test Accuracy: 0.8293
Test Precission: 0.8392
Test Recall: 0.8148
Test MCC: 0.6590
Test MR: 9.4312
Fold 6
Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 17:00:02


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 129.186297416687
Train loss: 0.9672775781419053
Train accuracy: 0.7784678490958461
Train precision: 0.7781625115420129
Train recall: 0.7790166965162633
Train f1: 0.7785893697491123
Train roc_auc: 0.8335588714095843
Train pr_auc: 0.801045958976088
Train average_precision: 0.7987745319225844
Val loss: 0.9605957349843305
Val accuracy: 0.8116767230390074
Val precision: 0.8528817473299336
Val recall: 0.7532931078439704
Val f1: 0.8000000000000002
Val roc_auc: 0.8764907930996287
Val pr_auc: 0.882917856279462
Val average_precision: 0.8828638740285847
Best val_f1: 0.8000000000000002

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 130.9055507183075
Train loss: 0.4686463186018128
Train accuracy: 0.8157172569183662
Train precision: 0.8098721329137252
Train recall: 0.8251487665376394
Train f1: 0.8174390819728427
Train roc_auc: 0.8874339705784031
Train pr_auc: 0.8708907905460063
Train average_precision: 0.870899159899064
Val loss: 

100%|██████████| 409/409 [00:44<00:00,  9.14it/s]



Test AUC: 0.8798
Test AP: 0.8834
Test F1: 0.8148
Test Accuracy: 0.8191
Test Precission: 0.8343
Test Recall: 0.7963
Test MCC: 0.6388
Test MR: 9.6486
Fold 7
Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 17:16:07


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 141.2035267353058
Train loss: 1.042460867690952
Train accuracy: 0.7757669420532671
Train precision: 0.7746008917014239
Train recall: 0.7778901149690912
Train f1: 0.7762420189383566
Train roc_auc: 0.8304113848823064
Train pr_auc: 0.7990328297498147
Train average_precision: 0.7961713363838474
Val loss: 0.9019477529977651
Val accuracy: 0.7973145236678848
Val precision: 0.8779302149724533
Val recall: 0.6906603212373588
Val f1: 0.7731164383561644
Val roc_auc: 0.8708430892688767
Val pr_auc: 0.8689556955057286
Val average_precision: 0.8689381700158675
Best val_f1: 0.7731164383561644

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 145.39514803886414
Train loss: 0.47244576165063545
Train accuracy: 0.8157172569183662
Train precision: 0.8160185051322828
Train recall: 0.8152406262637935
Train f1: 0.815629380229181
Train roc_auc: 0.886239250599699
Train pr_auc: 0.8697265592560655
Train average_precision: 0.8697171300573272
Val los

100%|██████████| 409/409 [00:45<00:00,  8.96it/s]



Test AUC: 0.9036
Test AP: 0.9147
Test F1: 0.8170
Test Accuracy: 0.8336
Test Precission: 0.9073
Test Recall: 0.7431
Test MCC: 0.6784
Test MR: 9.8312
Fold 8
Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 17:41:32


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 147.17434215545654
Train loss: 1.0977108314385475
Train accuracy: 0.7734560055462476
Train precision: 0.7714408602150538
Train recall: 0.7771679473106476
Train f1: 0.7742938137654152
Train roc_auc: 0.827495100065944
Train pr_auc: 0.794893816910331
Train average_precision: 0.7909096582024816
Val loss: 0.9556820959095722
Val accuracy: 0.8031358885017421
Val precision: 0.8687965260545906
Val recall: 0.7141157474292513
Val f1: 0.7838985027286721
Val roc_auc: 0.8711614430987776
Val pr_auc: 0.8671392423292468
Val average_precision: 0.8671066182298512
Best val_f1: 0.7838985027286721

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 144.50898361206055
Train loss: 0.4634229708290805
Train accuracy: 0.8160494540412502
Train precision: 0.813855421686747
Train recall: 0.8195447455081172
Train f1: 0.8166901753072916
Train roc_auc: 0.8879848419374137
Train pr_auc: 0.8705860205989439
Train average_precision: 0.870595173710984
Val loss

100%|██████████| 409/409 [00:44<00:00,  9.23it/s]



Test AUC: 0.8897
Test AP: 0.8934
Test F1: 0.8314
Test Accuracy: 0.8290
Test Precission: 0.8200
Test Recall: 0.8431
Test MCC: 0.6583
Test MR: 8.6391
Fold 9
Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 18:06:13


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 141.65959334373474
Train loss: 1.1273305914097924
Train accuracy: 0.7694118666589636
Train precision: 0.7695130761450657
Train recall: 0.7692241030677682
Train f1: 0.7693685624720108
Train roc_auc: 0.8231378131104766
Train pr_auc: 0.7910804676557752
Train average_precision: 0.7872422069983213
Val loss: 1.3699439704651013
Val accuracy: 0.7512110138522988
Val precision: 0.8899736147757256
Val recall: 0.5732982068496643
Val f1: 0.6973691011526335
Val roc_auc: 0.8690431695809967
Val pr_auc: 0.8627509755335933
Val average_precision: 0.8627773104786508
Best val_f1: 0.6973691011526335

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 141.19264817237854
Train loss: 0.4953145984140209
Train accuracy: 0.8107198567219366
Train precision: 0.8115568428674873
Train recall: 0.8093766248772315
Train f1: 0.8104652676337445
Train roc_auc: 0.881034039347756
Train pr_auc: 0.8628854477807723
Train average_precision: 0.862890889212139
Val lo

100%|██████████| 409/409 [00:45<00:00,  9.00it/s]



Test AUC: 0.8829
Test AP: 0.8756
Test F1: 0.8233
Test Accuracy: 0.8292
Test Precission: 0.8527
Test Recall: 0.7958
Test MCC: 0.6598
Test MR: 9.4833
Fold 10
Device: 'cuda'
Total Number of Parameters: 3906032
Total Number of Trainable Parameters: 3906032
Starting training loop at 2024-08-02 18:30:45


MiniBatches:   0%|          | 0/9090 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 143.96696186065674
Train loss: 1.0600285984011102
Train accuracy: 0.7670576000924375
Train precision: 0.7668109668109668
Train recall: 0.7675197873938413
Train f1: 0.767165213374141
Train roc_auc: 0.8249194231286193
Train pr_auc: 0.7959472690379884
Train average_precision: 0.7932254281836104
Val loss: 1.1007291780096358
Val accuracy: 0.8141412424577208
Val precision: 0.8637213421233888
Val recall: 0.7459845330160618
Val f1: 0.8005471956224349
Val roc_auc: 0.8804737259543642
Val pr_auc: 0.8780514697548554
Val average_precision: 0.8780134456894039
Best val_f1: 0.8005471956224349

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 143.54961252212524
Train loss: 0.49623278213076144
Train accuracy: 0.8079467329135132
Train precision: 0.8060284196928377
Train recall: 0.8110809405511583
Train f1: 0.808546787035837
Train roc_auc: 0.8786660266041876
Train pr_auc: 0.8600230548203445
Train average_precision: 0.8600272415418009
V

100%|██████████| 409/409 [00:46<00:00,  8.74it/s]



Test AUC: 0.8989
Test AP: 0.8903
Test F1: 0.8331
Test Accuracy: 0.8370
Test Precission: 0.8538
Test Recall: 0.8133
Test MCC: 0.6748
Test MR: 9.4780


### CV Performance

In [68]:
metrics = ['auc', 'ap', 'f1', 'acc', 'precision', 'recall', 'mr', 'mcc']
metrics_mean_value = np.mean(eval_metrics, axis=0)
metrics_std = np.std(eval_metrics, axis=0)
df = pd.DataFrame({
    'Metric': metrics,
    'Mean': metrics_mean_value,
    'Standard Deviation': metrics_std
})
df

Unnamed: 0,Metric,Mean,Standard Deviation
0,auc,0.89888,0.006975
1,ap,0.909314,0.006861
2,f1,0.826126,0.005576
3,acc,0.826956,0.009283
4,precision,0.832085,0.028054
5,recall,0.821767,0.023088
6,mr,9.752247,0.211021
7,mcc,0.655145,0.018332
