### Import

In [1]:
import gc 
import os
import threading
import tqdm
import time
import pickle
import copy
import random
from datetime import datetime
from math import sqrt

import numpy as np
import pandas as pd


from rdkit import Chem

import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter


from torch_geometric.loader.link_neighbor_loader import LinkNeighborLoader
import torch_geometric.transforms as T
from torch_geometric.data import (
                                    HeteroData,
                                    Data, 
                                    Batch
                                 )   
from torch_geometric.nn import (
                                GATv2Conv,
                                SAGPooling,
                                global_add_pool,
                                HeteroConv,
                                Linear,
                                to_hetero
                                )

from sklearn.model_selection import StratifiedShuffleSplit, KFold, train_test_split, StratifiedKFold
from sklearn.metrics import (
    accuracy_score, 
    precision_score, 
    recall_score, 
    f1_score,
    roc_auc_score, 
    precision_recall_curve, 
    auc, 
    average_precision_score, 
    matthews_corrcoef
    )

### Seed all randomness

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Usage example:
seed_everything(42)  # Set the seed to 42

### Load HeteroData

In [3]:
# data_dict = data.to_dict()
fnm = '../prep_data/hetero_graph/hetero_data_dict.pt'
data = torch.load(fnm)

In [4]:
data

HeteroData(
  [1mdrug[0m={ node_id=[1007] },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] },
  [1m(drug, struct, drug)[0m={
    edge_index=[2, 15844],
    edge_attr=[15844]
  },
  [1m(drug, word, drug)[0m={
    edge_index=[2, 83865],
    edge_attr=[83865]
  },
  [1m(drug, target, drug)[0m={
    edge_index=[2, 3363],
    edge_attr=[3363]
  },
  [1m(drug, se_encoded, drug)[0m={
    edge_index=[2, 65854],
    edge_attr=[65854]
  },
  [1m(side_effect, name, side_effect)[0m={
    edge_index=[2, 299170],
    edge_attr=[299170]
  },
  [1m(side_effect, dg_encoded, side_effect)[0m={
    edge_index=[2, 101114],
    edge_attr=[101114]
  },
  [1m(side_effect, atc, side_effect)[0m={
    edge_index=[2, 26140],
    edge_attr=[26140]
  }
)

### Load Transformation Maps

In [5]:
DB_TO_ID_DICT = {}
drug_id_mol_graph_tup = []
ID_TO_DB_DICT = {}
MEDRAID_TO_ID_DICT = {}
ID_TO_MEDRAID_DICT = {}

In [6]:
dict_list = [DB_TO_ID_DICT, ID_TO_DB_DICT, MEDRAID_TO_ID_DICT, ID_TO_MEDRAID_DICT, drug_id_mol_graph_tup]
file_names = ['db_to_id.pt', 'id_to_db.pt', 'uml_to_id.pt', 'id_to_uml.pt', 'drug_to_mol.pt']

for data_dict, fnm in zip(dict_list, file_names):
    full_path = f"../prep_data/hetero_graph/{fnm}"
    loaded_data = torch.load(full_path)
    
    if isinstance(data_dict, dict):
        data_dict.update(loaded_data)
    elif isinstance(data_dict, list):
        data_dict.extend(loaded_data)
    else:
        # If it's neither a dict nor a list, just replace it
        index = dict_list.index(data_dict)
        dict_list[index] = loaded_data

### Variant Simple - Homogenous GNN

In [7]:
remove_similarity_edges = [('drug', 'struct', 'drug'),
                           ('drug', 'word', 'drug'),
                           ('drug', 'target', 'drug'),
                           ('drug', 'se_encoded', 'drug'),
                           ('side_effect', 'name', 'side_effect'),
                           ('side_effect', 'dg_encoded', 'side_effect'),
                           ('side_effect', 'atc', 'side_effect')
                            ]
for edge in remove_similarity_edges:
    del data[edge]

In [8]:
data

HeteroData(
  [1mdrug[0m={ node_id=[1007] },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] }
)

### HeteroData Undirected

In [9]:
# data = T.ToUndirected()(data)

### Random Initalized node features

In [10]:
feature_dim = 384

In [11]:
num_drug_nodes = data['drug']['node_id'].size()[0]
num_se_nodes = data['side_effect']['node_id'].size()[0]

In [12]:
# Initialize feature tensors
drug_features = torch.empty(num_drug_nodes, feature_dim)
side_effect_features = torch.empty(num_se_nodes, feature_dim)

# Xavier Uniform Initialization with different parameters
nn.init.uniform_(drug_features)
nn.init.uniform_(side_effect_features)

tensor([[0.7288, 0.7229, 0.2304,  ..., 0.4231, 0.8584, 0.4674],
        [0.1594, 0.9352, 0.9086,  ..., 0.6596, 0.9826, 0.5885],
        [0.6754, 0.1170, 0.0593,  ..., 0.5903, 0.3413, 0.2729],
        ...,
        [0.8108, 0.0930, 0.7023,  ..., 0.3677, 0.7749, 0.9355],
        [0.8262, 0.4098, 0.4337,  ..., 0.0729, 0.7103, 0.2575],
        [0.0270, 0.8912, 0.3906,  ..., 0.7522, 0.7644, 0.2061]])

### Homogenous data

In [13]:
data

HeteroData(
  [1mdrug[0m={ node_id=[1007] },
  [1mside_effect[0m={ node_id=[5587] },
  [1m(drug, known, side_effect)[0m={ edge_index=[2, 132063] }
)

In [14]:
homo_data = data.to_homogeneous()
homo_data

Data(edge_index=[2, 132063], node_id=[6594], node_type=[6594], edge_type=[132063])

In [15]:
homo_data.node_type

tensor([0, 0, 0,  ..., 1, 1, 1])

In [16]:
# Step 2: Manually combine node features
all_features = torch.cat([drug_features, side_effect_features], dim=0)

# Step 3: Assign the combined features to the homogeneous data
homo_data.x = all_features

In [17]:
homo_data

Data(edge_index=[2, 132063], node_id=[6594], node_type=[6594], edge_type=[132063], x=[6594, 384])

### CV Split

In [18]:
def get_kfold_data(data, k=10, shuffle=True, num_neighbors=[10, 4], batch_size=64):
    kf = KFold(n_splits=k, shuffle=shuffle)
    kf.get_n_splits()
    train_val_data_X = data.edge_index.T.numpy()
    for train_index, test_index in kf.split(train_val_data_X):
        train_index_, valid_index_ = train_test_split(train_index, test_size=0.1)
        train_set = train_index_
        valid_set = valid_index_
        
        train_data_cv = copy.deepcopy(data)
        train_data_cv.edge_index = torch.tensor(train_val_data_X[train_set].T)
        
    
        val_data_cv = copy.deepcopy(data)
        val_data_cv.edge_index = torch.tensor(train_val_data_X[valid_set].T)
        
        
        test_data_cv = copy.deepcopy(data)
        test_data_cv.edge_index = torch.tensor(train_val_data_X[test_index].T)
    
        
        # use RandomLinkSplit to get disjoint train ratio an other pyg transforms
        transform = T.RandomLinkSplit(
            num_val=0.0,
            num_test=0.0,
            disjoint_train_ratio=0.3236238313900354,
            neg_sampling_ratio=0.0,
            add_negative_train_samples=False,
            is_undirected=True,
        )
        train_cv, _, _ = transform(train_data_cv)
        
        transform = T.RandomLinkSplit(
            num_val=0.0,
            num_test=0.0,
            disjoint_train_ratio=0.99,
            neg_sampling_ratio=1.0,
            add_negative_train_samples=True,
            is_undirected=True
        )
        
        val_cv, _, _ = transform(val_data_cv)
       

        test_cv, _, _ = transform(test_data_cv)
        # Define seed edges:
        edge_label_index = train_cv.edge_label_index
        edge_label = train_cv.edge_label

        train_loader = LinkNeighborLoader(
            data=train_cv,
            num_neighbors=num_neighbors,
            neg_sampling_ratio=1.0,
            edge_label_index=edge_label_index,
            edge_label=edge_label,
            batch_size=batch_size,
            shuffle=True,
            # disjoint=True,
        )
        
        edge_label_index = val_cv.edge_label_index
        edge_label = val_cv.edge_label
        # num_neighbors is a dictionary, it uses the specified number for each edge type
        val_loader = LinkNeighborLoader(
            data=val_cv,
            num_neighbors=num_neighbors,
            edge_label_index=edge_label_index,
            edge_label=edge_label,
            batch_size=batch_size,
            shuffle=False,
        )
        
        
        
        edge_label_index = test_cv.edge_label_index
        edge_label = test_cv.edge_label

        test_loader = LinkNeighborLoader(
            data=test_cv,
            num_neighbors=num_neighbors,
            edge_label_index= edge_label_index,
            edge_label=edge_label,
            batch_size=batch_size,
            shuffle=False
        )
        yield train_loader, val_loader, test_loader


### Model

### Outer GNN

In [19]:
class HeteroMHGNN(nn.Module):
    def __init__(self, metadata, in_channels, hidden_dims, heads, use_edge_attr=None):
        super().__init__()
        
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleDict()
        self.skips = nn.ModuleDict()
        self.final_norms = nn.ModuleDict()
        
        # Define which edge types should use edge attributes
        if use_edge_attr is None:
            use_edge_attr = {edge_type: False for edge_type in metadata[1]}
        
        for i, (out_dim, head) in enumerate(zip(hidden_dims, heads)):
            conv_dict = {}
            for edge_type in metadata[1]:
                src, _, dst = edge_type
                if i == 0:
                    in_channels = in_channels
                else:
                    in_channels = hidden_dims[i-1] * heads[i-1]
                
                if use_edge_attr[edge_type]:
                    conv_dict[edge_type] = GATv2Conv(in_channels, out_dim, heads=head, add_self_loops=False, edge_dim=1)
                else:
                    conv_dict[edge_type] = GATv2Conv(in_channels, out_dim, heads=head, add_self_loops=False)
            
            self.convs.append(HeteroConv(conv_dict, aggr='sum'))
            
            for node_type in metadata[0]:
                self.norms[f'{node_type}_{i}'] = nn.LayerNorm(out_dim * head)
                if i == 0:
                    self.skips[f'{node_type}_{i}'] = Linear(in_channels, out_dim * head)
                else:
                    self.skips[f'{node_type}_{i}'] = Linear(hidden_dims[i-1] * heads[i-1], out_dim * head)
        
        self.node_types = metadata[0]
        for node_type in metadata[0]:
            self.final_norms[f'{node_type}'] = nn.LayerNorm(out_dim * head *len(heads))
        
        # Initialize skips with xavier init
        for skip in self.skips.values():
            nn.init.xavier_uniform_(skip.weight)
            nn.init.zeros_(skip.bias)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        x_repr_dict = {node_type: [] for node_type in self.node_types}
        # edge_attr_dict = {key: value.to(torch.float32) for key, value in edge_attr_dict.items()}

        
        for i, conv in enumerate(self.convs):
            skip_x = {}
            for node_type in self.node_types:
                skip_x[node_type] = self.skips[f'{node_type}_{i}'](x_dict[node_type])
            
            x_dict_new = conv(x_dict, edge_index_dict, edge_attr_dict)
            
            for node_type in self.node_types:
                # skip_x = self.skips[f'{node_type}_{i}'](x_dict[node_type])
                x = x_dict_new[node_type]
                x = self.norms[f'{node_type}_{i}'](x) + skip_x[node_type]
                x = self.norms[f'{node_type}_{i}'](x)
                x = F.elu(x)
                x_repr_dict[node_type].append(x)
                x_dict[node_type] = x
        
        # Concatenate all representations for each node type
        for node_type in self.node_types:
            x_repr_dict[node_type] = self.final_norms[f'{node_type}'](torch.cat(x_repr_dict[node_type], dim=1))
        
        return x_repr_dict




#### Edge Classifier

In [20]:
# Our final classifier applies the hammard-product between source and destination
# node embeddings to derive edge-level predictions:
class VanillaHomoClassifier(torch.nn.Module):
    def forward(self, x: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_drug = x[edge_label_index[0]]
        edge_feat_se = x[edge_label_index[1]]

        # Apply hammard-product to get a prediction per supervision edge:
        return (edge_feat_drug * edge_feat_se).sum(dim=-1)

#### FinalModel

In [21]:
class HomoModel(torch.nn.Module):
    def __init__(self, gnn_model, classifier_model, num_nodes):
        super().__init__()
        self.inital_norm = nn.LayerNorm(384)
        # self.node_emb = torch.nn.Embedding(num_nodes, 384)
        
        self.gnn = gnn_model 
        # Instantiate classifier:
        self.classifier = classifier_model   
    
    def forward(self, data: Data) -> Tensor:
        x = self.inital_norm(data.x)
        x_dict = {"node": x}
        edge_index_dict = {('node', 'edge', 'node'):data.edge_index}
        edge_index_attr = {('node', 'edge', 'node'):data.edge_attr}
        x = self.gnn(x_dict, edge_index_dict, edge_index_attr)
        pred = self.classifier(
            x_dict["node"],
            data.edge_label_index,
        )

        return pred

### Train Utils

#### Train Loop

In [22]:
def do_train_compute(batch, device, model):
    # batch = batch.to(device)
    pred = model(batch)
    actual = batch.edge_label
    return pred, actual

# def do_train_compute(batch, device, model):
#     # batch = batch.to(device)
#     pred = model(batch)
#     actual = batch.edge_label
#     return pred, actual


def evaluate_metrics(probas_pred, ground_truth):
    # compute binary classification metrics using sklearn
    # convert to numpy array
    probas_pred = probas_pred.numpy()
    
    ground_truth = ground_truth.numpy()
    
    # convert to binary predictions
    binary_pred = np.where(probas_pred > 0.5, 1, 0)

    
    # compute metrics
    accuracy = accuracy_score(ground_truth, binary_pred)
    precision = precision_score(ground_truth, binary_pred)
    recall = recall_score(ground_truth, binary_pred)
    f1 = f1_score(ground_truth, binary_pred)
    roc_auc = roc_auc_score(ground_truth, probas_pred)
    precision_, recall_, _ = precision_recall_curve(ground_truth, probas_pred)
    pr_auc = auc(recall_, precision_)
    average_precision = average_precision_score(ground_truth, probas_pred)
    return accuracy, precision, recall, f1, roc_auc, pr_auc, average_precision

def train_loop(model, model_name, writer, train_loader, val_loader, loss_fn, optimizer, n_epochs, device, scheduler=None, early_stopping_patience=3, early_stopping_counter=0):
    early_stop = False
    best_val_metrics = -float("inf") #-float("inf")
    best_model_path = f"saved_models/{model_name}/best_model.pth"
    # make best_model_path parent directory if it doesn't exist
    os.makedirs(os.path.dirname(best_model_path), exist_ok=True)
    
    print("Starting training loop at", datetime.today().strftime("%Y-%m-%d %H:%M:%S"))
    
    total_train_val_steps = len(train_loader) + len(val_loader)
    epoch_progress_bar = tqdm.notebook.tqdm(range(1, (total_train_val_steps*n_epochs)+1), desc="MiniBatches")
    epoch = 0
    for _ in epoch_progress_bar:
        epoch += 1
        start_time = time.time()
        train_loss = 0
        val_loss = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []
        print("Epoch", epoch)
        
        model.train()
        for idx, batch in enumerate(train_loader):
            batch = batch.to(device)
            lr = optimizer.param_groups[0]['lr']
            optimizer.zero_grad()
            out, actual = do_train_compute(batch, device, model)
            pred = torch.sigmoid(out)
            train_probas_pred.append(pred.detach().cpu())
            train_ground_truth.append(actual.detach().cpu())
            loss = loss_fn(out, actual)
            loss.backward()
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Adjust max_norm as needed

            optimizer.step()
            train_loss += loss.item()
            epoch_progress_bar.set_postfix_str(f"Epoch {epoch} - LR {lr:.7f} - Train Batch {idx+1}/{len(train_loader)} - Train loss: {train_loss/(idx+1):.4f}")
            epoch_progress_bar.update()
            writer.add_scalar("Training Loss MiniBatch", loss.item(), idx)
            batch = batch.to("cpu")
            # if scheduler is not None: # cosine annealing scheduler
            #     scheduler.step()
        
        train_loss /= len(train_loader)
        writer.add_scalar("Training Loss Epoch", train_loss, epoch)
        model.eval()
        with torch.no_grad():
            train_probas_pred = torch.cat(train_probas_pred, dim=0)
            train_ground_truth = torch.cat(train_ground_truth, dim=0)
            train_accuracy, train_precision, train_recall, train_f1, \
                train_roc_auc, train_pr_auc, train_average_precision = evaluate_metrics(train_probas_pred, train_ground_truth)
            writer.add_scalar("Training Accuracy", train_accuracy, epoch)
            writer.add_scalar("Training Precision", train_precision, epoch)
            writer.add_scalar("Training Recall", train_recall, epoch)
            writer.add_scalar("Training F1", train_f1, epoch)
            writer.add_scalar("Training ROC AUC", train_roc_auc, epoch)
            writer.add_scalar("Training PR AUC", train_pr_auc, epoch)
            writer.add_scalar("Training Average Precision", train_average_precision, epoch)

            for idx_, batch in enumerate(val_loader):
                batch = batch.to(device)
                out, actual = do_train_compute(batch, device, model)
                pred = torch.sigmoid(out)
                val_probas_pred.append(pred.detach().cpu())
                val_ground_truth.append(actual.detach().cpu())
                loss = loss_fn(out, actual)
                val_loss += loss.item()
                epoch_progress_bar.set_postfix_str(f"Epoch {epoch} - LR {lr:.7f} - Val Batch {idx_+1}/{len(val_loader)} - Val loss: {val_loss/(idx+1):.4f}")
                epoch_progress_bar.update()
                writer.add_scalar("Validation Loss MiniBatch", loss.item(), idx_)
                batch = batch.to("cpu")
            val_loss /= len(val_loader)
            val_probas_pred = torch.cat(val_probas_pred, dim=0)
            val_ground_truth = torch.cat(val_ground_truth, dim=0)
            val_accuracy, val_precision, val_recall, val_f1, \
                val_roc_auc, val_pr_auc, val_average_precision = evaluate_metrics(val_probas_pred, val_ground_truth)
            
            writer.add_scalar("Validation Loss Epoch", val_loss, epoch)
            writer.add_scalar("Validation Accuracy", val_accuracy, epoch)
            writer.add_scalar("Validation Precision", val_precision, epoch)
            writer.add_scalar("Validation Recall", val_recall, epoch)
            writer.add_scalar("Validation F1", val_f1, epoch)
            writer.add_scalar("Validation ROC AUC", val_roc_auc, epoch)
            writer.add_scalar("Validation PR AUC", val_pr_auc, epoch)
            writer.add_scalar("Validation Average Precision", val_average_precision, epoch)
            
            if val_f1 > best_val_metrics:
                best_val_metrics = val_f1
                early_stopping_counter = 0
                torch.save(model.state_dict(), best_model_path)
                print("New best model saved!") 
            else:
                early_stopping_counter += 1
                print("Early stopping counter:", early_stopping_counter)
                if early_stopping_counter >= early_stopping_patience:
                    print("Early stopping triggered!")
                    early_stop = True
        
        if scheduler is not None:
            scheduler.step(val_f1) #
      
        
        epoch_progress_bar.set_postfix_str("Train loss: {:.4f}, Train f1: {:.4f}, Train auc: {:.4f}, Train pr_auc: {:.4f},\
                                            Val loss: {:.4f}, Val f1: {:.4f}, Val auc: {:.4f}, Val pr_auc: {:.4f},\
                                            Best val f1: {:.4f}".format(train_loss, train_f1, train_roc_auc, train_pr_auc,\
                                            val_loss, val_f1, val_roc_auc, val_pr_auc, best_val_metrics))
        epoch_progress_bar.update()
        print("Epoch Number:", epoch)   
        print("Epoch time:", time.time() - start_time)
        print("Train loss:", train_loss)
        print("Train accuracy:", train_accuracy)
        print("Train precision:", train_precision)
        print("Train recall:", train_recall)
        print("Train f1:", train_f1)
        print("Train roc_auc:", train_roc_auc)
        print("Train pr_auc:", train_pr_auc)
        print("Train average_precision:", train_average_precision)
        
        print("Val loss:", val_loss)
        print("Val accuracy:", val_accuracy)
        print("Val precision:", val_precision)
        print("Val recall:", val_recall)
        print("Val f1:", val_f1)
        print("Val roc_auc:", val_roc_auc)
        print("Val pr_auc:", val_pr_auc)
        print("Val average_precision:", val_average_precision)
        print("Best val_f1:", best_val_metrics)
        print()
        if early_stop:
            break
        if epoch == n_epochs:
            print("Training completed!")
            break
    
    # load best model 
    model.load_state_dict(torch.load(best_model_path))
    return model

#### Test Evaluate Metrics

In [23]:
def mrank(y, y_pre):
    index = np.argsort(-y_pre)
    r_label = y[index]
    r_index = np.array(np.where(r_label == 1)) + 1
    reci_sum = np.sum(1 / r_index)
    reci_rank = np.mean(1 / r_index)
    return reci_sum

def evaluate_fold(loader, model, device, ret=False):
    preds = []
    ground_truths = []
    model.eval()
    for sampled_data in tqdm.tqdm(loader):
        with torch.no_grad():
            sampled_data.to(device)
            pred = model(sampled_data) 
            # Applying sigmoid activation function to the predicted values
            output_probs = torch.sigmoid(pred)

            preds.append(output_probs)
            ground_truths.append(sampled_data.edge_label)

    pred = torch.cat(preds, dim=0).cpu().numpy()
    pred_int = (pred>0.5).astype(int)
    ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()

    auc = roc_auc_score(ground_truth, pred)
    ap = average_precision_score(ground_truth, pred)
    mr = mrank(ground_truth, pred)
    f1 = f1_score(ground_truth, pred_int)
    mcc = matthews_corrcoef(ground_truth, pred_int)
    acc = (pred_int == ground_truth).mean()
    precision = precision_score(ground_truth, pred_int)
    recall = recall_score(ground_truth, pred_int)
    print()
    print(f"Test AUC: {auc:.4f}")
    print(f"Test AP: {ap:.4f}")
    print(f"Test F1: {f1:.4f}")
    print(f"Test Accuracy: {acc:.4f}")
    print(f"Test Precission: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"Test MCC: {mcc:.4f}")
    print(f"Test MR: {mr:.4f}")
    if ret:
        return auc, ap, f1, acc, precision, recall, mr, mcc


#### Train Wrap CV

In [24]:
homo_data

Data(edge_index=[2, 132063], node_id=[6594], node_type=[6594], edge_type=[132063], x=[6594, 384])

In [25]:
def train_wrap_cv(data, model_name_, cv_fold=10, shuffle=True, num_neighbors=[10, 4], batch_size=64, n_epochs=10, early_stopping_patience=5):
    eval_metrics = []
    for i, (train_loader_cv, val_loader_cv,  test_loader_cv) in enumerate(get_kfold_data(data, k=cv_fold,
                                                                shuffle=shuffle,
                                                                num_neighbors=num_neighbors, 
                                                                batch_size=batch_size)):
        print(f"Fold {i+1}")
        model_name = f"{model_name_}/fold{i+1}"
        # Define the log directory where TensorBoard logs will be stored
        log_dir = f"logs/{model_name}/" + datetime.now().strftime("%Y%m%d-%H%M%S")
        os.makedirs(log_dir, exist_ok=True)

        # Create a SummaryWriter
        writer = SummaryWriter(log_dir)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Device: '{device}'")

        # gnn_model = MHGNN(input_dim=384, hidden_dims=[64, 64, 64], heads=[2, 2, 2])

        # gnn_model = to_hetero(gnn_model, metadata=data.metadata())
        use_edge_attr = {('node', 'edge', 'node'): False}
        homo_metadata = (['node'], [('node', 'edge', 'node')])
        gnn_model = HeteroMHGNN(homo_metadata, in_channels=384, hidden_dims=[64, 64, 64], heads=[2, 2, 2], use_edge_attr=use_edge_attr)
        classifier_model = VanillaHomoClassifier()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = HomoModel(gnn_model=gnn_model,
                        classifier_model=classifier_model, num_nodes=6594)
        
        model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.00025062971034390006,  weight_decay=0.001)
        
        # optimizer = torch.optim.SGD(model.parameters(), lr=0.00025062971034390006, weight_decay=0.001)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
        #                                                                  T_0=len(train_loader_cv), 
        #                                                                  T_mult=1, eta_min=1e-5, 
        #                                                                  verbose=False)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode="max", factor=0.5, patience=2, min_lr=1e-6
            )
        # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    
        print(f"Total Number of Parameters: {sum(p.numel() for p in model.parameters())}")
        print(f"Total Number of Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
        
        model = train_loop(model, model_name, writer, train_loader_cv, val_loader_cv,
                             F.binary_cross_entropy_with_logits, optimizer, n_epochs=n_epochs, 
                             device=device, scheduler=scheduler, early_stopping_patience=early_stopping_patience)

        # load best model and store evaluation metrics
        model.load_state_dict(torch.load(f'saved_models/{model_name}/best_model.pth'))
        auc, ap, f1, acc, precision, recall, mr, mcc = evaluate_fold(test_loader_cv, model, torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), ret=True)
        eval_metrics.append([auc, ap, f1, acc, precision, recall, mr, mcc])
        gc.collect()
        torch.cuda.empty_cache()

    return eval_metrics

### Run Train CV 

In [26]:
gc.collect()
torch.cuda.empty_cache()

In [28]:
# ! rm -rf logs/* saved_models/*

In [29]:
eval_metrics = train_wrap_cv(homo_data, "sdv-hgnn-variant-homo", cv_fold=3, shuffle=True, 
                             num_neighbors=[10, 4], batch_size=64)

Fold 1


2024-08-05 16:12:47.433748: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-08-05 16:12:47.547455: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-05 16:12:47.991116: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2024-08-05 16:12:47.991396: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not l

Device: 'cuda'
Total Number of Parameters: 249984
Total Number of Trainable Parameters: 249984
Starting training loop at 2024-08-05 16:12:50


MiniBatches:   0%|          | 0/6740 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 11.21910834312439
Train loss: 2.7331895480429442
Train accuracy: 0.6133686919897043
Train precision: 0.577590348582715
Train recall: 0.8439279307386319
Train f1: 0.6858084553463903
Train roc_auc: 0.6971765095740495
Train pr_auc: 0.687608227563087
Train average_precision: 0.6639685847950041
Val loss: 1.2457727996748447
Val accuracy: 0.447854520422212
Val precision: 0.47171572593191863
Val recall: 0.8696649839375861
Val f1: 0.61166027839419
Val roc_auc: 0.4225265022458736
Val pr_auc: 0.47782948429123806
Val average_precision: 0.4778991467296386
Best val_f1: 0.61166027839419

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 10.18732738494873
Train loss: 1.1962706754332468
Train accuracy: 0.6812261134076905
Train precision: 0.6240987021310688
Train recall: 0.9113953669760549
Train f1: 0.7408698960182603
Train roc_auc: 0.7456062402215877
Train pr_auc: 0.69300839604029
Train average_precision: 0.6900671400854346
Val loss:

100%|██████████| 1362/1362 [00:05<00:00, 234.88it/s]



Test AUC: 0.3420
Test AP: 0.4871
Test F1: 0.5163
Test Accuracy: 0.3766
Test Precission: 0.4218
Test Recall: 0.6655
Test MCC: -0.3023
Test MR: 9.8770
Fold 2
Device: 'cuda'
Total Number of Parameters: 249984
Total Number of Trainable Parameters: 249984
Starting training loop at 2024-08-05 16:13:58


MiniBatches:   0%|          | 0/6740 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 10.759796142578125
Train loss: 2.8548794336747054
Train accuracy: 0.6154356134466891
Train precision: 0.5787862656374767
Train recall: 0.8480227751345449
Train f1: 0.6880022780484718
Train roc_auc: 0.6943458668849813
Train pr_auc: 0.6842439714341574
Train average_precision: 0.6589796954367474
Val loss: 1.740153941206443
Val accuracy: 0.4471661312528683
Val precision: 0.47142768505304955
Val recall: 0.8717301514456173
Val f1: 0.6119276768815688
Val roc_auc: 0.3954311170122583
Val pr_auc: 0.45429998211703115
Val average_precision: 0.4543304756151678
Best val_f1: 0.6119276768815688

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 11.508731365203857
Train loss: 1.2372667997555244
Train accuracy: 0.6718274705561189
Train precision: 0.6154007333682556
Train recall: 0.9163091802511505
Train f1: 0.7362978283350569
Train roc_auc: 0.7291389197782203
Train pr_auc: 0.674632308321145
Train average_precision: 0.6713832763348493


100%|██████████| 1362/1362 [00:06<00:00, 209.35it/s]



Test AUC: 0.3067
Test AP: 0.4599
Test F1: 0.5245
Test Accuracy: 0.3810
Test Precission: 0.4258
Test Recall: 0.6829
Test MCC: -0.2986
Test MR: 9.6716
Fold 3
Device: 'cuda'
Total Number of Parameters: 249984
Total Number of Trainable Parameters: 249984
Starting training loop at 2024-08-05 16:15:14


MiniBatches:   0%|          | 0/6740 [00:00<?, ?it/s]

Epoch 1
New best model saved!
Epoch Number: 1
Epoch time: 10.446845054626465
Train loss: 2.7653167259663416
Train accuracy: 0.616683566024491
Train precision: 0.5798931909212283
Train recall: 0.8469308166289681
Train f1: 0.6884232549293096
Train roc_auc: 0.69648837553673
Train pr_auc: 0.6839607214744425
Train average_precision: 0.6601892782516559
Val loss: 1.404423611340942
Val accuracy: 0.453132170720514
Val precision: 0.4746131377788826
Val recall: 0.8762046810463515
Val f1: 0.6157133067279398
Val roc_auc: 0.4152265902927123
Val pr_auc: 0.4655298276074895
Val average_precision: 0.4656009232402104
Best val_f1: 0.6157133067279398

Epoch 2
Early stopping counter: 1
Epoch Number: 2
Epoch time: 11.460914850234985
Train loss: 1.382160448225359
Train accuracy: 0.6802706497153108
Train precision: 0.6250169033130494
Train recall: 0.9012557522814133
Train f1: 0.7381382052797164
Train roc_auc: 0.7391693600667744
Train pr_auc: 0.6885370844276095
Train average_precision: 0.6845510138536021
Val lo

100%|██████████| 1362/1362 [00:07<00:00, 194.50it/s]



Test AUC: 0.3172
Test AP: 0.4696
Test F1: 0.4944
Test Accuracy: 0.3610
Test Precission: 0.4090
Test Recall: 0.6248
Test MCC: -0.3272
Test MR: 9.0816


### CV Performance

In [30]:
metrics = ['auc', 'ap', 'f1', 'acc', 'precision', 'recall', 'mr', 'mcc']
metrics_mean_value = np.mean(eval_metrics, axis=0)
metrics_std = np.std(eval_metrics, axis=0)
df = pd.DataFrame({
    'Metric': metrics,
    'Mean': metrics_mean_value,
    'Standard Deviation': metrics_std
})
df

Unnamed: 0,Metric,Mean,Standard Deviation
0,auc,0.321979,0.014771
1,ap,0.472182,0.011243
2,f1,0.51176,0.012723
3,acc,0.372881,0.008572
4,precision,0.418875,0.007154
5,recall,0.657741,0.024335
6,mr,9.543407,0.337145
7,mcc,-0.309364,0.012701
