In [16]:
# import torch_sparse
# print(torch_sparse.__version__)

import torch_geometric
print(torch_geometric.__version__)

import torch
print(torch.__version__)

2.7.0
2.4.1


In [17]:
import pickle
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
# import torch.sparse as tsp
import torch
import torch.nn as nn
from torch_geometric.loader import NeighborLoader
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv, HeteroConv, GraphConv
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from torch.utils.data import random_split
from torch_geometric.data import HeteroData
import warnings
warnings.filterwarnings("ignore")

file = '../data/final_data_dict.pkl'
with open(file, 'rb') as f:   # the whole dataset
    data = pickle.load(f)

    
class HeteroGNN_GraphConv(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = HeteroConv({
            ('disease', 'interacts', 'drug'): GraphConv((-1, -1), hidden_channels),
            ('drug', 'interacts', 'protein'): GraphConv((-1, -1), hidden_channels),
            ('disease', 'interacts', 'protein'): GraphConv((-1, -1), hidden_channels),
            ('protein', 'interacts', 'protein'): GraphConv((-1, -1), hidden_channels),
        })
        self.conv2 = HeteroConv({
            ('disease', 'interacts', 'drug'): GraphConv((hidden_channels, hidden_channels), out_channels),
            ('drug', 'interacts', 'protein'): GraphConv((hidden_channels, hidden_channels), out_channels),
            ('disease', 'interacts', 'protein'): GraphConv((hidden_channels, hidden_channels), out_channels),
            ('protein', 'interacts', 'protein'): GraphConv((hidden_channels, hidden_channels), out_channels),
        })

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict


class HeteroGNN_SAGE(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = HeteroConv({
            ('disease', 'interacts', 'drug'): SAGEConv((-1, -1), hidden_channels),
            ('drug', 'interacts', 'protein'): SAGEConv((-1, -1), hidden_channels),
            ('disease', 'interacts', 'protein'): SAGEConv((-1, -1), hidden_channels),
            ('protein', 'interacts', 'protein'): SAGEConv((-1, -1), hidden_channels),
        })
        self.conv2 = HeteroConv({
            ('disease', 'interacts', 'drug'): SAGEConv((hidden_channels, hidden_channels), out_channels),
            ('drug', 'interacts', 'protein'): SAGEConv((hidden_channels, hidden_channels), out_channels),
            ('disease', 'interacts', 'protein'): SAGEConv((hidden_channels, hidden_channels), out_channels),
            ('protein', 'interacts', 'protein'): SAGEConv((hidden_channels, hidden_channels), out_channels),
        })

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict


def create_pyg_heterograph(data):
    graph = HeteroData()

    drug_feat = torch.tensor(data['d_feat'], dtype=torch.float32)
    disease_feat = torch.tensor(data['dis_feat'], dtype=torch.float32)
    protein_feat = torch.tensor(data['p_feat'], dtype=torch.float32)
    if drug_feat.is_sparse:
        drug_feat = drug_feat.to_dense()
    if disease_feat.is_sparse:
        disease_feat = disease_feat.to_dense()
    if protein_feat.is_sparse:
        protein_feat = protein_feat.to_dense()
    graph['drug'].x = drug_feat
    graph['disease'].x = disease_feat
    graph['protein'].x = protein_feat
    
    drug_disease_edge_index = torch.stack([
        torch.tensor(data['dd_adj'].row, dtype=torch.long),
        torch.tensor(data['dd_adj'].col, dtype=torch.long)
    ])
    drug_protein_edge_index = torch.stack([
        torch.tensor(data['dp_adj'].row, dtype=torch.long),
        torch.tensor(data['dp_adj'].col, dtype=torch.long)
    ])
    disease_protein_edge_index = torch.stack([
        torch.tensor(data['disp_adj'].row, dtype=torch.long),
        torch.tensor(data['disp_adj'].col, dtype=torch.long)
    ])
    protein_protein_edge_index = torch.stack([
        torch.tensor(data['pp_adj'].row, dtype=torch.long),
        torch.tensor(data['pp_adj'].col, dtype=torch.long)
    ])
    graph['disease', 'interacts', 'drug'].edge_index = drug_disease_edge_index
    graph['drug', 'interacts', 'protein'].edge_index = drug_protein_edge_index
    graph['disease', 'interacts', 'protein'].edge_index = disease_protein_edge_index
    graph['protein', 'interacts', 'protein'].edge_index = protein_protein_edge_index
    
    return graph
    
def generate_negative_pairs(num_diseases, num_drugs, positive_pairs, num_neg_samples=None):
    positive_set = set((disease.item(), drug.item()) for disease, drug in zip(positive_pairs[0], positive_pairs[1]))
    all_possible_pairs = [(disease, drug) for disease in range(num_diseases) for drug in range(num_drugs)]
    negative_candidates = [(disease, drug) for disease, drug in all_possible_pairs if (disease, drug) not in positive_set]
    if num_neg_samples is None:
        num_neg_samples = len(positive_pairs[0])
    neg_pairs = random.sample(negative_candidates, num_neg_samples)
    neg_pairs = torch.tensor(neg_pairs, dtype=torch.long).t()
    return neg_pairs

def create_data_loaders(graph, num_parts=100, batch_size=32, shuffle=True):

    num_disease_nodes = graph['disease'].x.size(0)
    disease_1446_idx = 1446  # Ensure that COVID-19 to the test set

    disease_indices = torch.arange(num_disease_nodes)
    disease_indices = disease_indices[disease_indices != disease_1446_idx]  # Exclude disease 1446
    
    train_size = int(0.7 * num_disease_nodes)
    val_size = int(0.15 * num_disease_nodes)
    test_size = num_disease_nodes - train_size - val_size - 1  # Account for disease 1446
    
    train_subset, val_subset, test_subset = random_split(disease_indices, [train_size, val_size, test_size])

    train_idx = torch.tensor(train_subset.indices)
    val_idx = torch.tensor(val_subset.indices)
    test_idx = torch.tensor(test_subset.indices)
    test_idx = torch.cat([test_idx, torch.tensor([disease_1446_idx])])
    
    train_loader = NeighborLoader(
        graph,
        input_nodes=('disease', train_idx),
        num_neighbors=num_parts,
        batch_size=batch_size,
        shuffle=shuffle,
    )
    val_loader = NeighborLoader(
        graph,
        input_nodes=('disease', val_idx),
        num_neighbors=num_parts,
        batch_size=batch_size,
        shuffle=False,  # Validation data is not shuffled
    )
    test_loader = NeighborLoader(
        graph,
        input_nodes=('disease', test_idx),
        num_neighbors=num_parts,
        batch_size=batch_size,
        shuffle=False,  # Test data is not shuffled
    )

    return train_loader, val_loader, test_loader
    
def train_hetero_gnn_with_loader(train_loader, model, optimizer, device, num_epochs=100, patience=10):
    best_val_loss = float('inf')
    counter = 0
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()
        total_loss = 0
        
        # Iterate over mini-batches
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x_dict, batch.edge_index_dict)

            # Positive and negative edges sampling
            pos_edges = batch['disease', 'interacts', 'drug'].edge_index

            num_diseases = batch['disease'].x.size(0)
            num_drugs = batch['drug'].x.size(0)
            neg_edges = generate_negative_pairs(num_diseases, num_drugs, pos_edges, num_neg_samples=len(pos_edges[0]))
            
            pos_score = (out['drug'][pos_edges[1]] * out['disease'][pos_edges[0]]).sum(dim=1)
            neg_score = (out['drug'][neg_edges[1]] * out['disease'][neg_edges[0]]).sum(dim=1)
            loss = torch.mean(F.softplus(neg_score - pos_score))

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss}")

def evaluate_model_with_loader(test_loader, model, device):
    model.eval()
    total_correct = 0
    total_examples = 0

    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            out = model(batch.x_dict, batch.edge_index_dict)

            pos_edges = batch['disease', 'interacts', 'drug'].edge_index

            num_diseases = batch['disease'].x.size(0)
            num_drugs = batch['drug'].x.size(0)
            neg_edges = generate_negative_pairs(num_diseases, num_drugs, pos_edges, num_neg_samples=len(pos_edges[0]))

            pos_score = (out['drug'][pos_edges[1]] * out['disease'][pos_edges[0]]).sum(dim=1).sigmoid()
            neg_score = (out['drug'][neg_edges[1]] * out['disease'][neg_edges[0]]).sum(dim=1).sigmoid()

            scores = torch.cat([pos_score, neg_score]).cpu().numpy()
            labels = torch.cat([torch.ones(pos_score.size(0)), torch.zeros(neg_score.size(0))]).cpu().numpy()

            auc = roc_auc_score(labels, scores)
            f1 = f1_score(labels, (scores > 0.5).astype(int))
            precision = precision_score(labels, (scores > 0.5).astype(int))
            recall = recall_score(labels, (scores > 0.5).astype(int))

            print(f"AUC: {auc}, F1: {f1}, Precision: {precision}, Recall: {recall}")

def get_top_hits_for_disease(model, disease_index, graph, top_k=50):
    model.eval()
    with torch.no_grad():
        out = model(graph.x_dict, graph.edge_index_dict)
    disease_feature = out['disease'][disease_index]
    drug_scores = torch.matmul(out['drug'], disease_feature)
    top_drugs = torch.topk(drug_scores, k=top_k).indices
    return top_drugs

In [12]:
graph = create_pyg_heterograph(data)

hidden_channels = 16
out_channels = 4

models = {
    'GraphConv': HeteroGNN_GraphConv(hidden_channels, out_channels),
    'GraphSAGE': HeteroGNN_SAGE(hidden_channels, out_channels)
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
graph = graph.to(device)

train_loader, val_loader, test_loader = create_data_loaders(graph, num_parts=100, batch_size=32)

for model_name, model in models.items():
    print(f"\nTraining and evaluating {model_name} model:")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    
    train_hetero_gnn_with_loader(train_loader, model, optimizer, device)
    print("\nEvaluate Train")
    evaluate_model_with_loader(train_loader, model, device)
    print("\nEvaluate Val")
    evaluate_model_with_loader(val_loader, model, device)
    print("\nEvaluate Test")
    evaluate_model_with_loader(test_loader, model, device)
    top_drugs = get_top_hits_for_disease(model, 1446, graph)


Training and evaluating GraphConv model:
Epoch 1/100


ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'

In [None]:
best_model = models['GAT']
best_model.load_state_dict(torch.load('best_model.pt'))
# Use the best model for predictions or further analysis
# For example, get top hits for multiple diseases
disease_indices = [1446, 2000, 3000]  # example disease indices
for disease_index in disease_indices:
    top_drugs = get_top_hits_for_disease(best_model, disease_index, graph)
    print(f"Top drugs for disease index {disease_index}: {top_drugs}")