In [5]:
import pickle
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
import torch
import torch.nn as nn
from torch.utils.data import random_split
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv, HeteroConv
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
import warnings
warnings.filterwarnings("ignore")

In [6]:
file = '../data/final_data_dict.pkl'
with open(file, 'rb') as f:
    data = pickle.load(f)

file = "../data/processed_step2/map_drug.pkl"
with open(file, 'rb') as f:
    drug_map = pickle.load(f)

inverse_drug_map = {}
for key in drug_map:
    value = drug_map[key]
    inverse_drug_map[value] = key

In [9]:
def get_edge_index(adj_matrix):
    edge_index = torch.tensor([adj_matrix.row, adj_matrix.col], dtype=torch.long)
    return edge_index

def generate_negative_samples(n_drug, n_prot, n_dis, pos_edge_index, num_samples):
    disease_offset = n_drug + n_prot
    pos_pairs = set(zip(pos_edge_index[0].cpu().numpy(), pos_edge_index[1].cpu().numpy()))

    all_disease_drug_pairs = [(disease_offset + dis, drug) for dis in range(n_dis)
                          if dis != 1446  # Exclude disease 1446
                          for drug in range(n_drug)]
    # all_disease_drug_pairs = [(disease_offset + dis, drug) for dis in range(n_dis) for drug in range(n_drug)]
    
    negative_pairs = [pair for pair in all_disease_drug_pairs if tuple(pair) not in pos_pairs]
    sampled_negative_pairs = np.random.choice(len(negative_pairs), num_samples, replace=False)
    negative_edge_index = torch.tensor([negative_pairs[i] for i in sampled_negative_pairs], dtype=torch.long).t()
    return negative_edge_index

class FeatureProjector(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FeatureProjector, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        return self.linear(x)

class GCN(nn.Module):
    def __init__(self, drug_input_dim, protein_input_dim, disease_input_dim, hidden_channels, out_channels, feat_dim):
        super(GCN, self).__init__()
        self.drug_projector = FeatureProjector(drug_input_dim, feat_dim)
        self.protein_projector = FeatureProjector(protein_input_dim, feat_dim)
        self.disease_projector = FeatureProjector(disease_input_dim, feat_dim)
        self.conv1 = GCNConv(feat_dim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, drug_feat, protein_feat, disease_feat, edge_index):
        drug_feat = self.drug_projector(drug_feat)
        protein_feat = self.protein_projector(protein_feat)
        disease_feat = self.disease_projector(disease_feat)
        x = torch.cat([drug_feat, protein_feat, disease_feat], dim=0)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

class GAT(torch.nn.Module):
    def __init__(self, drug_input_dim, protein_input_dim, disease_input_dim, n_heads, hidden_channels, out_channels, feat_dim):
        super(GAT, self).__init__()
        self.drug_projector = FeatureProjector(drug_input_dim, feat_dim)
        self.protein_projector = FeatureProjector(protein_input_dim, feat_dim)
        self.disease_projector = FeatureProjector(disease_input_dim, feat_dim)
        self.conv1 = GATConv(feat_dim, hidden_channels, heads=n_heads, concat=True)
        self.conv2 = GATConv(hidden_channels * n_heads, out_channels, heads=1, concat=False)

    def forward(self, drug_feat, protein_feat, disease_feat, edge_index):
        drug_feat = self.drug_projector(drug_feat)
        protein_feat = self.protein_projector(protein_feat)
        disease_feat = self.disease_projector(disease_feat)
        x = torch.cat([drug_feat, protein_feat, disease_feat], dim=0)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = self.conv2(x, edge_index)
        return x

class GraphSAGE(torch.nn.Module):
    def __init__(self, drug_input_dim, protein_input_dim, disease_input_dim, hidden_channels, out_channels, feat_dim):
        super(GraphSAGE, self).__init__()
        self.drug_projector = FeatureProjector(drug_input_dim, feat_dim)
        self.protein_projector = FeatureProjector(protein_input_dim, feat_dim)
        self.disease_projector = FeatureProjector(disease_input_dim, feat_dim)
        self.conv1 = SAGEConv(feat_dim, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, drug_feat, protein_feat, disease_feat, edge_index):
        drug_feat = self.drug_projector(drug_feat)
        protein_feat = self.protein_projector(protein_feat)
        disease_feat = self.disease_projector(disease_feat)
        x = torch.cat([drug_feat, protein_feat, disease_feat], dim=0)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

class GIN(torch.nn.Module):
    def __init__(self, drug_input_dim, protein_input_dim, disease_input_dim, hidden_channels, out_channels, feat_dim):
        super(GIN, self).__init__()
        self.drug_projector = FeatureProjector(drug_input_dim, feat_dim)
        self.protein_projector = FeatureProjector(protein_input_dim, feat_dim)
        self.disease_projector = FeatureProjector(disease_input_dim, feat_dim)
        self.conv1 = GINConv(torch.nn.Sequential(torch.nn.Linear(feat_dim, hidden_channels), torch.nn.ReLU(), torch.nn.Linear(hidden_channels, hidden_channels)))
        self.conv2 = GINConv(torch.nn.Sequential(torch.nn.Linear(hidden_channels, hidden_channels), torch.nn.ReLU(), torch.nn.Linear(hidden_channels, out_channels)))

    def forward(self, drug_feat, protein_feat, disease_feat, edge_index):
        drug_feat = self.drug_projector(drug_feat)
        protein_feat = self.protein_projector(protein_feat)
        disease_feat = self.disease_projector(disease_feat)
        x = torch.cat([drug_feat, protein_feat, disease_feat], dim=0)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

def train(model, optimizer, data, train_pos_edge_index, train_neg_edge_index, drug_feat, protein_feat, disease_feat):
    model.train()
    optimizer.zero_grad()
    z = model(drug_feat, protein_feat, disease_feat, data.edge_index)
    pos_score = (z[train_pos_edge_index[0]] * z[train_pos_edge_index[1]]).sum(dim=1)
    neg_score = (z[train_neg_edge_index[0]] * z[train_neg_edge_index[1]]).sum(dim=1)

    all_scores = torch.cat([pos_score, neg_score], dim=0)
    all_truths = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)], dim=0)
    loss = F.binary_cross_entropy_with_logits(all_scores, all_truths)
    
    # pos_loss = F.binary_cross_entropy_with_logits(pos_score, torch.ones_like(pos_score))
    # neg_loss = F.binary_cross_entropy_with_logits(neg_score, torch.zeros_like(neg_score))
    # loss = pos_loss + neg_loss
    
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model, data, test_pos_edge_index, test_neg_edge_index, drug_feat, protein_feat, disease_feat):
    model.eval()
    with torch.no_grad():
        z = model(drug_feat, protein_feat, disease_feat, data.edge_index)
    
    pos_score = (z[test_pos_edge_index[0]] * z[test_pos_edge_index[1]]).sum(dim=1).sigmoid()
    neg_score = (z[test_neg_edge_index[0]] * z[test_neg_edge_index[1]]).sum(dim=1).sigmoid()
    pos_labels = torch.ones(test_pos_edge_index.size(1))
    neg_labels = torch.zeros(test_neg_edge_index.size(1))
    
    all_scores = torch.cat([pos_score, neg_score], dim=0)
    all_labels = torch.cat([pos_labels, neg_labels], dim=0)
    
    auc = roc_auc_score(all_labels.cpu(), all_scores.cpu())
    f1 = f1_score(all_labels.cpu(), all_scores.cpu().round())
    precision = precision_score(all_labels.cpu(), all_scores.cpu().round())
    recall = recall_score(all_labels.cpu(), all_scores.cpu().round())
    
    return auc, f1, precision, recall

def get_top_hits_for_disease(model, disease_index, data, drug_feat, protein_feat, disease_feat, n_drug, n_prot, top_k=50):
    model.eval()
    with torch.no_grad():
        z = model(drug_feat, protein_feat, disease_feat, data.edge_index)
    disease_feature = z[n_drug + n_prot + disease_index]
    drug_scores = torch.matmul(z[:n_drug], disease_feature)
    top_drug_indices = drug_scores.topk(top_k).indices
    top_drug_scores = drug_scores.topk(top_k).values
    return (top_drug_indices, top_drug_scores)

In [10]:
num_nodes = data['n_drug'] + data['n_prot'] + data['n_dis']

n_drug = data['n_drug']
n_prot = data['n_prot']
n_dis = data['n_dis']

drug_feat = torch.tensor(data['d_feat'], dtype=torch.float32)
disease_feat = torch.tensor(data['dis_feat'], dtype=torch.float32)
protein_feat = torch.tensor(data['p_feat'], dtype=torch.float32)
if drug_feat.is_sparse:
    drug_feat = drug_feat.to_dense()
if disease_feat.is_sparse:
    disease_feat = disease_feat.to_dense()
if protein_feat.is_sparse:
    protein_feat = protein_feat.to_dense()

dp_edge_index = get_edge_index(data['dp_adj'])  # Drug-protein edges
pp_edge_index = get_edge_index(data['pp_adj'])  # Protein-protein edges
dd_edge_index = get_edge_index(data['dd_adj'])  # Disease-drug edges
disp_edge_index = get_edge_index(data['disp_adj'])  # Disease-protein edges

# Callibration for homo
dp_edge_index[1] += n_drug  # Offset protein indices
pp_edge_index += n_drug  # Offset both protein indices
dd_edge_index[0] += n_drug + n_prot  # Offset disease indices
disp_edge_index[0] += n_drug + n_prot  # Offset disease indices
disp_edge_index[1] += n_drug  # Offset protein indices
edge_index = torch.cat([dp_edge_index, pp_edge_index, dd_edge_index, disp_edge_index], dim=1)

graph_data = Data(edge_index=edge_index)

n_dd_edges = dd_edge_index.size(1)
train_size = int(0.7 * n_dd_edges)
val_size = int(0.15 * n_dd_edges)
test_size = n_dd_edges - train_size - val_size

train_edges, val_edges, test_edges = random_split(range(n_dd_edges), [train_size, val_size, test_size])
train_pos_edge_index = dd_edge_index[:, train_edges.indices]
val_pos_edge_index = dd_edge_index[:, val_edges.indices]
test_pos_edge_index = dd_edge_index[:, test_edges.indices]

neg_dd_edges_index = generate_negative_samples(n_drug, n_prot, n_dis, dd_edge_index, dd_edge_index.size(1))
n_neg_dd_edges = neg_dd_edges_index.size(1)
train_edges, val_edges, test_edges = random_split(range(n_neg_dd_edges), [train_size, val_size, test_size])
train_neg_edge_index = neg_dd_edges_index[:, train_edges.indices]
val_neg_edge_index = neg_dd_edges_index[:, val_edges.indices]
test_neg_edge_index = neg_dd_edges_index[:, test_edges.indices]

In [11]:
print(n_drug + n_prot + 1446)

26109


In [12]:
for i in set(zip(neg_dd_edges_index[0].cpu().numpy(), neg_dd_edges_index[1].cpu().numpy())):
    if(i[0]==26109):
        print(i)

In [13]:
print(test_neg_edge_index)
print(test_pos_edge_index)

tensor([[25119, 24847, 24762,  ..., 24982, 24757, 25370],
        [ 3819,  4352,  1171,  ...,  1158,  4303,  4777]])
tensor([[25525, 25226, 25065,  ..., 25524, 25529, 25729],
        [ 6104,  2638,  1110,  ...,  1327,   519,   115]])


In [15]:
feat_dim = 100  # The desired dimension after projection
hidden_channels = 64
out_channels = 16
epochs = 50
topk = 20
n_heads = 4

models = [GCN(drug_feat.shape[1], protein_feat.shape[1], disease_feat.shape[1], hidden_channels, out_channels, feat_dim),
          GAT(drug_feat.shape[1], protein_feat.shape[1], disease_feat.shape[1], n_heads, hidden_channels, out_channels, feat_dim),
          GraphSAGE(drug_feat.shape[1], protein_feat.shape[1], disease_feat.shape[1], hidden_channels, out_channels, feat_dim),
          GIN(drug_feat.shape[1], protein_feat.shape[1], disease_feat.shape[1], hidden_channels, out_channels, feat_dim)]
model_names = ["GCN", "GAT", "SAGE", "GIN"]

top_drugs_dict = dict()

for i in range(len(models)):
    model = models[i]
    name = model_names[i]
    print("!!!!!!!!!!!Starting Training for", name)
    print()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    for epoch in range(epochs):
        loss = train(model, optimizer, graph_data, train_pos_edge_index, train_neg_edge_index, drug_feat, protein_feat, disease_feat)
        if epoch % 5 == 0:
            print(f'Epoch {epoch}, Loss: {loss}')
    print(f'Epoch {epoch}, Loss: {loss}')
    print()
    auc, f1, precision, recall = test(model, graph_data, train_pos_edge_index, train_neg_edge_index, drug_feat, protein_feat, disease_feat)
    print(f"Train Metrics:\tAUC: {auc:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")
    auc, f1, precision, recall = test(model, graph_data, val_pos_edge_index, val_neg_edge_index, drug_feat, protein_feat, disease_feat)
    print(f"Val Metrics:\tAUC: {auc:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

    print("\nTop drugs list for index 1446")
    top_drugs, top_drugs_scores = get_top_hits_for_disease(model, 1446, graph_data, drug_feat, protein_feat, disease_feat, n_drug, n_prot, top_k=topk)
    top_drugs = top_drugs.tolist()
    top_drugs_scores = top_drugs_scores.tolist()
    top_drugs_scores = [round(num, 3) for num in top_drugs_scores]
    print(top_drugs)
    print(top_drugs_scores)
    top_drug_names = []
    for drug in top_drugs:
        top_drug_names.append(inverse_drug_map[drug])
    print(top_drug_names)
    top_drugs_dict[name] = top_drug_names
    print("\n")

!!!!!!!!!!!Starting Training for GCN

Epoch 0, Loss: 0.6815246939659119
Epoch 5, Loss: 0.41296181082725525
Epoch 10, Loss: 0.2525148391723633
Epoch 15, Loss: 0.14256665110588074
Epoch 20, Loss: 0.12041120231151581
Epoch 25, Loss: 0.1127784475684166
Epoch 30, Loss: 0.10672105103731155
Epoch 35, Loss: 0.10297480970621109
Epoch 40, Loss: 0.100586898624897
Epoch 45, Loss: 0.09942039102315903
Epoch 49, Loss: 0.09882950782775879

Train Metrics:	AUC: 0.9931, F1: 0.9625, Precision: 0.9459, Recall: 0.9796
Val Metrics:	AUC: 0.9914, F1: 0.9603, Precision: 0.9435, Recall: 0.9777

Top drugs list for index 1446
[280, 282, 707, 74, 887, 682, 355, 1889, 1073, 497, 465, 451, 137, 574, 1069, 2634, 606, 113, 1327, 1361]
[6.334, 6.102, 5.673, 5.532, 5.477, 5.271, 4.993, 4.815, 4.69, 4.676, 4.652, 4.508, 4.5, 4.483, 4.477, 4.474, 4.445, 4.25, 4.247, 4.192]
['DB00313', 'DB00316', 'DB00783', 'DB00091', 'DB00977', 'DB00755', 'DB00396', 'DB02709', 'DB01174', 'DB00550', 'DB00515', 'DB00499', 'DB00158', 'DB00636

In [16]:
drug_counts = {}
for name in top_drugs_dict:
    drug_list = top_drugs_dict[name]
    for drug in drug_list:
        if(drug not in drug_counts):
            drug_counts[drug] = 0
        drug_counts[drug]+=1

sorted_dict = dict(sorted(drug_counts.items(), key=lambda item: item[1], reverse=True))
print(sorted_dict)

{'DB00316': 4, 'DB00783': 4, 'DB00091': 4, 'DB00977': 4, 'DB00755': 4, 'DB00396': 4, 'DB02709': 4, 'DB01174': 4, 'DB00515': 4, 'DB00158': 4, 'DB04216': 4, 'DB00675': 4, 'DB00313': 3, 'DB00550': 3, 'DB00499': 3, 'DB00636': 3, 'DB01169': 3, 'DB00134': 3, 'DB01593': 3, 'DB01645': 3, 'DB00122': 2, 'DB01234': 2, 'DB01030': 1, 'DB00997': 1, 'DB00136': 1, 'DB00907': 1}
