## Graph Neural Network for Molecular Interaction Prediction

This Jupyter Notebook outlines the process for training a Graph Neural Network (GNN) model to predict molecular interactions using the GATv2 architecture. The goal of this project is to leverage the inherent graph structure of molecules for effective prediction of binding to RNA, a crucial factor in drug discovery and biological research.

Each run of this notebook represents a distinct experiment with specified hyperparameters and configurations. Results and models from each run are saved separately for comparative analysis to ensure the reproducibility and statistical significance of our findings.

### Notebook Details:

- **Objective**: Predict molecular interactions with RNA using GNN.
- **Model Architecture**: GATv2Conv from the Deep Graph Library (DGL).
- **Data Source**: Preprocessed molecular interaction datasets.
- **Run Number**: This notebook facilitates multiple runs. Specific details for each run, including the random state and run number, are set at the beginning to ensure reproducibility.

Before executing the notebook, please adjust the `RANDOM_STATE` and `RUN_NUMBER` variables at the top of the notebook to reflect the specific experiment being conducted. This setup ensures each run's outputs are unique and traceable.


In [1]:
"""
This script implements a Graph Neural Network (GNN) using the GATv2 architecture
for the purpose of predicting molecular interactions. The implementation leverages
the Deep Graph Library (DGL) for constructing and manipulating graphs, as well as
Optuna for hyperparameter optimization. The model includes features such as dropout,
early stopping, and gradient scaling for improved training stability and performance.
"""


'\nThis script implements a Graph Neural Network (GNN) using the GATv2 architecture\nfor the purpose of predicting molecular interactions. The implementation leverages\nthe Deep Graph Library (DGL) for constructing and manipulating graphs, as well as\nOptuna for hyperparameter optimization. The model includes features such as dropout,\nearly stopping, and gradient scaling for improved training stability and performance.\n'

In [2]:
# Set the random seed and run number at the top for reproducibility and to differentiate runs
RANDOM_STATE = 420  # Change for each run if needed
RUN_NUMBER = 1  # Change for each run

import numpy as np
import torch
import dgl

np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
dgl.seed(RANDOM_STATE)


## Import Necessary Libraries

In [3]:
import os
import sys
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import dgl
from dgl.nn import GATv2Conv, GlobalAttentionPooling
from dgl.dataloading import GraphDataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
import optuna
from optuna.pruners import MedianPruner
from torch.cuda.amp import GradScaler, autocast


## The Graph Neural Network Model



In [4]:
class GraphClsGATv2(nn.Module):
    """
    A Graph Neural Network (GNN) model using the GATv2 architecture for graph 
    classification.
    
    Parameters
    ----------
    in_feats : int
        The number of input features.
    hidden_dim : int
        The number of output features.
    num_heads : int
        The number of attention heads.
    feat_drop : float
        The dropout rate for the input features.
    attn_drop : float
        The dropout rate for the attention scores.
    negative_slope : float
        The negative slope for the LeakyReLU activation function.
    num_cls : int
        The number of classes for classification.
    """
    def __init__(self, in_feats, hidden_dim, num_heads, feat_drop, attn_drop, 
                 negative_slope, num_cls):
        super(GraphClsGATv2, self).__init__()
        self.gatv2_1 = GATv2Conv(in_feats,
                                 hidden_dim,
                                 num_heads,
                                 feat_drop=feat_drop,
                                 attn_drop=attn_drop,
                                 negative_slope=negative_slope,
                                 activation=F.elu,
                                 residual=True)
        # The second GATv2 layer
        self.gatv2_2 = GATv2Conv((hidden_dim * num_heads),
                                 hidden_dim,  
                                 1,  # Single head in the second layer
                                 feat_drop=feat_drop,
                                 attn_drop=attn_drop,
                                 negative_slope=negative_slope,
                                 activation=F.elu,
                                 residual=True)
        # The global attention pooling layer
        self.pooling = GlobalAttentionPooling(nn.Linear(hidden_dim, 1))
        # The fully connected layer for classification
        self.fc = nn.Linear(hidden_dim, num_cls)
        self.dropout = nn.Dropout(feat_drop)

    def forward(self, graph, feat):
        # Apply the first GATv2 layer
        h = self.gatv2_1(graph, feat)
        # reshape the feature tensor to have the same shape as the input tensor
        h = h.reshape(h.shape[0], -1)
        h = self.dropout(h)
        # No need to flatten since DGL handles dimensionality appropriately
        h = self.gatv2_2(graph, h)
        # Aggregate node features to graph-level features
        hg = self.pooling(graph, h).squeeze()  # Global pooling
        # Classify based on the graph-level representation
        return self.fc(hg)

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        for layer in self.children():
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()
                

## Early Stopping Mechanism



In [5]:
class EarlyStopping:
    """Early stops the training if validation
    loss doesn't improve after a given patience.
    
    Parameters
    ----------
    patience : int
        How long to wait after last time validation loss improved.
    verbose : bool
        If True, prints a message for each validation loss improvement.
    delta : float
        Minimum change in the monitored quantity to qualify as an improvement.
    path : str
        The file path where the model will be saved.
    print_freq : int
        The frequency at which to print messages during training.
    """

    def __init__(
            self,
            patience=7,
            verbose=True,
            delta=0.001,
            path='checkpoint.pt',
            print_freq=5):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.best_epoch = 0
        self.print_freq = print_freq
        
    def __call__(self, val_loss, model, epoch):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.best_epoch = epoch
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose and self.counter % self.print_freq == 0:
                print(
                    f'EarlyStopping counter: {self.counter} '
                    f'out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
            self.best_epoch = epoch

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(
                f'Validation loss decreased ({self.val_loss_min:.6f} '
                f'--> {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss
        

## The Collate Function for DataLoader



In [6]:
def collate(samples):
    """
    Function to collate samples into a batch for the GraphDataLoader.
    
    Parameters
    ----------
    samples : list
        A list of tuples of the form (graph, label).
    """
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    labels = torch.tensor(labels, dtype=torch.long)
    return batched_graph, labels


## Training and Evaluation Pipeline


In [7]:
class TrainingPipeline:
    def __init__(self, device):
        self.device = device

    def train_and_evaluate(
            self,
            model,
            train_loader,
            val_loader,
            optimizer,
            criterion,
            early_stopping,
            num_epochs,
            plot_curves=False,
            accumulation_steps=8):
        train_losses, val_losses = [], []
        scaler = GradScaler()  # Initialize the gradient scaler

        for epoch in range(num_epochs):
            model.train()
            train_loss = 0.0
            optimizer.zero_grad()  # Initialize gradients to zero

            for batch_idx, (batched_graph, labels) in enumerate(train_loader):
                batched_graph, labels = batched_graph.to(
                    self.device), labels.to(self.device)

                with autocast():  # Enable automatic mixed precision
                    logits = model(
                        batched_graph, batched_graph.ndata['h'].float())
                    loss = criterion(logits, labels) / \
                        accumulation_steps  # Scale loss

                # Scale the loss and call backward to propagate gradients
                scaler.scale(loss).backward()
                # Correct scaling for logging purposes
                train_loss += loss.item() * accumulation_steps

                if (batch_idx + 1) % accumulation_steps == 0 or \
                        batch_idx == len(train_loader) - 1:
                    # Perform optimizer step using scaled gradients
                    scaler.step(optimizer)
                    scaler.update()  # Update the scaler for the next iteration
                    optimizer.zero_grad()  # Initialize gradients to zero

            train_loss /= len(train_loader)
            train_losses.append(train_loss)

            val_loss = 0.0
            val_correct = 0
            total = 0
            if val_loader is not None:
                model.eval()
                with torch.no_grad():
                    val_correct = 0
                    total = 0
                    for batched_graph, labels in val_loader:
                        batched_graph, labels = batched_graph.to(
                            self.device), labels.to(self.device)
                        with autocast():  # Enable automatic mixed precision
                            logits = model(
                                batched_graph, batched_graph.ndata['h'].float()
                            )
                            loss = criterion(logits, labels)
                        val_loss += loss.item()
                        _, predicted = torch.max(logits.data, 1)
                        total += labels.size(0)
                        val_correct += (predicted == labels).sum().item()

                    val_loss /= len(val_loader)
                    val_losses.append(val_loss)
                    val_accuracy = val_correct / total

                    if early_stopping:
                        early_stopping(val_loss, model, epoch + 1)
                        if early_stopping.early_stop:
                            print(
                                f"Early stopping triggered"
                                f"at epoch {epoch + 1}")
                            break

                if (epoch + 1) % 5 == 0 or epoch == 0:
                    print(
                        f'Epoch {epoch + 1}/{num_epochs} - '
                        f'Train Loss: {train_loss:.4f}, '
                        f'Val Loss: {val_loss:.4f} '
                        f'| Val accuracy: {100 * val_accuracy:.2f}%')

        if plot_curves and val_loader is not None:
            self.plot_loss_curves(train_losses, val_losses)

    @staticmethod
    def plot_loss_curves(train_losses, val_losses):
        sns.set(style="whitegrid")
        plt.figure(figsize=(10, 6))
        epochs = range(1, len(train_losses) + 1)
        plt.plot(epochs, train_losses, label='Training Loss')
        plt.plot(epochs, val_losses, label='Validation Loss')
        plt.title('Training and Validation Loss of GATv2Conv')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(f'loss_curves_{RUN_NUMBER}.png', dpi=300)
        plt.show()

    def evaluate_on_test(self, model, test_loader, criterion, run_id):
        model.eval()
        test_loss = 0.0
        test_accuracy = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for batched_graph, labels in test_loader:
                batched_graph, labels = batched_graph.to(self.device), 
                labels.to(self.device)
                logits = model(batched_graph, batched_graph.ndata['h'].float())
                loss = criterion(logits, labels)
                test_loss += loss.item()
                preds = torch.argmax(logits, dim=1)
                test_accuracy += torch.sum(preds == labels).item()
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        test_loss /= len(test_loader)
        test_accuracy /= len(test_loader.dataset)
        
        # Calculate ROC-AUC
        roc_auc = roc_auc_score(all_labels, all_preds) 
        
        # Calculate and save confusion matrix
        cm = confusion_matrix(all_labels, all_preds)
        disp = ConfusionMatrixDisplay(confusion_matrix=cm)
        disp.plot()
        plt.savefig(f'confusion_matrix_{run_id}.png', dpi=300)
        
        # Append results to CSV including ROC-AUC
        results_df = pd.DataFrame({
            'Run ID': [run_id],
            'Test Loss': [test_loss],
            'Test Accuracy': [test_accuracy],
            'ROC-AUC': [roc_auc]
        })
        results_df.to_csv('test_results.csv', mode='a', index=False, 
                          header=not os.path.exists('test_results.csv'))
        
        print(f"Test Loss: {test_loss}")
        print(f"Test Accuracy: {test_accuracy}")
        print(f"ROC-AUC: {roc_auc}")


## Hyperparameter Optimization Using Optuna



In [8]:
class HyperparameterOptimizer:
    def __init__(
            self,
            device,
            subset_train_graphs,
            subset_train_labels,
            subset_val_graphs,
            subset_val_labels,
            num_trials,
            num_epochs):
        self.device = device
        self.subset_train_graphs = subset_train_graphs
        self.subset_train_labels = subset_train_labels
        self.subset_val_graphs = subset_val_graphs
        self.subset_val_labels = subset_val_labels
        self.num_trials = num_trials
        self.num_epochs = num_epochs

    def objective(self, trial):
        # Adjusting the hyperparameters for GATv2Conv
        in_feats = 74  
        hidden_dim = trial.suggest_int('hidden_dim', 16, 256)
        num_heads = trial.suggest_categorical('num_heads', [1, 2, 3, 4, 5, 
                                                              6, 7, 8, 9, 10, 
                                                            12, 14, 16, 18, 20])
        feat_drop = trial.suggest_float('feat_drop', 0.0, 0.5)
        attn_drop = trial.suggest_float('attn_drop', 0.0, 0.5)
        negative_slope = trial.suggest_float('negative_slope', 0.01, 0.2)
        lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
        batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128,
                                                                256, 512])

        # Create the model, optimizer, and loaders
        model = GraphClsGATv2(
            in_feats=in_feats,
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            feat_drop=feat_drop,
            attn_drop=attn_drop,
            negative_slope=negative_slope,
            num_cls=2,
        ).to(self.device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        train_loader = GraphDataLoader(
            list(zip(self.subset_train_graphs, self.subset_train_labels)),
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate,
            num_workers=3)
        val_loader = GraphDataLoader(
            list(zip(self.subset_val_graphs, self.subset_val_labels)),
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate,
            num_workers=3)

        # Training loop with pruning
        model.train()
        for epoch in range(self.num_epochs):
            for batched_graph, labels in train_loader:
                batched_graph, labels = batched_graph.to(
                    self.device), labels.to(self.device)
                optimizer.zero_grad()
                logits = model(batched_graph, batched_graph.ndata['h'].float())
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()

            # Validation phase and report for pruning
            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for batched_graph, labels in val_loader:
                    batched_graph, labels = batched_graph.to(
                        self.device), labels.to(self.device)
                    logits = model(
                        batched_graph, batched_graph.ndata['h'].float())
                    loss = criterion(logits, labels)
                    val_loss += loss.item()

            val_loss /= len(val_loader)
            # Report intermediate value to the pruner
            trial.report(val_loss, epoch)

            if trial.should_prune():  # Handle pruning based on the 
                                      # intermediate value
                raise optuna.TrialPruned()

        return val_loss

    def optimize(self):
        """Run the hyperparameter optimization.
        
         Returns
         -------
         dict
             The best hyperparameters found by the optimization.
        """
        study = optuna.create_study(direction='minimize',
                                    pruner=MedianPruner())
        study.optimize(self.objective, n_trials=self.num_trials)

        best_hyperparams = study.best_trial.params
        with open(f'gatv2_best_hyperparams_run_{RUN_NUMBER}.json', 'w') as f:
            json.dump(best_hyperparams, f)
        print(f"Best hyperparameters are {best_hyperparams}.")
        print("Best hyperparameters saved.")
        

## Main Training Loop



In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [10]:
# Load data and prepare for training
reloaded_df = pd.read_csv("data_mvi/combined_df.csv")
graphs, labels_dict = dgl.load_graphs("data_mvi/graphs.bin")
labels = reloaded_df['binds_to_rna'].values

# Split dataset train, test
train_indices, test_indices, train_labels, test_labels = train_test_split(
    range(len(reloaded_df)), labels, test_size=0.2, stratify=labels,
    random_state=42)

# Split dataset train, validation
train_indices, val_indices, train_labels, val_labels = train_test_split(
    train_indices, train_labels, test_size=0.2, stratify=train_labels,
    random_state=42)

# Placeholder for data loading. Replace this with your actual data loading
# code.
train_graphs = [graphs[i] for i in train_indices]
test_graphs = [graphs[i] for i in test_indices]
val_graphs = [graphs[i] for i in val_indices]

subset_train_indices = np.random.choice(
    len(train_graphs), size=int(len(train_graphs) * 0.3), replace=False)
subset_train_graphs = [train_graphs[i] for i in subset_train_indices]
subset_train_labels = train_labels[subset_train_indices]

subset_val_indices = np.random.choice(
    len(val_graphs), size=int(len(val_graphs) * 0.3), replace=False)
subset_val_graphs = [val_graphs[i] for i in subset_val_indices]
subset_val_labels = val_labels[subset_val_indices]

# Combine train and validation graphs and labels for retraining
combined_train_graphs = train_graphs + val_graphs
combined_train_labels = np.concatenate((train_labels, val_labels))

# annouce the start of the project
print("Starting the project...")
print("")

# annouce the start of the data loading
print("Starting data loading...")
print(
    f'Train: {len(train_graphs)}, Validation: {len(val_graphs)}, '
    f'Test: {len(test_graphs)}, \nSubset Train: {len(subset_train_graphs)}, '
    f'Subset Val: {len(subset_val_graphs)}'
)
print("")
print("Completed data loading.")
print("")
sys.stdout.flush()  # Force flushing of the buffer


Starting the project...

Starting data loading...
Train: 47275, Validation: 11819, Test: 14774, 
Subset Train: 14182, Subset Val: 3545

Completed data loading.



## Perform Hyperparameter Optimization


In [None]:
# 1. Hyperparameter Optimization on a subset of the data
print("Starting hyperparameter optimization...")
sys.stdout.flush()
print("")

# Specify the number of trials and epochs for hyperparameter optimization
optimizer = HyperparameterOptimizer(
    device,
    subset_train_graphs,
    subset_train_labels,
    subset_val_graphs,
    subset_val_labels,
    num_trials=50,
    num_epochs=30)
optimizer.optimize()
print("Completed hyperparameter optimization.")
sys.stdout.flush()

print("")


In [None]:
# Load the best hyperparameters
with open('gatv2_best_hyperparams_run_{RUN_NUMBER}.json', 'r') as f:
    best_hyperparams = json.load(f)

train_loader = GraphDataLoader(list(zip(train_graphs,
                                        train_labels)),
                                batch_size=best_hyperparams['batch_size'],
                                shuffle=True,
                                collate_fn=collate,
                                num_workers=8)
val_loader = GraphDataLoader(list(zip(val_graphs,
                                        val_labels)),
                                batch_size=best_hyperparams['batch_size'],
                                shuffle=False,
                                collate_fn=collate,
                                num_workers=8)
test_loader = GraphDataLoader(list(zip(test_graphs,
                                        test_labels)),
                                batch_size=best_hyperparams['batch_size'],
                                shuffle=False,
                                collate_fn=collate,
                                num_workers=8)
combined_train_loader = GraphDataLoader(
    list(
        zip(
            combined_train_graphs,
            combined_train_labels)),
    batch_size=best_hyperparams['batch_size'],
    shuffle=True,
    collate_fn=collate,
    num_workers=8)
print("Data loaders created.")


In [None]:
# 2. Retraining with best hyperparameters (on a larger train and val set)
print("Retraining with best hyperparameters...")
model = GraphClsGATv2(
    in_feats=74,  # Adjust this based on your dataset
    hidden_dim=best_hyperparams['hidden_dim'],
    num_heads=best_hyperparams['num_heads'],
    feat_drop=best_hyperparams['feat_drop'],
    attn_drop=best_hyperparams['attn_drop'],
    negative_slope=best_hyperparams['negative_slope'],
    num_cls=2,  # Adjust based on your dataset
).to(device)

print("")

# Reset model weights and biases parameters before retraining
model.reset_parameters()

optimizer = optim.Adam(model.parameters(), lr=best_hyperparams['lr'])
criterion = nn.CrossEntropyLoss()

early_stopping = EarlyStopping(patience=35, verbose=True)
training_pipeline = TrainingPipeline(device)
training_pipeline.train_and_evaluate(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    early_stopping,
    500,
    plot_curves=True)
optimal_epoch = early_stopping.best_epoch

# Before final training on the combined train and val dataset, reset the
# model weights and biases again
model.reset_parameters()
print("Completed training.")
print("")


## Train Model with Best Hyperparameters



In [None]:
# 3. Final training on the combined train and val dataset
print("Final training on the combined train and val dataset...")
optimizer = optim.Adam(model.parameters(), lr=best_hyperparams['lr'])
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(
    patience=20, verbose=True, delta=0, path='checkpoint.pt')
training_pipeline.train_and_evaluate(
    model,
    combined_train_loader,
    None,
    optimizer,
    criterion,
    None,
    optimal_epoch,
    plot_curves=False)
print("Completed training.")
print("")


## Evaluate the Model



In [None]:

# Evaluation on the test set
print("Evaluating on the test set...")
training_pipeline.evaluate_on_test(model, test_loader, criterion)
print("Completed evaluation.")
