### Import

In [10]:
import gc 
import os
import threading
import tqdm
import time
import pickle
import copy
import random
from datetime import datetime
from math import sqrt

import numpy as np
import pandas as pd


from rdkit import Chem

import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter


from torch_geometric.loader.link_neighbor_loader import LinkNeighborLoader
import torch_geometric.transforms as T
from torch_geometric.data import (
                                    HeteroData,
                                    Data, 
                                    Batch
                                 )   
from torch_geometric.nn import (
                                GATv2Conv,
                                SAGPooling,
                                global_add_pool,
                                HeteroConv,
                                Linear,
                                to_hetero
                                )

from sklearn.model_selection import StratifiedShuffleSplit, KFold, train_test_split, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score,
    roc_auc_score, 
    precision_recall_curve, 
    auc, 
    average_precision_score, 
    matthews_corrcoef
    )

### Seed all randomness

In [11]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Usage example:
seed_everything(29)  # Set the seed to 42

### Load HeteroData

In [12]:
# data_dict = data.to_dict()
fnm = '../prep_data/hetero_graph/hetero_data_dict.pt'
data = torch.load(fnm)

In [13]:
data

HeteroData(
  [1mdrug[0m={ node_id=[1007] },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] },
  [1m(drug, struct, drug)[0m={
    edge_index=[2, 15844],
    edge_attr=[15844]
  },
  [1m(drug, word, drug)[0m={
    edge_index=[2, 83865],
    edge_attr=[83865]
  },
  [1m(drug, target, drug)[0m={
    edge_index=[2, 3363],
    edge_attr=[3363]
  },
  [1m(drug, se_encoded, drug)[0m={
    edge_index=[2, 65854],
    edge_attr=[65854]
  },
  [1m(side_effect, name, side_effect)[0m={
    edge_index=[2, 299170],
    edge_attr=[299170]
  },
  [1m(side_effect, dg_encoded, side_effect)[0m={
    edge_index=[2, 101114],
    edge_attr=[101114]
  },
  [1m(side_effect, atc, side_effect)[0m={
    edge_index=[2, 26140],
    edge_attr=[26140]
  }
)

### Load Transformation Maps

In [14]:
DB_TO_ID_DICT = {}
drug_id_mol_graph_tup = []
ID_TO_DB_DICT = {}
MEDRAID_TO_ID_DICT = {}
ID_TO_MEDRAID_DICT = {}

In [15]:
dict_list = [DB_TO_ID_DICT, ID_TO_DB_DICT, MEDRAID_TO_ID_DICT, ID_TO_MEDRAID_DICT, drug_id_mol_graph_tup]
file_names = ['db_to_id.pt', 'id_to_db.pt', 'uml_to_id.pt', 'id_to_uml.pt', 'drug_to_mol.pt']

for data_dict, fnm in zip(dict_list, file_names):
    full_path = f"../prep_data/hetero_graph/{fnm}"
    loaded_data = torch.load(full_path)
    
    if isinstance(data_dict, dict):
        data_dict.update(loaded_data)
    elif isinstance(data_dict, list):
        data_dict.extend(loaded_data)
    else:
        # If it's neither a dict nor a list, just replace it
        index = dict_list.index(data_dict)
        dict_list[index] = loaded_data

### Variant Simple - Hetero GNN

In [16]:
remove_similarity_edges = [('drug', 'struct', 'drug'),
                           ('drug', 'word', 'drug'),
                           ('drug', 'target', 'drug'),
                           ('drug', 'se_encoded', 'drug'),
                           ('side_effect', 'name', 'side_effect'),
                           ('side_effect', 'dg_encoded', 'side_effect'),
                           ('side_effect', 'atc', 'side_effect')
                            ]
for edge in remove_similarity_edges:
    del data[edge]

In [17]:
data

HeteroData(
  [1mdrug[0m={ node_id=[1007] },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] }
)

### HeteroData Undirected

In [18]:
data = T.ToUndirected()(data)

### Random Initalized node features

In [19]:
feature_dim = 384

In [20]:
num_drug_nodes = data['drug']['node_id'].size()[0]
num_se_nodes = data['side_effect']['node_id'].size()[0]

In [21]:
# Initialize feature tensors
drug_features = torch.empty(num_drug_nodes, feature_dim)
side_effect_features = torch.empty(num_se_nodes, feature_dim)

# Xavier Uniform Initialization with different parameters
nn.init.uniform_(drug_features)
nn.init.uniform_(side_effect_features)

tensor([[0.8509, 0.7431, 0.8943,  ..., 0.4268, 0.6308, 0.3762],
        [0.3987, 0.4537, 0.0922,  ..., 0.2985, 0.5453, 0.5793],
        [0.9462, 0.0774, 0.0744,  ..., 0.2169, 0.2147, 0.3346],
        ...,
        [0.2994, 0.9146, 0.0501,  ..., 0.5804, 0.2876, 0.8630],
        [0.8042, 0.3490, 0.8564,  ..., 0.5564, 0.2242, 0.7845],
        [0.9705, 0.6715, 0.2157,  ..., 0.6528, 0.2216, 0.1756]])

### Hetero data - Random Features

In [22]:
data

HeteroData(
  [1mdrug[0m={ node_id=[1007] },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] },
  [1m(side_effect, rev_known, drug)[0m={ edge_index=[2, 132063] }
)

In [23]:
data['drug'].x = drug_features
data['side_effect'].x = side_effect_features

### CV Split

In [24]:
def get_kfold_data(data, k=10, shuffle=True, num_neighbors=[10, 4], batch_size=64):
    kf = KFold(n_splits=k, shuffle=shuffle)
    kf.get_n_splits()
    train_val_data_X = data['drug', 'known', 'side_effect'].edge_index.T.numpy()
    for train_index, test_index in kf.split(train_val_data_X):
        train_index_, valid_index_ = train_test_split(train_index, test_size=0.1)
        train_set = train_index_
        valid_set = valid_index_
        
        train_data_cv = copy.deepcopy(data)
        train_data_cv['drug', 'known', 'side_effect'].edge_index = torch.tensor(train_val_data_X[train_set].T)
        train_data_cv['side_effect', 'rev_known', 'drug'].edge_index = torch.tensor(train_val_data_X[train_set].T)[[1, 0]]
    
        val_data_cv = copy.deepcopy(data)
        val_data_cv['drug', 'known', 'side_effect'].edge_index = torch.tensor(train_val_data_X[valid_set].T)
        val_data_cv['side_effect', 'rev_known', 'drug'].edge_index = torch.tensor(train_val_data_X[valid_set].T)[[1, 0]]
    
        
        
        test_data_cv = copy.deepcopy(data)
        test_data_cv['drug', 'known', 'side_effect'].edge_index = torch.tensor(train_val_data_X[test_index].T)
        test_data_cv['side_effect', 'rev_known', 'drug'].edge_index = torch.tensor(train_val_data_X[test_index].T)[[1, 0]]
        
        # use RandomLinkSplit to get disjoint train ratio an other pyg transforms
        transform = T.RandomLinkSplit(
            num_val=0.0,
            num_test=0.0,
            disjoint_train_ratio=0.3236238313900354,
            neg_sampling_ratio=0.0,
            add_negative_train_samples=False,
            edge_types=('drug', 'known', 'side_effect'),
            rev_edge_types=('side_effect', 'rev_known', 'drug'), 
        )
        train_cv, _, _ = transform(train_data_cv)
        
        transform = T.RandomLinkSplit(
            num_val=0.0,
            num_test=0.0,
            disjoint_train_ratio=0.99,
            neg_sampling_ratio=1.0,
            add_negative_train_samples=True,
            edge_types=('drug', 'known', 'side_effect'),
            rev_edge_types=('side_effect', 'rev_known', 'drug'), 
        )
        
        val_cv, _, _ = transform(val_data_cv)
       

        test_cv, _, _ = transform(test_data_cv)
        # Define seed edges:
        edge_label_index = train_cv['drug', 'known', 'side_effect'].edge_label_index
        edge_label = train_cv['drug', 'known', 'side_effect'].edge_label

        train_loader = LinkNeighborLoader(
            data=train_cv,
            num_neighbors=num_neighbors,
            neg_sampling_ratio=1.0,
            edge_label_index=(("drug", "known", "side_effect"), edge_label_index),
            edge_label=edge_label,
            batch_size=batch_size,
            shuffle=True,
            # disjoint=True,
        )
        
        edge_label_index = val_cv['drug', 'known', 'side_effect'].edge_label_index
        edge_label = val_cv['drug', 'known', 'side_effect'].edge_label
        # num_neighbors is a dictionary, it uses the specified number for each edge type
        val_loader = LinkNeighborLoader(
            data=val_cv,
            num_neighbors=num_neighbors,
            edge_label_index=(("drug", "known", "side_effect"), edge_label_index),
            edge_label=edge_label,
            batch_size=batch_size,
            shuffle=False,
        )
        
        
        
        edge_label_index = test_cv['drug', 'known', 'side_effect'].edge_label_index
        edge_label = test_cv['drug', 'known', 'side_effect'].edge_label

        test_loader = LinkNeighborLoader(
            data=test_cv,
            num_neighbors=num_neighbors,
            edge_label_index= (("drug", "known", "side_effect"), edge_label_index), 
            edge_label=edge_label,
            batch_size=batch_size,
            shuffle=False
        )
        yield train_loader, val_loader, test_loader


### Model

### Outer GNN

In [25]:
class HeteroMHGNN(nn.Module):
    def __init__(self, metadata, in_channels, hidden_dims, heads, use_edge_attr=None):
        super().__init__()
        
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleDict()
        self.skips = nn.ModuleDict()
        self.final_norms = nn.ModuleDict()
        
        # Define which edge types should use edge attributes
        if use_edge_attr is None:
            use_edge_attr = {edge_type: False for edge_type in metadata[1]}
        
        for i, (out_dim, head) in enumerate(zip(hidden_dims, heads)):
            conv_dict = {}
            for edge_type in metadata[1]:
                src, _, dst = edge_type
                if i == 0:
                    in_channels = in_channels
                else:
                    in_channels = hidden_dims[i-1] * heads[i-1]
                
                if use_edge_attr[edge_type]:
                    conv_dict[edge_type] = GATv2Conv(in_channels, out_dim, heads=head, add_self_loops=False, edge_dim=1)
                else:
                    conv_dict[edge_type] = GATv2Conv(in_channels, out_dim, heads=head, add_self_loops=False)
            
            self.convs.append(HeteroConv(conv_dict, aggr='sum'))
            
            for node_type in metadata[0]:
                self.norms[f'{node_type}_{i}'] = nn.LayerNorm(out_dim * head)
                if i == 0:
                    self.skips[f'{node_type}_{i}'] = Linear(in_channels, out_dim * head)
                else:
                    self.skips[f'{node_type}_{i}'] = Linear(hidden_dims[i-1] * heads[i-1], out_dim * head)
        
        self.node_types = metadata[0]
        for node_type in metadata[0]:
            self.final_norms[f'{node_type}'] = nn.LayerNorm(out_dim * head *len(heads))
        
        # Initialize skips with xavier init
        for skip in self.skips.values():
            nn.init.xavier_uniform_(skip.weight)
            nn.init.zeros_(skip.bias)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        x_repr_dict = {node_type: [] for node_type in self.node_types}
        # edge_attr_dict = {key: value.to(torch.float32) for key, value in edge_attr_dict.items()}

        
        for i, conv in enumerate(self.convs):
            skip_x = {}
            for node_type in self.node_types:
                skip_x[node_type] = self.skips[f'{node_type}_{i}'](x_dict[node_type])
            
            x_dict_new = conv(x_dict, edge_index_dict, edge_attr_dict)
            
            for node_type in self.node_types:
                # skip_x = self.skips[f'{node_type}_{i}'](x_dict[node_type])
                x = x_dict_new[node_type]
                x = self.norms[f'{node_type}_{i}'](x) + skip_x[node_type]
                x = self.norms[f'{node_type}_{i}'](x)
                x = F.elu(x)
                x_repr_dict[node_type].append(x)
                x_dict[node_type] = x
        
        # Concatenate all representations for each node type
        for node_type in self.node_types:
            x_repr_dict[node_type] = self.final_norms[f'{node_type}'](torch.cat(x_repr_dict[node_type], dim=1))
        
        return x_repr_dict




#### Edge Classifier

In [26]:
# Our final classifier applies the hammard-product between source and destination
# node embeddings to derive edge-level predictions:
class VanillaClassifier(torch.nn.Module):
    def forward(self, x_drug: Tensor, x_se: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_drug = x_drug[edge_label_index[0]]
        edge_feat_se = x_se[edge_label_index[1]]

        # Apply hammard-product to get a prediction per supervision edge:
        return (edge_feat_drug * edge_feat_se).sum(dim=-1)

#### FinalModel

In [37]:
class HeteroModel(torch.nn.Module):
    def __init__(self, gnn_model, classifier_model, num_drug_nodes, num_se_nodes):
        super().__init__()
        
        # self.drug_emb = torch.nn.Embedding(num_drug_nodes, 384)
        # self.se_emb = torch.nn.Embedding(num_se_nodes, 384)
        self.inital_norm_outer_drug = nn.LayerNorm(384)
        self.inital_norm_outer_se = nn.LayerNorm(384)
        
        self.gnn = gnn_model 
        # Instantiate classifier:
        self.classifier = classifier_model
        # torch.nn.init.xavier_uniform_(self.drug_emb.weight)
        # torch.nn.init.xavier_uniform_(self.se_emb.weight)
    def forward(self, data: Data) -> Tensor:
        x = self.gnn(data.x, data.edge_index)
        pred = self.classifier(
            x,
            data.edge_label_index,
        )

        return pred
    
    def forward(self, data: HeteroData) -> Tensor:
            
        # layer normalization of input features for outer gnn:
        x_dict = {
            "drug":  self.inital_norm_outer_drug(data["drug"].x),
            "side_effect": self.inital_norm_outer_se(data["side_effect"].x)
        }

        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        # x_dict = self.gnn(x_dict, data.edge_index_dict, data.edge_attr_dict)
        # Forward pass
        x_dict = self.gnn(x_dict, data.edge_index_dict, data.edge_attr_dict)
        pred = self.classifier(
            x_dict["drug"],
            x_dict["side_effect"],
            data["drug", "known", "side_effect"].edge_label_index,
        )

        return pred

### Train Utils

#### Train Loop

In [38]:
def do_train_compute(batch, device, model):
    # batch = batch.to(device)
    pred = model(batch)
    actual = batch["drug", "known", "side_effect"].edge_label
    return pred, actual

# def do_train_compute(batch, device, model):
#     # batch = batch.to(device)
#     pred = model(batch)
#     actual = batch.edge_label
#     return pred, actual


def evaluate_metrics(probas_pred, ground_truth):
    # compute binary classification metrics using sklearn
    # convert to numpy array
    probas_pred = probas_pred.numpy()
    
    ground_truth = ground_truth.numpy()
    
    # convert to binary predictions
    binary_pred = np.where(probas_pred > 0.5, 1, 0)

    
    # compute metrics
    accuracy = accuracy_score(ground_truth, binary_pred)
    precision = precision_score(ground_truth, binary_pred)
    recall = recall_score(ground_truth, binary_pred)
    f1 = f1_score(ground_truth, binary_pred)
    roc_auc = roc_auc_score(ground_truth, probas_pred)
    precision_, recall_, _ = precision_recall_curve(ground_truth, probas_pred)
    pr_auc = auc(recall_, precision_)
    average_precision = average_precision_score(ground_truth, probas_pred)
    return accuracy, precision, recall, f1, roc_auc, pr_auc, average_precision

def train_loop(model, model_name, writer, train_loader, val_loader, loss_fn, optimizer, n_epochs, device, scheduler=None, early_stopping_patience=3, early_stopping_counter=0):
    early_stop = False
    best_val_metrics = -float("inf") #-float("inf")
    best_model_path = f"saved_models/{model_name}/best_model.pth"
    # make best_model_path parent directory if it doesn't exist
    os.makedirs(os.path.dirname(best_model_path), exist_ok=True)
    
    print("Starting training loop at", datetime.today().strftime("%Y-%m-%d %H:%M:%S"))
    
    total_train_val_steps = len(train_loader) + len(val_loader)
    epoch_progress_bar = tqdm.notebook.tqdm(range(1, (total_train_val_steps*n_epochs)+1), desc="MiniBatches")
    epoch = 0
    for _ in epoch_progress_bar:
        epoch += 1
        start_time = time.time()
        train_loss = 0
        val_loss = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []
        print("Epoch", epoch)
        
        model.train()
        for idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            lr = optimizer.param_groups[0]['lr']
            optimizer.zero_grad()
            out, actual = do_train_compute(batch, device, model)
            pred = torch.sigmoid(out)
            train_probas_pred.append(pred.detach().cpu())
            train_ground_truth.append(actual.detach().cpu())
            loss = loss_fn(out, actual)
            loss.backward()
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed

            optimizer.step()
            train_loss += loss.item()
            epoch_progress_bar.set_postfix_str(f"Epoch {epoch} - LR {lr:.7f} - Train Batch {idx+1}/{len(train_loader)} - Train loss: {train_loss/(idx+1):.4f}")
            epoch_progress_bar.update()
            writer.add_scalar("Training Loss MiniBatch", loss.item(), idx)
            batch = batch.to("cpu")
            # if scheduler is not None: # cosine annealing scheduler
            #     scheduler.step()
        
        train_loss /= len(train_loader)
        writer.add_scalar("Training Loss Epoch", train_loss, epoch)
        model.eval()
        with torch.no_grad():
            train_probas_pred = torch.cat(train_probas_pred, dim=0)
            train_ground_truth = torch.cat(train_ground_truth, dim=0)
            train_accuracy, train_precision, train_recall, train_f1, \
                train_roc_auc, train_pr_auc, train_average_precision = evaluate_metrics(train_probas_pred, train_ground_truth)
            writer.add_scalar("Training Accuracy", train_accuracy, epoch)
            writer.add_scalar("Training Precision", train_precision, epoch)
            writer.add_scalar("Training Recall", train_recall, epoch)
            writer.add_scalar("Training F1", train_f1, epoch)
            writer.add_scalar("Training ROC AUC", train_roc_auc, epoch)
            writer.add_scalar("Training PR AUC", train_pr_auc, epoch)
            writer.add_scalar("Training Average Precision", train_average_precision, epoch)

            for idx_, batch in enumerate(val_loader):
                batch = batch.to(device)
                out, actual = do_train_compute(batch, device, model)
                pred = torch.sigmoid(out)
                val_probas_pred.append(pred.detach().cpu())
                val_ground_truth.append(actual.detach().cpu())
                loss = loss_fn(out, actual)
                val_loss += loss.item()
                epoch_progress_bar.set_postfix_str(f"Epoch {epoch} - LR {lr:.7f} - Val Batch {idx_+1}/{len(val_loader)} - Val loss: {val_loss/(idx+1):.4f}")
                epoch_progress_bar.update()
                writer.add_scalar("Validation Loss MiniBatch", loss.item(), idx_)
                batch = batch.to("cpu")
            val_loss /= len(val_loader)
            val_probas_pred = torch.cat(val_probas_pred, dim=0)
            val_ground_truth = torch.cat(val_ground_truth, dim=0)
            val_accuracy, val_precision, val_recall, val_f1, \
                val_roc_auc, val_pr_auc, val_average_precision = evaluate_metrics(val_probas_pred, val_ground_truth)
            
            writer.add_scalar("Validation Loss Epoch", val_loss, epoch)
            writer.add_scalar("Validation Accuracy", val_accuracy, epoch)
            writer.add_scalar("Validation Precision", val_precision, epoch)
            writer.add_scalar("Validation Recall", val_recall, epoch)
            writer.add_scalar("Validation F1", val_f1, epoch)
            writer.add_scalar("Validation ROC AUC", val_roc_auc, epoch)
            writer.add_scalar("Validation PR AUC", val_pr_auc, epoch)
            writer.add_scalar("Validation Average Precision", val_average_precision, epoch)
            
            if val_f1 > best_val_metrics:
                best_val_metrics = val_f1
                early_stopping_counter = 0
                torch.save(model.state_dict(), best_model_path)
                print("New best model saved!") 
            else:
                early_stopping_counter += 1
                print("Early stopping counter:", early_stopping_counter)
                if early_stopping_counter >= early_stopping_patience:
                    print("Early stopping triggered!")
                    early_stop = True
        
        if scheduler is not None:
            scheduler.step(val_f1) #
      
        
        epoch_progress_bar.set_postfix_str("Train loss: {:.4f}, Train f1: {:.4f}, Train auc: {:.4f}, Train pr_auc: {:.4f},\
                                            Val loss: {:.4f}, Val f1: {:.4f}, Val auc: {:.4f}, Val pr_auc: {:.4f},\
                                            Best val f1: {:.4f}".format(train_loss, train_f1, train_roc_auc, train_pr_auc,\
                                            val_loss, val_f1, val_roc_auc, val_pr_auc, best_val_metrics))
        epoch_progress_bar.update()
        print("Epoch Number:", epoch)   
        print("Epoch time:", time.time() - start_time)
        print("Train loss:", train_loss)
        print("Train accuracy:", train_accuracy)
        print("Train precision:", train_precision)
        print("Train recall:", train_recall)
        print("Train f1:", train_f1)
        print("Train roc_auc:", train_roc_auc)
        print("Train pr_auc:", train_pr_auc)
        print("Train average_precision:", train_average_precision)
        
        print("Val loss:", val_loss)
        print("Val accuracy:", val_accuracy)
        print("Val precision:", val_precision)
        print("Val recall:", val_recall)
        print("Val f1:", val_f1)
        print("Val roc_auc:", val_roc_auc)
        print("Val pr_auc:", val_pr_auc)
        print("Val average_precision:", val_average_precision)
        print("Best val_f1:", best_val_metrics)
        print()
        if early_stop:
            break
        if epoch == n_epochs:
            print("Training completed!")
            break
    
    # load best model 
    model.load_state_dict(torch.load(best_model_path))
    return model

#### Test Evaluate Metrics

In [39]:
def mrank(y, y_pre):
    index = np.argsort(-y_pre)
    r_label = y[index]
    r_index = np.array(np.where(r_label == 1)) + 1
    reci_sum = np.sum(1 / r_index)
    reci_rank = np.mean(1 / r_index)
    return reci_sum

def evaluate_fold(loader, model, device, ret=False):
    preds = []
    ground_truths = []
    model.eval()
    for sampled_data in tqdm.tqdm(loader):
        with torch.no_grad():
            sampled_data.to(device)
            pred = model(sampled_data) 
            # Applying sigmoid activation function to the predicted values
            output_probs = torch.sigmoid(pred)

            preds.append(output_probs)
            ground_truths.append(sampled_data["drug", "known", "side_effect"].edge_label)

    pred = torch.cat(preds, dim=0).cpu().numpy()
    pred_int = (pred>0.5).astype(int)
    ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()

    auc = roc_auc_score(ground_truth, pred)
    ap = average_precision_score(ground_truth, pred)
    mr = mrank(ground_truth, pred)
    f1 = f1_score(ground_truth, pred_int)
    mcc = matthews_corrcoef(ground_truth, pred_int)
    acc = (pred_int == ground_truth).mean()
    precision = precision_score(ground_truth, pred_int)
    recall = recall_score(ground_truth, pred_int)
    print()
    print(f"Test AUC: {auc:.4f}")
    print(f"Test AP: {ap:.4f}")
    print(f"Test F1: {f1:.4f}")
    print(f"Test Accuracy: {acc:.4f}")
    print(f"Test Precission: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"Test MCC: {mcc:.4f}")
    print(f"Test MR: {mr:.4f}")
    if ret:
        return auc, ap, f1, acc, precision, recall, mr, mcc


#### Train Wrap CV

In [40]:
def train_wrap_cv(data, model_name_, cv_fold=10, shuffle=True, num_neighbors=[10, 4], batch_size=64, n_epochs=10, early_stopping_patience=5):
    eval_metrics = []
    for i, (train_loader_cv, val_loader_cv,  test_loader_cv) in enumerate(get_kfold_data(data, k=cv_fold,
                                                                shuffle=shuffle,
                                                                num_neighbors=num_neighbors, 
                                                                batch_size=batch_size)):
        print(f"Fold {i+1}")
        model_name = f"{model_name_}/fold{i+1}"
        # Define the log directory where TensorBoard logs will be stored
        log_dir = f"logs/{model_name}/" + datetime.now().strftime("%Y%m%d-%H%M%S")
        os.makedirs(log_dir, exist_ok=True)

        # Create a SummaryWriter
        writer = SummaryWriter(log_dir)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Device: '{device}'")

        # gnn_model = MHGNN(input_dim=384, hidden_dims=[64, 64, 64], heads=[2, 2, 2])

        # gnn_model = to_hetero(gnn_model, metadata=data.metadata())
        # Specify which edge types should use edge attributes
        use_edge_attr = {
                            ('drug', 'known', 'side_effect'): False,
                            # ('drug', 'struct', 'drug'): True,
                            # ('drug', 'word', 'drug'): True,
                            # ('drug', 'target', 'drug'): True,
                            # ('drug', 'se_encoded', 'drug'): True,
                            # ('side_effect', 'name', 'side_effect'): True,
                            # ('side_effect', 'dg_encoded', 'side_effect'): True,
                            # ('side_effect', 'atc', 'side_effect'): True,
                            ('side_effect', 'rev_known', 'drug'): False
                        }
        gnn_model = HeteroMHGNN(data.metadata(), in_channels=384, hidden_dims=[64, 64, 64], heads=[2, 2, 2], use_edge_attr=use_edge_attr)
        classifier_model = VanillaClassifier()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = HeteroModel(gnn_model=gnn_model,
                        classifier_model=classifier_model, num_drug_nodes=1007, num_se_nodes=5587)
        
        model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.00025062971034390006, weight_decay=0.001)
        
        # optimizer = torch.optim.SGD(model.parameters(), lr=0.00025062971034390006, weight_decay=0.001)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
        #                                                                  T_0=len(train_loader_cv), 
        #                                                                  T_mult=1, eta_min=1e-5, 
        #                                                                  verbose=False)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, patience=2, min_lr=1e-6
            )
        # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    
        print(f"Total Number of Parameters: {sum(p.numel() for p in model.parameters())}")
        print(f"Total Number of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
        
        model = train_loop(model, model_name, writer, train_loader_cv, val_loader_cv,
                             F.binary_cross_entropy_with_logits, optimizer, n_epochs=n_epochs, 
                             device=device, scheduler=scheduler, early_stopping_patience=early_stopping_patience)

        # load best model and store evaluation metrics
        model.load_state_dict(torch.load(f'saved_models/{model_name}/best_model.pth'))
        auc, ap, f1, acc, precision, recall, mr, mcc = evaluate_fold(test_loader_cv, model, torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), ret=True)
        eval_metrics.append([auc, ap, f1, acc, precision, recall, mr, mcc])
        gc.collect()
        torch.cuda.empty_cache()

    return eval_metrics

### Run Train CV 

In [47]:
gc.collect()
torch.cuda.empty_cache()

In [48]:
data

HeteroData(
  [1mdrug[0m={
    node_id=[1007],
    x=[1007, 384]
  },
  [1mside_effect[0m={
    node_id=[5587],
    x=[5587, 384]
  },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] },
  [1m(side_effect, rev_known, drug)[0m={ edge_index=[2, 132063] }
)

In [49]:
eval_metrics = train_wrap_cv(data, "sdv-hgnn-variant-hetero", cv_fold=3, shuffle=True, 
                             num_neighbors=[10, 4], batch_size=64)

Fold 1
Device: 'cuda'
Total Number of Parameters: 499968
Total Number of Trainable Parameters: 499968
Starting training loop at 2024-08-05 16:41:28


MiniBatches:   0%|          | 0/6740 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 14.518026113510132
Train loss: 3.097954424390769
Train accuracy: 0.7070041338429139
Train precision: 0.7033404842169783
Train recall: 0.7160127915139225
Train f1: 0.709620067251575
Train roc_auc: 0.7645647487975598
Train pr_auc: 0.7471270964788392
Train average_precision: 0.7199447361965523
Val loss: 8.993545267716826
Val accuracy: 0.513194125745755
Val precision: 0.6356132075471698
Val recall: 0.061840293712712255
Val f1: 0.11271434546214973
Val roc_auc: 0.5303007532474973
Val pr_auc: 0.5437487513121384
Val average_precision: 0.5429142599770541
Best val_f1: 0.11271434546214973

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 13.872706413269043
Train loss: 1.3146244645713274
Train accuracy: 0.7524374073785196
Train precision: 0.7501545833977431
Train recall: 0.7570002339911084
Train f1: 0.75356186187352
Train roc_auc: 0.8222059402082789
Train pr_auc: 0.7907016947526165
Train average_precision: 0.787865390110706
Val

100%|██████████| 1362/1362 [00:10<00:00, 129.60it/s]



Test AUC: 0.5189
Test AP: 0.5568
Test F1: 0.2329
Test Accuracy: 0.5357
Test Precission: 0.6694
Test Recall: 0.1410
Test MCC: 0.1162
Test MR: 8.5498
Fold 2
Device: 'cuda'
Total Number of Parameters: 499968
Total Number of Trainable Parameters: 499968
Starting training loop at 2024-08-05 16:43:06


MiniBatches:   0%|          | 0/6740 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 15.225512266159058
Train loss: 3.1595100214951057
Train accuracy: 0.704527727946338
Train precision: 0.7002711268950242
Train recall: 0.7151548241166835
Train f1: 0.7076347218738543
Train roc_auc: 0.7637440590376667
Train pr_auc: 0.7514923670297626
Train average_precision: 0.7238468813374024
Val loss: 9.224569826227704
Val accuracy: 0.5185865075722809
Val precision: 0.6901408450704225
Val recall: 0.06746213859568609
Val f1: 0.1229096989966555
Val roc_auc: 0.5226323406326946
Val pr_auc: 0.5447555392905201
Val average_precision: 0.5441530973140756
Best val_f1: 0.1229096989966555

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 15.546830177307129
Train loss: 1.3272295744341804
Train accuracy: 0.7564932532563763
Train precision: 0.7544687765998607
Train recall: 0.7604711020981203
Train f1: 0.757458048477315
Train roc_auc: 0.8273795897409435
Train pr_auc: 0.8004065377471092
Train average_precision: 0.797339557820375
Val

100%|██████████| 1362/1362 [00:09<00:00, 143.67it/s]



Test AUC: 0.5127
Test AP: 0.5582
Test F1: 0.2732
Test Accuracy: 0.5435
Test Precission: 0.6700
Test Recall: 0.1716
Test MCC: 0.1303
Test MR: 9.1365
Fold 3
Device: 'cuda'
Total Number of Parameters: 499968
Total Number of Trainable Parameters: 499968
Starting training loop at 2024-08-05 16:45:48


MiniBatches:   0%|          | 0/6740 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 14.64178466796875
Train loss: 3.1441179026392034
Train accuracy: 0.7098315264019968
Train precision: 0.7054802367767806
Train recall: 0.720419624054286
Train f1: 0.7128716692071699
Train roc_auc: 0.7694635349742367
Train pr_auc: 0.7566524085787387
Train average_precision: 0.7294150357642348
Val loss: 15.650627601086153
Val accuracy: 0.505277650298302
Val precision: 0.7421052631578947
Val recall: 0.016177145479577788
Val f1: 0.03166404671008309
Val roc_auc: 0.5796108595102696
Val pr_auc: 0.5902920670888987
Val average_precision: 0.590326503612945
Best val_f1: 0.03166404671008309

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 14.383310317993164
Train loss: 1.3346178232880304
Train accuracy: 0.7533733718118711
Train precision: 0.753018147830828
Train recall: 0.7540753451368848
Train f1: 0.7535463756819953
Train roc_auc: 0.8263000910809006
Train pr_auc: 0.7964849601734707
Train average_precision: 0.7937768458860088
V

100%|██████████| 1362/1362 [00:09<00:00, 141.08it/s]



Test AUC: 0.4797
Test AP: 0.5469
Test F1: 0.1748
Test Accuracy: 0.5321
Test Precission: 0.7397
Test Recall: 0.0991
Test MCC: 0.1284
Test MR: 9.5148


### CV Performance

In [50]:
metrics = ['auc', 'ap', 'f1', 'acc', 'precision', 'recall', 'mr', 'mcc']
metrics_mean_value = np.mean(eval_metrics, axis=0)
metrics_std = np.std(eval_metrics, axis=0)
df = pd.DataFrame({
    'Metric': metrics,
    'Mean': metrics_mean_value,
    'Standard Deviation': metrics_std
})
df

Unnamed: 0,Metric,Mean,Standard Deviation
0,auc,0.503781,0.01721
1,ap,0.553984,0.005007
2,f1,0.226961,0.040398
3,acc,0.537108,0.004775
4,precision,0.693028,0.03299
5,recall,0.137219,0.029711
6,mr,9.06704,0.39703
7,mcc,0.12499,0.006249
