### Import

In [1]:
import gc 
import os
import threading
import tqdm
import time
import pickle
import copy
import random
from datetime import datetime

import numpy as np
import pandas as pd


from rdkit import Chem

import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter


from torch_geometric.loader.link_neighbor_loader import LinkNeighborLoader
import torch_geometric.transforms as T
from torch_geometric.data import (
                                    HeteroData,
                                    Data, 
                                    Batch
                                 )   
from torch_geometric.nn import (
                                GATv2Conv,
                                SAGPooling,
                                global_add_pool,
                                HeteroConv,
                                Linear,
                                to_hetero
                                )

from sklearn.model_selection import StratifiedShuffleSplit, KFold, train_test_split, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score,
    roc_auc_score, 
    precision_recall_curve, 
    auc, 
    average_precision_score, 
    matthews_corrcoef
    )

### Seed all randomness

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Usage example:
seed_everything(29)  # Set the seed to 42

### Load HeteroData

In [3]:
# data_dict = data.to_dict()
fnm = '../prep_data/hetero_graph/hetero_data_dict.pt'
data = torch.load(fnm)

In [4]:
data

HeteroData(
  [1mdrug[0m={ node_id=[1007] },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] },
  [1m(drug, struct, drug)[0m={
    edge_index=[2, 15844],
    edge_attr=[15844]
  },
  [1m(drug, word, drug)[0m={
    edge_index=[2, 83865],
    edge_attr=[83865]
  },
  [1m(drug, target, drug)[0m={
    edge_index=[2, 3363],
    edge_attr=[3363]
  },
  [1m(drug, se_encoded, drug)[0m={
    edge_index=[2, 65854],
    edge_attr=[65854]
  },
  [1m(side_effect, name, side_effect)[0m={
    edge_index=[2, 299170],
    edge_attr=[299170]
  },
  [1m(side_effect, dg_encoded, side_effect)[0m={
    edge_index=[2, 101114],
    edge_attr=[101114]
  },
  [1m(side_effect, atc, side_effect)[0m={
    edge_index=[2, 26140],
    edge_attr=[26140]
  }
)

### Load Transformation Maps

In [5]:
DB_TO_ID_DICT = {}
drug_id_mol_graph_tup = []
ID_TO_DB_DICT = {}
MEDRAID_TO_ID_DICT = {}
ID_TO_MEDRAID_DICT = {}

In [6]:
dict_list = [DB_TO_ID_DICT, ID_TO_DB_DICT, MEDRAID_TO_ID_DICT, ID_TO_MEDRAID_DICT, drug_id_mol_graph_tup]
file_names = ['db_to_id.pt', 'id_to_db.pt', 'uml_to_id.pt', 'id_to_uml.pt', 'drug_to_mol.pt']

for data_dict, fnm in zip(dict_list, file_names):
    full_path = f"../prep_data/hetero_graph/{fnm}"
    loaded_data = torch.load(full_path)
    
    if isinstance(data_dict, dict):
        data_dict.update(loaded_data)
    elif isinstance(data_dict, list):
        data_dict.extend(loaded_data)
    else:
        # If it's neither a dict nor a list, just replace it
        index = dict_list.index(data_dict)
        dict_list[index] = loaded_data

### Variant DV

In [7]:
remove_similarity_edges = [('drug', 'struct', 'drug'),
                           ('drug', 'word', 'drug'),
                           ('drug', 'target', 'drug'),
                           ('drug', 'se_encoded', 'drug'),
                           ('side_effect', 'name', 'side_effect'),
                           ('side_effect', 'dg_encoded', 'side_effect'),
                           ('side_effect', 'atc', 'side_effect')
                            ]
for edge in remove_similarity_edges:
    del data[edge]

In [8]:
data

HeteroData(
  [1mdrug[0m={ node_id=[1007] },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] }
)

### HeteroData Undirected

In [9]:
data = T.ToUndirected()(data)

### Molecule Featurization Utils

In [10]:
# Bond featurization
def get_bond_features(bond):
    # Simplified list of bond types
    permitted_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, 
                            Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, 'Unknown']
    bond_type = bond.GetBondType() if bond.GetBondType() in permitted_bond_types else 'Unknown'
    
    # Features: Bond type, Is in a ring
    features = one_of_k_encoding_unk(bond_type, permitted_bond_types) \
               + [bond.IsInRing()]
    
    return np.array(features, dtype=np.float32)

def get_mol_edge_list_and_feat_mtx(mol_graph):
    n_features = [(atom.GetIdx(), atom_features(atom)) for atom in mol_graph.GetAtoms()]
    n_features.sort() # to make sure that the feature matrix is aligned according to the idx of the atom
    _, n_features = zip(*n_features)
    # n_features = torch.stack(n_features)
    n_features = torch.tensor(n_features, dtype=torch.float32)

    edge_list = torch.LongTensor([(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol_graph.GetBonds()])
    undirected_edge_list = torch.cat([edge_list, edge_list[:, [1, 0]]], dim=0) if len(edge_list) else edge_list 

    # Extract bond features
    bond_features = [get_bond_features(bond) for bond in mol_graph.GetBonds()]
    undirected_bond_features = bond_features + bond_features  # duplicate for undirected edges
    edge_attr = torch.tensor(undirected_bond_features, dtype=torch.float32)

    return undirected_edge_list.T, n_features, edge_attr 


def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def all_of_k_encoding_unk(x, allowable_set):
    enc = np.zeros(len(allowable_set))
    for idx, side_eff_id in enumerate(allowable_set):
        if side_eff_id in x:
            enc[idx] = 1
    return enc
    
def atom_features(atom,
                explicit_H=True,
                use_chirality=False):

    results = one_of_k_encoding_unk(
        atom.GetSymbol(),
        ['C','N','O', 'S','F','Si','P', 'Cl','Br','Mg','Na','Ca','Fe','As','Al','I','B','V','K','Tl',
            'Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn','H', 'Li','Ge','Cu','Au','Ni','Cd','In',
            'Mn','Zr','Cr','Pt','Hg','Pb','Unknown'
        ]) + [atom.GetDegree()/10, atom.GetImplicitValence(), 
                atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
                one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
                                    SP3D, Chem.rdchem.HybridizationType.SP3D2
                ]) + [atom.GetIsAromatic()]
    # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
    if explicit_H:
        results = results + [atom.GetTotalNumHs()]

    if use_chirality:
        try:
            results = results + one_of_k_encoding_unk(
            atom.GetProp('_CIPCode'),
            ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
        except:
            results = results + [False, False
                            ] + [atom.HasProp('_ChiralityPossible')]

    results = np.array(results).astype(np.float32)

    return results #torch.from_numpy(results)

### Molecule Featurization

In [11]:
MOL_EDGE_LIST_FEAT_MTX = {DB_TO_ID_DICT[drug_id]: get_mol_edge_list_and_feat_mtx(mol) 
                                for drug_id, mol in drug_id_mol_graph_tup}
len(MOL_EDGE_LIST_FEAT_MTX.keys())



1007

In [12]:
MOL_EDGE_LIST_FEAT_MTX[998][1].shape, MOL_EDGE_LIST_FEAT_MTX[998][2].shape

(torch.Size([57, 55]), torch.Size([16, 6]))

### CV Split

In [13]:
def get_kfold_data(data, k=10, shuffle=True, num_neighbors=[10, 4], batch_size=64):
    kf = KFold(n_splits=k, shuffle=shuffle)
    kf.get_n_splits()
    train_val_data_X = data['drug', 'known', 'side_effect'].edge_index.T.numpy()
    for train_index, test_index in kf.split(train_val_data_X):
        train_index_, valid_index_ = train_test_split(train_index, test_size=0.1)
        train_set = train_index_
        valid_set = valid_index_
        
        train_data_cv = copy.deepcopy(data)
        train_data_cv['drug', 'known', 'side_effect'].edge_index = torch.tensor(train_val_data_X[train_set].T)
        train_data_cv['side_effect', 'rev_known', 'drug'].edge_index = torch.tensor(train_val_data_X[train_set].T)[[1, 0]]
    
        val_data_cv = copy.deepcopy(data)
        val_data_cv['drug', 'known', 'side_effect'].edge_index = torch.tensor(train_val_data_X[valid_set].T)
        val_data_cv['side_effect', 'rev_known', 'drug'].edge_index = torch.tensor(train_val_data_X[valid_set].T)[[1, 0]]
    
        
        
        test_data_cv = copy.deepcopy(data)
        test_data_cv['drug', 'known', 'side_effect'].edge_index = torch.tensor(train_val_data_X[test_index].T)
        test_data_cv['side_effect', 'rev_known', 'drug'].edge_index = torch.tensor(train_val_data_X[test_index].T)[[1, 0]]
        
        # use RandomLinkSplit to get disjoint train ratio an other pyg transforms
        transform = T.RandomLinkSplit(
            num_val=0.0,
            num_test=0.0,
            disjoint_train_ratio=0.3236238313900354,
            neg_sampling_ratio=0.0,
            add_negative_train_samples=False,
            edge_types=('drug', 'known', 'side_effect'),
            rev_edge_types=('side_effect', 'rev_known', 'drug'), 
        )
        train_cv, _, _ = transform(train_data_cv)
        
        transform = T.RandomLinkSplit(
            num_val=0.0,
            num_test=0.0,
            disjoint_train_ratio=0.99,
            neg_sampling_ratio=1.0,
            add_negative_train_samples=True,
            edge_types=('drug', 'known', 'side_effect'),
            rev_edge_types=('side_effect', 'rev_known', 'drug'), 
        )
        
        val_cv, _, _ = transform(val_data_cv)
       

        test_cv, _, _ = transform(test_data_cv)
        # Define seed edges:
        edge_label_index = train_cv['drug', 'known', 'side_effect'].edge_label_index
        edge_label = train_cv['drug', 'known', 'side_effect'].edge_label

        train_loader = LinkNeighborLoader(
            data=train_cv,
            num_neighbors=num_neighbors,
            neg_sampling_ratio=1.0,
            edge_label_index=(("drug", "known", "side_effect"), edge_label_index),
            edge_label=edge_label,
            batch_size=batch_size,
            shuffle=True,
            # disjoint=True,
        )
        
        edge_label_index = val_cv['drug', 'known', 'side_effect'].edge_label_index
        edge_label = val_cv['drug', 'known', 'side_effect'].edge_label
        # num_neighbors is a dictionary, it uses the specified number for each edge type
        val_loader = LinkNeighborLoader(
            data=val_cv,
            num_neighbors=num_neighbors,
            edge_label_index=(("drug", "known", "side_effect"), edge_label_index),
            edge_label=edge_label,
            batch_size=batch_size,
            shuffle=False,
        )
        
        
        
        edge_label_index = test_cv['drug', 'known', 'side_effect'].edge_label_index
        edge_label = test_cv['drug', 'known', 'side_effect'].edge_label

        test_loader = LinkNeighborLoader(
            data=test_cv,
            num_neighbors=num_neighbors,
            edge_label_index= (("drug", "known", "side_effect"), edge_label_index), 
            edge_label=edge_label,
            batch_size=batch_size,
            shuffle=False
        )
        yield train_loader, val_loader, test_loader


### Model

#### MHGNN Hetero

In [15]:
class HeteroMHGNN(nn.Module):
    def __init__(self, metadata, in_channels, hidden_dims, heads, use_edge_attr=None):
        super().__init__()
        
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleDict()
        self.skips = nn.ModuleDict()
        self.final_norms = nn.ModuleDict()
        
        # Define which edge types should use edge attributes
        if use_edge_attr is None:
            use_edge_attr = {edge_type: False for edge_type in metadata[1]}
        
        for i, (out_dim, head) in enumerate(zip(hidden_dims, heads)):
            conv_dict = {}
            for edge_type in metadata[1]:
                src, _, dst = edge_type
                if i == 0:
                    in_channels = in_channels
                else:
                    in_channels = hidden_dims[i-1] * heads[i-1]
                
                if use_edge_attr[edge_type]:
                    conv_dict[edge_type] = GATv2Conv(in_channels, out_dim, heads=head, add_self_loops=False, edge_dim=1)
                else:
                    conv_dict[edge_type] = GATv2Conv(in_channels, out_dim, heads=head, add_self_loops=False)
            
            self.convs.append(HeteroConv(conv_dict, aggr='sum'))
            
            for node_type in metadata[0]:
                self.norms[f'{node_type}_{i}'] = nn.LayerNorm(out_dim * head)
                if i == 0:
                    self.skips[f'{node_type}_{i}'] = Linear(in_channels, out_dim * head)
                else:
                    self.skips[f'{node_type}_{i}'] = Linear(hidden_dims[i-1] * heads[i-1], out_dim * head)
        
        self.node_types = metadata[0]
        for node_type in metadata[0]:
            self.final_norms[f'{node_type}'] = nn.LayerNorm(out_dim * head *len(heads))
        
        # Initialize skips with xavier init
        for skip in self.skips.values():
            nn.init.xavier_uniform_(skip.weight)
            nn.init.zeros_(skip.bias)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        x_repr_dict = {node_type: [] for node_type in self.node_types}
        # edge_attr_dict = {key: value.to(torch.float32) for key, value in edge_attr_dict.items()}

        
        for i, conv in enumerate(self.convs):
            skip_x = {}
            for node_type in self.node_types:
                skip_x[node_type] = self.skips[f'{node_type}_{i}'](x_dict[node_type])
            
            x_dict_new = conv(x_dict, edge_index_dict, edge_attr_dict)
            
            for node_type in self.node_types:
                # skip_x = self.skips[f'{node_type}_{i}'](x_dict[node_type])
                x = x_dict_new[node_type]
                x = self.norms[f'{node_type}_{i}'](x) + skip_x[node_type]
                x = self.norms[f'{node_type}_{i}'](x)
                x = F.elu(x)
                x_repr_dict[node_type].append(x)
                x_dict[node_type] = x
        
        # Concatenate all representations for each node type
        for node_type in self.node_types:
            x_repr_dict[node_type] = self.final_norms[f'{node_type}'](torch.cat(x_repr_dict[node_type], dim=1))
        
        return x_repr_dict

# Specify which edge types should use edge attributes
use_edge_attr = {
    ('drug', 'known', 'side_effect'): False,
    ('drug', 'struct', 'drug'): True,
    ('drug', 'word', 'drug'): True,
    ('drug', 'target', 'drug'): True,
    ('drug', 'se_encoded', 'drug'): True,
    ('side_effect', 'name', 'side_effect'): True,
    ('side_effect', 'dg_encoded', 'side_effect'): True,
    ('side_effect', 'atc', 'side_effect'): True,
    ('side_effect', 'rev_known', 'drug'): False
}


#### MHGNN - Outer HGNN

In [16]:
class MHGNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, heads):
        super().__init__()
        self.GATLayers = nn.ModuleList()
        self.norms = nn.ModuleList() 
        self.skips = nn.ModuleList()
        for i, (out_dim, head) in enumerate(zip(hidden_dims, heads)):
            self.GATLayers.append(GATv2Conv(input_dim, out_dim, heads=head, add_self_loops=False, name=f'GATLayer{i}'))
            self.norms.append(nn.LayerNorm(out_dim * head))
            self.skips.append(nn.Linear(input_dim, out_dim * head))
            input_dim = out_dim * head      
        
        # # # initialize skips with xavier init
        for skip in self.skips:
            nn.init.xavier_uniform_(skip.weight)
            nn.init.zeros_(skip.bias)

    def forward(self, x, edge_index, edge_attr):
        x_repr = []
        for idx, (layer, skip, norm) in enumerate(zip(self.GATLayers, self.skips, self.norms)): # norm, self.norms
            skip_x = skip(x)
            x = layer(x, edge_index)
            x = norm(x) + skip_x  # Add skip connection
            x = norm(x)     # Apply normalization
            x = F.elu(x)    # Apply activation
            x_repr.append(x)
            # x = F.elu(norm(x))
            # x += skip_x
            # x_repr.append(F.elu(x))
            # if idx < len(self.GATLayers) - 1:
            #     x = F.elu(x) # norm(x)
        x_repr = torch.cat(x_repr, dim=1)
        return x_repr

#### DVModel

In [17]:
class DrugInterView_Block(nn.Module):
    def __init__(self, n_heads, in_features, head_out_feats):
        super().__init__()
        self.n_heads = n_heads
        self.in_features = in_features
        self.out_features = head_out_feats

        self.feature_conv = GATv2Conv(in_features, head_out_feats, n_heads, edge_dim=6)

        self.readout = SAGPooling(n_heads * head_out_feats, min_score=-1)

    def forward(self, mol_data):
        mol_data.x = self.feature_conv(mol_data.x, mol_data.edge_index, mol_data.edge_attr)
        mol_data_att_x, att_edge_index, att_edge_attr, h_att_batch, att_perm, h_att_scores = self.readout(mol_data.x, mol_data.edge_index, batch=mol_data.batch)

        mol_data_global_graph_emb = global_add_pool(mol_data_att_x, h_att_batch)

        return mol_data, mol_data_global_graph_emb, h_att_scores, h_att_batch

In [18]:
class FinalDrugMolEmb(nn.Module):
    def __init__(self, in_features, heads_out_feat_params, blocks_params):
        super().__init__()
        self.in_features = in_features
        self.n_blocks = len(blocks_params)

        self.inital_norm = nn.LayerNorm(self.in_features)

        self.blocks = nn.ModuleList()
        self.net_norms = nn.ModuleList()

        for i, (head_out_feats, n_heads) in enumerate(zip(heads_out_feat_params, blocks_params)):
            block = DrugInterView_Block(n_heads, in_features, head_out_feats)
            self.blocks.append(block)
            self.net_norms.append(nn.LayerNorm(head_out_feats * n_heads))
            in_features = head_out_feats * n_heads
       
    def forward(self, mol_data):
        repr_mol = []
        mol_data.x = self.inital_norm(mol_data.x)
        attention_weights = []
        attention_batch = []
        for idx, (block, norm) in enumerate(zip(self.blocks, self.net_norms)):
            mol_data, mol_data_global_graph_emb, mol_data_att_x, h_att_batch = block(mol_data)
            attention_weights.append(mol_data_att_x)
            attention_batch.append((mol_data.batch, h_att_batch))
            repr_mol.append(mol_data_global_graph_emb)
            if idx < len(self.blocks) - 1:
                mol_data.x = F.elu(norm(mol_data.x))
        # concat all the global graph embeddings
        mol_data_global_graph_emb = torch.cat(repr_mol, dim=1)
        return mol_data_global_graph_emb, attention_weights, attention_batch

In [19]:
class DVModel(torch.nn.Module):
    def __init__(self, hidden_channels, gnn_model, classifier_model, use_node_features=False, node_feature_mode="no"):
        super().__init__()
        # Instantiate node embeddings:
        self.seff_emb = torch.nn.Embedding(data["side_effect"].num_nodes, hidden_channels)
        # DV for Drug
        self.drug_emb = FinalDrugMolEmb(in_features=55, heads_out_feat_params=[64, 64], blocks_params=[3, 3])
        outer_emb_dim = 2 * 64 * 3
        self.inital_norm_outer_drug = nn.LayerNorm(outer_emb_dim)
        # self.inital_norm_outer_se = nn.LayerNorm(outer_emb_dim)
        # Instantiate Outer GNNs
        self.gnn = gnn_model # outer message passing
        self.use_node_features = use_node_features
        self.node_feature_mode = node_feature_mode  # combine, feat
        if use_node_features:
            self.drug_feat_layernorm = torch.nn.LayerNorm(data["drug"].num_features)
            self.drug_lin = torch.nn.Linear(data["drug"].num_features, hidden_channels)
        # Instantiate classifier:
        self.classifier = classifier_model
        
        torch.nn.init.xavier_uniform_(self.seff_emb.weight)
        
    def __create_graph_data(self, drug_ids, device):
        drug_ids_ = drug_ids.cpu().numpy().astype(int).tolist()
        final_data = []
        for id in drug_ids_:
            _ = MOL_EDGE_LIST_FEAT_MTX[id]
            final_data.append(Data(x= _[1]  , edge_index=_[0], edge_attr=_[2]))
        return Batch.from_data_list(final_data).to(device)        
       
    
    def forward(self, data: HeteroData) -> Tensor:
        if self.use_node_features:
            if self.node_feature_mode=="feat":
                x_dict = {
                    "drug": self.drug_lin(self.drug_feat_layernorm(data["drug"].x)),
                    "side_effect": self.seff_emb(data["side_effect"].node_id)
                }
        else:
            drug_list_of_graph_data = self.__create_graph_data(data["drug"].node_id, data["drug"].node_id.device)
            
            drug_dv, attention_weights, h_att_batch = self.drug_emb(drug_list_of_graph_data)
            if self.node_feature_mode=="combined":
                drug_dv += self.drug_lin(self.drug_feat_layernorm(data["drug"].x))
            
            # layer normalization of input features for outer gnn:
            x_dict = {
                "drug":  self.inital_norm_outer_drug(drug_dv),
                "side_effect": self.seff_emb(data["side_effect"].node_id)
            }

        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        # x_dict = self.gnn(x_dict, data.edge_index_dict, data.edge_attr_dict)
        # Forward pass
        x_dict = self.gnn(x_dict, data.edge_index_dict, data.edge_attr_dict)
        pred = self.classifier(
            x_dict["drug"],
            x_dict["side_effect"],
            data["drug", "known", "side_effect"].edge_label_index,
        )

        return pred, attention_weights, h_att_batch

#### Edge Classifier

In [20]:
# Our final classifier applies the hammard-product between source and destination
# node embeddings to derive edge-level predictions:
class VanillaClassifier(torch.nn.Module):
    def forward(self, x_drug: Tensor, x_se: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_drug = x_drug[edge_label_index[0]]
        edge_feat_se = x_se[edge_label_index[1]]

        # Apply hammard-product to get a prediction per supervision edge:
        return (edge_feat_drug * edge_feat_se).sum(dim=-1)

### Train Utils

#### Train Loop

In [21]:
def do_train_compute(batch, device, model):
    # batch = batch.to(device)
    pred, _, _ = model(batch)
    actual = batch["drug", "known", "side_effect"].edge_label
    return pred, actual

# def do_train_compute(batch, device, model):
#     # batch = batch.to(device)
#     pred = model(batch)
#     actual = batch.edge_label
#     return pred, actual


def evaluate_metrics(probas_pred, ground_truth):
    # compute binary classification metrics using sklearn
    # convert to numpy array
    probas_pred = probas_pred.numpy()
    
    ground_truth = ground_truth.numpy()
    
    # convert to binary predictions
    binary_pred = np.where(probas_pred > 0.5, 1, 0)

    
    # compute metrics
    accuracy = accuracy_score(ground_truth, binary_pred)
    precision = precision_score(ground_truth, binary_pred)
    recall = recall_score(ground_truth, binary_pred)
    f1 = f1_score(ground_truth, binary_pred)
    roc_auc = roc_auc_score(ground_truth, probas_pred)
    precision_, recall_, _ = precision_recall_curve(ground_truth, probas_pred)
    pr_auc = auc(recall_, precision_)
    average_precision = average_precision_score(ground_truth, probas_pred)
    return accuracy, precision, recall, f1, roc_auc, pr_auc, average_precision

def train_loop(model, model_name, writer, train_loader, val_loader, loss_fn, optimizer, n_epochs, device, scheduler=None, early_stopping_patience=3, early_stopping_counter=0):
    early_stop = False
    best_val_metrics = -float("inf") #-float("inf")
    best_model_path = f"saved_models/{model_name}/best_model.pth"
    # make best_model_path parent directory if it doesn't exist
    os.makedirs(os.path.dirname(best_model_path), exist_ok=True)
    
    print("Starting training loop at", datetime.today().strftime("%Y-%m-%d %H:%M:%S"))
    
    total_train_val_steps = len(train_loader) + len(val_loader)
    epoch_progress_bar = tqdm.notebook.tqdm(range(1, (total_train_val_steps*n_epochs)+1), desc="MiniBatches")
    epoch = 0
    for _ in epoch_progress_bar:
        epoch += 1
        start_time = time.time()
        train_loss = 0
        val_loss = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []
        print("Epoch", epoch)
        
        model.train()
        for idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            lr = optimizer.param_groups[0]['lr']
            optimizer.zero_grad()
            out, actual = do_train_compute(batch, device, model)
            pred = torch.sigmoid(out)
            train_probas_pred.append(pred.detach().cpu())
            train_ground_truth.append(actual.detach().cpu())
            loss = loss_fn(out, actual)
            loss.backward()
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed

            optimizer.step()
            train_loss += loss.item()
            epoch_progress_bar.set_postfix_str(f"Epoch {epoch} - LR {lr:.7f} - Train Batch {idx+1}/{len(train_loader)} - Train loss: {train_loss/(idx+1):.4f}")
            epoch_progress_bar.update()
            writer.add_scalar("Training Loss MiniBatch", loss.item(), idx)
            batch = batch.to("cpu")
            # if scheduler is not None: # cosine annealing scheduler
            #     scheduler.step()
        
        train_loss /= len(train_loader)
        writer.add_scalar("Training Loss Epoch", train_loss, epoch)
        model.eval()
        with torch.no_grad():
            train_probas_pred = torch.cat(train_probas_pred, dim=0)
            train_ground_truth = torch.cat(train_ground_truth, dim=0)
            train_accuracy, train_precision, train_recall, train_f1, \
                train_roc_auc, train_pr_auc, train_average_precision = evaluate_metrics(train_probas_pred, train_ground_truth)
            writer.add_scalar("Training Accuracy", train_accuracy, epoch)
            writer.add_scalar("Training Precision", train_precision, epoch)
            writer.add_scalar("Training Recall", train_recall, epoch)
            writer.add_scalar("Training F1", train_f1, epoch)
            writer.add_scalar("Training ROC AUC", train_roc_auc, epoch)
            writer.add_scalar("Training PR AUC", train_pr_auc, epoch)
            writer.add_scalar("Training Average Precision", train_average_precision, epoch)

            for idx_, batch in enumerate(val_loader):
                batch = batch.to(device)
                out, actual = do_train_compute(batch, device, model)
                pred = torch.sigmoid(out)
                val_probas_pred.append(pred.detach().cpu())
                val_ground_truth.append(actual.detach().cpu())
                loss = loss_fn(out, actual)
                val_loss += loss.item()
                epoch_progress_bar.set_postfix_str(f"Epoch {epoch} - LR {lr:.7f} - Val Batch {idx_+1}/{len(val_loader)} - Val loss: {val_loss/(idx+1):.4f}")
                epoch_progress_bar.update()
                writer.add_scalar("Validation Loss MiniBatch", loss.item(), idx_)
                batch = batch.to("cpu")
            val_loss /= len(val_loader)
            val_probas_pred = torch.cat(val_probas_pred, dim=0)
            val_ground_truth = torch.cat(val_ground_truth, dim=0)
            val_accuracy, val_precision, val_recall, val_f1, \
                val_roc_auc, val_pr_auc, val_average_precision = evaluate_metrics(val_probas_pred, val_ground_truth)
            
            writer.add_scalar("Validation Loss Epoch", val_loss, epoch)
            writer.add_scalar("Validation Accuracy", val_accuracy, epoch)
            writer.add_scalar("Validation Precision", val_precision, epoch)
            writer.add_scalar("Validation Recall", val_recall, epoch)
            writer.add_scalar("Validation F1", val_f1, epoch)
            writer.add_scalar("Validation ROC AUC", val_roc_auc, epoch)
            writer.add_scalar("Validation PR AUC", val_pr_auc, epoch)
            writer.add_scalar("Validation Average Precision", val_average_precision, epoch)
            
            if val_f1 > best_val_metrics:
                best_val_metrics = val_f1
                early_stopping_counter = 0
                torch.save(model.state_dict(), best_model_path)
                print("New best model saved!") 
            else:
                early_stopping_counter += 1
                print("Early stopping counter:", early_stopping_counter)
                if early_stopping_counter >= early_stopping_patience:
                    print("Early stopping triggered!")
                    early_stop = True
        
        if scheduler is not None:
            scheduler.step(val_f1) #
      
        
        epoch_progress_bar.set_postfix_str("Train loss: {:.4f}, Train f1: {:.4f}, Train auc: {:.4f}, Train pr_auc: {:.4f},\
                                            Val loss: {:.4f}, Val f1: {:.4f}, Val auc: {:.4f}, Val pr_auc: {:.4f},\
                                            Best val f1: {:.4f}".format(train_loss, train_f1, train_roc_auc, train_pr_auc,\
                                            val_loss, val_f1, val_roc_auc, val_pr_auc, best_val_metrics))
        epoch_progress_bar.update()
        print("Epoch Number:", epoch)   
        print("Epoch time:", time.time() - start_time)
        print("Train loss:", train_loss)
        print("Train accuracy:", train_accuracy)
        print("Train precision:", train_precision)
        print("Train recall:", train_recall)
        print("Train f1:", train_f1)
        print("Train roc_auc:", train_roc_auc)
        print("Train pr_auc:", train_pr_auc)
        print("Train average_precision:", train_average_precision)
        
        print("Val loss:", val_loss)
        print("Val accuracy:", val_accuracy)
        print("Val precision:", val_precision)
        print("Val recall:", val_recall)
        print("Val f1:", val_f1)
        print("Val roc_auc:", val_roc_auc)
        print("Val pr_auc:", val_pr_auc)
        print("Val average_precision:", val_average_precision)
        print("Best val_f1:", best_val_metrics)
        print()
        if early_stop:
            break
        if epoch == n_epochs:
            print("Training completed!")
            break
    
    # load best model 
    model.load_state_dict(torch.load(best_model_path))
    return model

#### Test Evaluate Metrics

In [22]:
def mrank(y, y_pre):
    index = np.argsort(-y_pre)
    r_label = y[index]
    r_index = np.array(np.where(r_label == 1)) + 1
    reci_sum = np.sum(1 / r_index)
    reci_rank = np.mean(1 / r_index)
    return reci_sum

def evaluate_fold(loader, model, device, ret=False):
    preds = []
    ground_truths = []
    model.eval()
    for sampled_data in tqdm.tqdm(loader):
        with torch.no_grad():
            sampled_data.to(device)
            pred, _, _ = model(sampled_data) 
            # Applying sigmoid activation function to the predicted values
            output_probs = torch.sigmoid(pred)

            preds.append(output_probs)
            ground_truths.append(sampled_data["drug", "known", "side_effect"].edge_label)

    pred = torch.cat(preds, dim=0).cpu().numpy()
    pred_int = (pred>0.5).astype(int)
    ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()

    auc = roc_auc_score(ground_truth, pred)
    ap = average_precision_score(ground_truth, pred)
    mr = mrank(ground_truth, pred)
    f1 = f1_score(ground_truth, pred_int)
    mcc = matthews_corrcoef(ground_truth, pred_int)
    acc = (pred_int == ground_truth).mean()
    precision = precision_score(ground_truth, pred_int)
    recall = recall_score(ground_truth, pred_int)
    print()
    print(f"Test AUC: {auc:.4f}")
    print(f"Test AP: {ap:.4f}")
    print(f"Test F1: {f1:.4f}")
    print(f"Test Accuracy: {acc:.4f}")
    print(f"Test Precission: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"Test MCC: {mcc:.4f}")
    print(f"Test MR: {mr:.4f}")
    if ret:
        return auc, ap, f1, acc, precision, recall, mr, mcc


#### Train Wrap CV

In [23]:
def train_wrap_cv(data, model_name_, cv_fold=10, shuffle=True, num_neighbors=[10, 4], batch_size=64, n_epochs=10, early_stopping_patience=5):
    eval_metrics = []
    for i, (train_loader_cv, val_loader_cv,  test_loader_cv) in enumerate(get_kfold_data(data, k=cv_fold,
                                                                shuffle=shuffle,
                                                                num_neighbors=num_neighbors, 
                                                                batch_size=batch_size)):
        print(f"Fold {i+1}")
        model_name = f"{model_name_}/fold{i+1}"
        # Define the log directory where TensorBoard logs will be stored
        log_dir = f"logs/{model_name}/" + datetime.now().strftime("%Y%m%d-%H%M%S")
        os.makedirs(log_dir, exist_ok=True)

        # Create a SummaryWriter
        writer = SummaryWriter(log_dir)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Device: '{device}'")

        # gnn_model = MHGNN(input_dim=384, hidden_dims=[64, 64, 64], heads=[2, 2, 2])

        # gnn_model = to_hetero(gnn_model, metadata=data.metadata())

        gnn_model = HeteroMHGNN(data.metadata(), in_channels=384, hidden_dims=[64, 64, 64], heads=[2, 2, 2], use_edge_attr=use_edge_attr)
        classifier_model = VanillaClassifier()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = DVModel(hidden_channels=384, gnn_model=gnn_model,
                        classifier_model=classifier_model, use_node_features=False)
        
        model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.00025062971034390006) #,  weight_decay=0.001)
        
        # optimizer = torch.optim.SGD(model.parameters(), lr=0.00025062971034390006, weight_decay=0.001)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
        #                                                                  T_0=len(train_loader_cv), 
        #                                                                  T_mult=1, eta_min=1e-5, 
        #                                                                  verbose=False)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, patience=2, min_lr=1e-6
            )
        # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    
        print(f"Total Number of Parameters: {sum(p.numel() for p in model.parameters())}")
        print(f"Total Number of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
        
        model = train_loop(model, model_name, writer, train_loader_cv, val_loader_cv,
                             F.binary_cross_entropy_with_logits, optimizer, n_epochs=n_epochs, 
                             device=device, scheduler=scheduler, early_stopping_patience=early_stopping_patience)

        # load best model and store evaluation metrics
        model.load_state_dict(torch.load(f'saved_models/{model_name}/best_model.pth'))
        auc, ap, f1, acc, precision, recall, mr, mcc = evaluate_fold(test_loader_cv, model, torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), ret=True)
        eval_metrics.append([auc, ap, f1, acc, precision, recall, mr, mcc])
        gc.collect()
        torch.cuda.empty_cache()

    return eval_metrics

### Run Train CV 

In [30]:
gc.collect()
torch.cuda.empty_cache()

In [31]:
data

HeteroData(
  [1mdrug[0m={ node_id=[1007] },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] },
  [1m(side_effect, rev_known, drug)[0m={ edge_index=[2, 132063] }
)

In [32]:
eval_metrics = train_wrap_cv(data, "sdv-hgnn-variant-dv", cv_fold=3, shuffle=True, 
                             num_neighbors=[10, 4], batch_size=64)

Fold 1
Device: 'cuda'
Total Number of Parameters: 2744944
Total Number of Trainable Parameters: 2744944
Starting training loop at 2024-08-05 17:24:26


MiniBatches:   0%|          | 0/6740 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 40.29823327064514
Train loss: 1.2928061621296139
Train accuracy: 0.6899032836752204
Train precision: 0.6952407682129826
Train recall: 0.6762343030964824
Train f1: 0.6856058359527905
Train roc_auc: 0.7452964353067086
Train pr_auc: 0.7049551082612389
Train average_precision: 0.7038764893767748
Val loss: 28.02231058284374
Val accuracy: 0.5088916934373566
Val precision: 0.9532163742690059
Val recall: 0.018701239100504818
Val f1: 0.036682795093957464
Val roc_auc: 0.7369683096670817
Val pr_auc: 0.6992310159967932
Val average_precision: 0.6992773528720391
Best val_f1: 0.036682795093957464

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 40.98165273666382
Train loss: 0.5402394868340575
Train accuracy: 0.7926838780126355
Train precision: 0.7845174008643566
Train recall: 0.807035332657359
Train f1: 0.7956170703575548
Train roc_auc: 0.866621624965101
Train pr_auc: 0.8463757947788898
Train average_precision: 0.8464003271641122
Val

100%|██████████| 1362/1362 [00:24<00:00, 56.26it/s]



Test AUC: 0.8247
Test AP: 0.8228
Test F1: 0.6551
Test Accuracy: 0.7201
Test Precission: 0.8534
Test Recall: 0.5316
Test MCC: 0.4753
Test MR: 10.7556
Fold 2
Device: 'cuda'
Total Number of Parameters: 2744944
Total Number of Trainable Parameters: 2744944
Starting training loop at 2024-08-05 17:31:40


MiniBatches:   0%|          | 0/6740 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 38.23499536514282
Train loss: 1.3154718109645749
Train accuracy: 0.699672412448327
Train precision: 0.7044891764517933
Train recall: 0.6878948599953202
Train f1: 0.696093133385951
Train roc_auc: 0.7480504801733219
Train pr_auc: 0.7080829727609126
Train average_precision: 0.7068647663415635
Val loss: 11.480780214182685
Val accuracy: 0.544229004130335
Val precision: 0.7320891029500302
Val recall: 0.13951353832033042
Val f1: 0.2343644598631589
Val roc_auc: 0.7561502667626501
Val pr_auc: 0.719676654334731
Val average_precision: 0.7197343102659153
Best val_f1: 0.2343644598631589

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 41.550193071365356
Train loss: 0.5504975961776742
Train accuracy: 0.7825832618360502
Train precision: 0.780287792047037
Train recall: 0.7866781062319632
Train f1: 0.7834699188254943
Train roc_auc: 0.8575510193907225
Train pr_auc: 0.8399591860247834
Train average_precision: 0.839991787075716
Val lo

100%|██████████| 1362/1362 [00:24<00:00, 55.75it/s]



Test AUC: 0.8234
Test AP: 0.8190
Test F1: 0.5711
Test Accuracy: 0.6833
Test Precission: 0.8845
Test Recall: 0.4217
Test MCC: 0.4302
Test MR: 9.3302
Fold 3
Device: 'cuda'
Total Number of Parameters: 2744944
Total Number of Trainable Parameters: 2744944
Starting training loop at 2024-08-05 17:38:51


MiniBatches:   0%|          | 0/6740 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 39.88079023361206
Train loss: 1.1206484318076821
Train accuracy: 0.7069066375477732
Train precision: 0.7015537742658512
Train recall: 0.7201856329459481
Train f1: 0.7107476185894352
Train roc_auc: 0.7618790022871729
Train pr_auc: 0.7127700777777222
Train average_precision: 0.7112217662249343
Val loss: 27.516790300218858
Val accuracy: 0.5122189077558513
Val precision: 0.8338557993730408
Val recall: 0.03051858650757228
Val f1: 0.05888212506917542
Val roc_auc: 0.7322533387980432
Val pr_auc: 0.6887988015987283
Val average_precision: 0.688899239199931
Best val_f1: 0.05888212506917542

Epoch 2
New best model saved!
Epoch Number: 2
Epoch time: 39.24109148979187
Train loss: 0.5504288887442496
Train accuracy: 0.7864441151236253
Train precision: 0.7834156505633585
Train recall: 0.7917869120973403
Train f1: 0.7875790372008223
Train roc_auc: 0.86053962645777
Train pr_auc: 0.8412033496351133
Train average_precision: 0.8412136954114807
Val lo

100%|██████████| 1362/1362 [00:25<00:00, 53.45it/s]



Test AUC: 0.7571
Test AP: 0.7201
Test F1: 0.5026
Test Accuracy: 0.6267
Test Precission: 0.7528
Test Recall: 0.3772
Test MCC: 0.2923
Test MR: 9.0090


### CV Performance

In [33]:
metrics = ['auc', 'ap', 'f1', 'acc', 'precision', 'recall', 'mr', 'mcc']
metrics_mean_value = np.mean(eval_metrics, axis=0)
metrics_std = np.std(eval_metrics, axis=0)
df = pd.DataFrame({
    'Metric': metrics,
    'Mean': metrics_mean_value,
    'Standard Deviation': metrics_std
})
df

Unnamed: 0,Metric,Mean,Standard Deviation
0,auc,0.801719,0.031582
1,ap,0.787321,0.04755
2,f1,0.576253,0.062373
3,acc,0.676694,0.038439
4,precision,0.830204,0.056209
5,recall,0.443491,0.064889
6,mr,9.69825,0.759048
7,mcc,0.399273,0.077853
