In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import f1_score, precision_score, recall_score
from torch_geometric.nn import GATConv
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import random
import os

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [2]:
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# Model Definitions
class MultiLabelLinkPredictor(nn.Module):
    def __init__(self, model_type, in_dim, hidden_dim, num_classes, dropout=0.3):
        super().__init__()
        self.dropout = dropout
        self.model_type = model_type

        # GNN Layers
        if model_type == 'gat':
            self.conv1 = GATConv(in_dim, hidden_dim, heads=2, concat=False)
            self.conv2 = GATConv(hidden_dim, hidden_dim, heads=2, concat=False)
            self.conv3 = GATConv(hidden_dim, hidden_dim, heads=2, concat=False)
        else:
            raise ValueError("model_type must be 'gat'")

        # MLP Predictor
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x, edge_index, edge_pairs):
        # Apply GNN layers with ReLU and Dropout
        h = F.relu(self.conv1(x, edge_index))
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = F.relu(self.conv2(h, edge_index))
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = F.relu(self.conv3(h, edge_index))
        h = F.dropout(h, p=self.dropout, training=self.training)

        src, dst = edge_pairs
        edge_repr = torch.cat([h[src], h[dst]], dim=1)
        return self.mlp(edge_repr)


In [5]:
def train_model(model, optimizer, loss_fn, x, edge_index, train_loader, val_pairs, val_labels, patience=10, max_epochs=200):
    best_val_f1 = -1.0
    best_state = None
    patience_counter = 0
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)

    for epoch in range(max_epochs):
        model.train()
        total_loss = 0
        for batch_pairs, batch_labels in train_loader:
            batch_pairs = batch_pairs.T.to(device)
            batch_labels = batch_labels.to(device)

            optimizer.zero_grad()
            logits = model(x, edge_index, batch_pairs)
            loss = loss_fn(logits, batch_labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # Validation step
        model.eval()
        with torch.no_grad():
            logits_val = model(x, edge_index, val_pairs)
            val_loss = loss_fn(logits_val, val_labels).item()
            preds_bin = (torch.sigmoid(logits_val) > 0.5).long()
            micro_f1 = f1_score(val_labels.cpu(), preds_bin.cpu(), average='micro')

        scheduler.step(micro_f1)

        if micro_f1 > best_val_f1:
            best_val_f1 = micro_f1
            best_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1

        print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Micro F1: {micro_f1:.4f}, Patience: {patience_counter}/{patience}, LR: {optimizer.param_groups[0]['lr']:.6f}")

        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1} due to no improvement in validation F1.")
            break

    if best_state is not None:
        model.load_state_dict(best_state)
    else:
        print("Warning: No best model state saved. Returning final model state.")
    return model

In [6]:
def test_model(model, x, edge_index, test_pairs, test_labels, batch_size):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        test_dataset = TensorDataset(test_pairs.T, test_labels)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        for batch_pairs, batch_labels in test_loader:
            logits = model(x, edge_index, batch_pairs.T)
            preds = (torch.sigmoid(logits) > 0.5).long()
            all_preds.append(preds.cpu())
            all_labels.append(batch_labels.cpu())

    all_preds = torch.cat(all_preds, dim=0).numpy()
    all_labels = torch.cat(all_labels, dim=0).numpy()

    micro_f1 = f1_score(all_labels, all_preds, average='micro')
    micro_precision = precision_score(all_labels, all_preds, average='micro')
    micro_recall = recall_score(all_labels, all_preds, average='micro')

    print("Final Test Results")
    print(f"Micro F1 Score: {micro_f1:.4f}")
    print(f"Micro Precision: {micro_precision:.4f}")
    print(f"Micro Recall: {micro_recall:.4f}")

In [None]:
if __name__ == '__main__':
    # Load data
    data = torch.load('../processed/ddi_graph.pt')
    data = data.to(device)
    x = data.x.to(device)

    # Split data into train, validation, and test sets
    num_edges = data.edge_index.shape[1]
    all_indices = torch.randperm(num_edges)
    train_ratio = 0.7
    val_ratio = 0.15
    test_ratio = 0.15

    train_size = int(train_ratio * num_edges)
    val_size = int(val_ratio * num_edges)
    test_size = num_edges - train_size - val_size
    train_idx = all_indices[:train_size]
    val_idx = all_indices[train_size: train_size + val_size]
    test_idx = all_indices[train_size + val_size:]

    train_pairs = data.edge_index[:, train_idx]
    train_labels = data.y[train_idx]
    val_pairs = data.edge_index[:, val_idx]
    val_labels = data.y[val_idx]
    test_pairs = data.edge_index[:, test_idx]
    test_labels = data.y[test_idx]

    # Calculate class weights for BCE loss
    pos_counts = train_labels.sum(dim=0).float()
    neg_counts = (train_labels.shape[0] - pos_counts).float()
    pos_weight = (neg_counts / (pos_counts + 1e-6)).clamp(1.0, 10.0).to(device)
    
    # Best hyperparameters from Optuna for GAT
    best_params = {'hidden_dim': 512, 'lr': 0.0017986978630912245, 'dropout': 0.12511777776933947, 'weight_decay': 2.05208440810529e-06, 'batch_size': 256, 'use_focal_loss': False}
    
    # Initialize model, optimizer, and loss function
    model_gat = MultiLabelLinkPredictor('gat', x.shape[1], best_params['hidden_dim'], data.y.shape[1], best_params['dropout']).to(device)
    optimizer_gat = torch.optim.Adam(model_gat.parameters(), lr=best_params['lr'], weight_decay=best_params['weight_decay'])
    loss_fn_gat = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='mean').to(device)

    # Prepare data loader
    train_dataset = TensorDataset(train_pairs.T, train_labels)
    train_loader = DataLoader(train_dataset, batch_size=best_params['batch_size'], shuffle=True)
    
    # Train the model
    trained_model_gat = train_model(model_gat, optimizer_gat, loss_fn_gat, x, data.edge_index, train_loader, val_pairs, val_labels)

    # Test the model
    test_model(trained_model_gat, x, data.edge_index, test_pairs, test_labels, batch_size=best_params['batch_size'])




Epoch 1/200, Train Loss: 0.7280, Val Loss: 0.6876, Val Micro F1: 0.6552, Patience: 0/10, LR: 0.001799
Epoch 2/200, Train Loss: 0.6784, Val Loss: 0.6624, Val Micro F1: 0.6541, Patience: 1/10, LR: 0.001799
Epoch 3/200, Train Loss: 0.6647, Val Loss: 0.6552, Val Micro F1: 0.6572, Patience: 0/10, LR: 0.001799
Epoch 4/200, Train Loss: 0.6537, Val Loss: 0.6440, Val Micro F1: 0.6739, Patience: 0/10, LR: 0.001799
Epoch 5/200, Train Loss: 0.6463, Val Loss: 0.6406, Val Micro F1: 0.6776, Patience: 0/10, LR: 0.001799
Epoch 6/200, Train Loss: 0.6415, Val Loss: 0.6339, Val Micro F1: 0.6832, Patience: 0/10, LR: 0.001799
Epoch 7/200, Train Loss: 0.6376, Val Loss: 0.6298, Val Micro F1: 0.6833, Patience: 0/10, LR: 0.001799
Epoch 8/200, Train Loss: 0.6346, Val Loss: 0.6273, Val Micro F1: 0.6838, Patience: 0/10, LR: 0.001799
