In [337]:
!pip install torch



In [338]:
!pip install torch_geometric



In [339]:
import torch
import numpy as np
import pandas as pd
import random
from sklearn import metrics
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import degree
from torch_geometric.utils import negative_sampling


In [340]:
data_path = "drkg.tsv"
df = pd.read_csv(data_path, sep="\t")

In [341]:
df['bioarx::HumGenHumGen:Gene:Gene'].unique()

array(['bioarx::HumGenHumGen:Gene:Gene', 'bioarx::VirGenHumGen:Gene:Gene',
       'bioarx::DrugVirGen:Compound:Gene',
       'bioarx::DrugHumGen:Compound:Gene',
       'bioarx::Covid2_acc_host_gene::Disease:Gene',
       'bioarx::Coronavirus_ass_host_gene::Disease:Gene',
       'DGIDB::INHIBITOR::Gene:Compound',
       'DGIDB::ANTAGONIST::Gene:Compound', 'DGIDB::OTHER::Gene:Compound',
       'DGIDB::AGONIST::Gene:Compound', 'DGIDB::BINDER::Gene:Compound',
       'DGIDB::MODULATOR::Gene:Compound', 'DGIDB::BLOCKER::Gene:Compound',
       'DGIDB::CHANNEL BLOCKER::Gene:Compound',
       'DGIDB::ANTIBODY::Gene:Compound',
       'DGIDB::POSITIVE ALLOSTERIC MODULATOR::Gene:Compound',
       'DGIDB::ALLOSTERIC MODULATOR::Gene:Compound',
       'DGIDB::ACTIVATOR::Gene:Compound',
       'DGIDB::PARTIAL AGONIST::Gene:Compound',
       'DRUGBANK::x-atc::Compound:Atc',
       'DRUGBANK::ddi-interactor-in::Compound:Compound',
       'DRUGBANK::target::Compound:Gene',
       'DRUGBANK::enzyme::Compou

In [342]:
temp_graph = df[(df['bioarx::HumGenHumGen:Gene:Gene'].str.contains('Compound:Disease')) | (df['bioarx::HumGenHumGen:Gene:Gene'].str.contains('Disease:Compound')) | (df['bioarx::HumGenHumGen:Gene:Gene'].str.contains('Compound:Gene')) | (df['bioarx::HumGenHumGen:Gene:Gene'].str.contains('Gene:Compound')) | (df['bioarx::HumGenHumGen:Gene:Gene'].str.contains('Disease:Gene')) | (df['bioarx::HumGenHumGen:Gene:Gene'].str.contains('Gene:Disease'))]

In [343]:
temp_graph

Unnamed: 0,Gene::2157,bioarx::HumGenHumGen:Gene:Gene,Gene::2157.1
58628,Compound::DB02573,bioarx::DrugVirGen:Compound:Gene,Gene::NVA376
58629,Compound::DB05105,bioarx::DrugVirGen:Compound:Gene,Gene::NVA193
58630,Compound::DB05105,bioarx::DrugVirGen:Compound:Gene,Gene::NVA345
58631,Compound::DB00244,bioarx::DrugVirGen:Compound:Gene,Gene::NVA298
58632,Compound::DB00684,bioarx::DrugVirGen:Compound:Gene,Gene::NVA175
...,...,...,...
4123201,Compound::DB00619,INTACT::PHYSICAL ASSOCIATION::Compound:Gene,Gene::780
4123202,Compound::DB00619,INTACT::PHYSICAL ASSOCIATION::Compound:Gene,Gene::84959
4123203,Compound::CHEMBL9506,INTACT::PHYSICAL ASSOCIATION::Compound:Gene,Gene::886
4123204,Compound::DB01037,INTACT::PHYSICAL ASSOCIATION::Compound:Gene,Gene::4129


In [344]:
combination_graph = pd.DataFrame({'source_node':[],'destination_node':[]})
combination_graph['source_node'] = temp_graph['Gene::2157']
combination_graph['destination_node'] = temp_graph['Gene::2157.1']

In [345]:
combination_graph

Unnamed: 0,source_node,destination_node
58628,Compound::DB02573,Gene::NVA376
58629,Compound::DB05105,Gene::NVA193
58630,Compound::DB05105,Gene::NVA345
58631,Compound::DB00244,Gene::NVA298
58632,Compound::DB00684,Gene::NVA175
...,...,...
4123201,Compound::DB00619,Gene::780
4123202,Compound::DB00619,Gene::84959
4123203,Compound::CHEMBL9506,Gene::886
4123204,Compound::DB01037,Gene::4129


In [346]:
def load_node_mapping(df, source_node, destination_node, category, offset=0):
    seta = list(df[(df[source_node].str.contains(category))][source_node].unique())
    setb = list(df[(df[destination_node].str.contains(category))][destination_node].unique())
    nodes = list(set(seta+setb))
    mapping = {index_id: i + offset for i, index_id in enumerate(nodes)}
    return mapping

def load_edge_list(df, src_col, dz_mapping, dst_col, compound_mapping,gene_mapping):

    src_nodes = []
    for i in df[src_col]:
        if 'Compound' in i:
            src_nodes.append(compound_mapping[i])
        elif 'Gene' in i:
            src_nodes.append(gene_mapping[i])
        elif 'Disease' in i:
            src_nodes.append(dz_mapping[i])
            
    dst_nodes = []
    for i in df[dst_col]:
        if 'Compound' in i:
            dst_nodes.append(compound_mapping[i])
        elif 'Gene' in i:
            dst_nodes.append(gene_mapping[i])
        elif 'Disease' in i:
            dst_nodes.append(dz_mapping[i])
            
    edge_index = torch.tensor([src_nodes, dst_nodes])
    return edge_index

def initialize_data(df, num_features=1):
    source_node, destination_node = "source_node", "destination_node"
    dz_mapping = load_node_mapping(df, source_node, destination_node, 'Disease', offset=0)
    compound_mapping = load_node_mapping(df, source_node, destination_node,'Compound', offset=5070)
    gene_mapping = load_node_mapping(df, source_node, destination_node,'Gene', offset=5070+23311)

  # Get edge index in terms of the integer indeces assigned to the nodes.
    edge_index = load_edge_list(df, source_node, dz_mapping, destination_node, compound_mapping,gene_mapping)

  # Add the reverse direction (aka make it a undirected graph)
    rev_edge_index = load_edge_list(df, destination_node, dz_mapping, source_node,  compound_mapping,gene_mapping)


  # Construct a Data object.
    data = Data()
    data.num_nodes = len(dz_mapping) + len(compound_mapping) + len(gene_mapping)
    data.edge_index = torch.cat((edge_index, rev_edge_index), dim=1)
    data.x = torch.ones((data.num_nodes, num_features))

    return data, compound_mapping, dz_mapping, gene_mapping

In [347]:
data_object, compound_mapping, dz_mapping, gene_mapping = initialize_data(combination_graph)
print(data_object)
print("Number of Compounds:", len(compound_mapping))
print("Number of Diseases:", len(dz_mapping))
print("Number of Genes:", len(gene_mapping))

Data(num_nodes=59743, edge_index=[2, 837052], x=[59743, 1])
Number of Compounds: 23311
Number of Diseases: 5070
Number of Genes: 31362


In [348]:
len(compound_mapping),len(dz_mapping),len(gene_mapping)

(23311, 5070, 31362)

In [349]:
print(data_object.edge_index)

tensor([[27313, 13758, 13758,  ..., 57332, 56767, 42142],
        [49131, 52019, 38824,  ..., 17187, 26513, 24099]])


In [350]:
reverse_dz_mapping = {j: i for i,j in dz_mapping.items()}
reverse_gene_mapping = {j: i for i,j in gene_mapping.items()}
reverse_compound_mapping = {j: i for i,j in compound_mapping.items()}

degrees = degree(data_object.edge_index[0]).numpy()
sorted_degrees_i = np.argsort(-1* degrees)

print("Disease Nodes of highest degree")
top_k = 10
num_dz_nodes = 0
for i in sorted_degrees_i:
    if i < 5070:   
        node_id = reverse_dz_mapping[i]
        node_degree = degrees[i]
        print("node_index=" + str(i), "node_degree=" + str(int(node_degree)),node_id)
        num_dz_nodes += 1
        if num_dz_nodes >= top_k:
            break

print("\nGene Nodes of highest degree")
num_gene_nodes = 0
for i in sorted_degrees_i:
    if i >= 5070+23311:
        node_id = reverse_gene_mapping[i]
        node_degree = degrees[i]
        print("node_index=" + str(i), "node_degree=" + str(int(node_degree)), node_id)
        num_gene_nodes += 1
        if num_gene_nodes >= top_k:
              break
    
print("\nCompound Nodes of highest degree")
num_gene_nodes = 0
for i in sorted_degrees_i:
    if 5070+23311>i >= 5070:
        node_id = reverse_compound_mapping[i]
        node_degree = degrees[i]
        print("node_index=" + str(i), "node_degree=" + str(int(node_degree)), node_id)
        num_gene_nodes += 1
        if num_gene_nodes >= top_k:
            break

Disease Nodes of highest degree
node_index=4803 node_degree=8156 Disease::MESH:D009369
node_index=3367 node_degree=3428 Disease::MESH:D001943
node_index=1643 node_degree=2679 Disease::MESH:D006528
node_index=2086 node_degree=2677 Disease::MESH:D015179
node_index=4237 node_degree=2551 Disease::MESH:D011471
node_index=334 node_degree=2463 Disease::MESH:D008545
node_index=3228 node_degree=2400 Disease::MESH:D013274
node_index=1280 node_degree=2139 Disease::MESH:D064420
node_index=1316 node_degree=2034 Disease::MESH:D010190
node_index=1301 node_degree=2007 Disease::MESH:D010051

Gene Nodes of highest degree
node_index=55672 node_degree=2483 Gene::1576
node_index=46152 node_degree=1213 Gene::1565
node_index=58050 node_degree=1166 Gene::5243
node_index=38535 node_degree=1126 Gene::3630
node_index=55287 node_degree=1093 Gene::7124
node_index=48561 node_degree=1041 Gene::1559
node_index=34537 node_degree=987 Gene::213
node_index=42056 node_degree=986 Gene::1544
node_index=52031 node_degree=910

In [351]:
NUM_FEATURES =   20
data_object.x = torch.ones((data_object.num_nodes, NUM_FEATURES))
print("Using dummy embeddings as initial node features.")
print("Number of features set to ", NUM_FEATURES)

Using dummy embeddings as initial node features.
Number of features set to  20


In [352]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        # First layer with ReLU activation
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # Second layer (output layer)
        x = self.conv2(x, edge_index)
        return x

In [387]:
# Model parameters
NUM_FEATURES=20
in_channels = NUM_FEATURES  # Number of input features per node
hidden_channels = 64  # Size of the hidden layer
out_channels = 32  # Size of the output layer

# Initialize the model and move it to GPU if available
model = GCN(in_channels, hidden_channels, out_channels)
#data_object = data_object.to(device)

# Define the optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
criterion = torch.nn.BCEWithLogitsLoss()  # Suitable for binary classification (link prediction)

In [388]:
# Apply the RandomLinkSplit transformation to create train, validation, and test splits
from torch_geometric.transforms import RandomLinkSplit

transform = RandomLinkSplit(num_val=0.1, num_test=0.1, is_undirected=True, add_negative_train_samples=True)
train_data, val_data, test_data = transform(data_object)


# If pos_edge_label_index and neg_edge_label_index are missing, we may need to add them manually
def add_pos_neg_edges(data):
    # Positive edges
    data.pos_edge_label_index = data.edge_index
    
    # Generate negative samples
    data.neg_edge_label_index = negative_sampling(
        edge_index=data.edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=data.edge_index.size(1),
        method='sparse'
    )
    return data

# Apply this function to add attributes if they are missing
if not hasattr(train_data, 'pos_edge_label_index'):
    train_data = add_pos_neg_edges(train_data)
if not hasattr(val_data, 'pos_edge_label_index'):
    val_data = add_pos_neg_edges(val_data)
if not hasattr(test_data, 'pos_edge_label_index'):
    test_data = add_pos_neg_edges(test_data)

In [389]:
def train():
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    z = model(train_data)
    
    # Positive and negative edges for training
    pos_edge_index = train_data.pos_edge_label_index
    neg_edge_index = train_data.neg_edge_label_index
    
    # Calculate link prediction loss
    pos_out = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=1)
    neg_out = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=1)
    
    # Create labels for positive and negative edges
    pos_label = torch.ones(pos_out.size(0))
    neg_label = torch.zeros(neg_out.size(0))
    
    # Concatenate outputs and labels
    out = torch.cat([pos_out, neg_out], dim=0)
    label = torch.cat([pos_label, neg_label], dim=0)
    
    # Compute loss
    loss = criterion(out, label)
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [390]:
@torch.no_grad()
def test(data):
    model.eval()
    z = model(data)
    
    pos_edge_index = data.pos_edge_label_index
    neg_edge_index = data.neg_edge_label_index
    
    pos_out = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=1)
    neg_out = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=1)
    
    # Labels
    pos_label = torch.ones(pos_out.size(0))
    neg_label = torch.zeros(neg_out.size(0))
    
    out = torch.cat([pos_out, neg_out], dim=0)
    label = torch.cat([pos_label, neg_label], dim=0)
    
    # Calculate ROC AUC score
    auc = metrics.roc_auc_score(label.cpu().numpy(), out.cpu().numpy())
    return auc

In [391]:
epochs = 50
for epoch in range(1, epochs + 1):
    loss = train()
    if epoch % 10 == 0:
        val_auc = test(val_data)
        print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Validation AUC: {val_auc:.4f}")

Epoch: 010, Loss: 0.8859, Validation AUC: 0.9624
Epoch: 020, Loss: 0.5367, Validation AUC: 0.9618
Epoch: 030, Loss: 0.5317, Validation AUC: 0.9623
Epoch: 040, Loss: 0.5287, Validation AUC: 0.9630
Epoch: 050, Loss: 0.5276, Validation AUC: 0.9630


In [392]:
test_auc = test(test_data)
print(f"Test AUC: {test_auc:.4f}")

Test AUC: 0.9672


In [393]:
@torch.no_grad()
def predict_for_target_gene(target_gene_id, gene_mapping, compound_mapping, data):
    model.eval()
    # Generate node embeddings (z) using the trained model
    z = model(data)
    
    # Get the embedding of the target gene
    if target_gene_id in gene_mapping:
        gene_index = gene_mapping[target_gene_id]  # Map gene name to node index
    else:
        print("Target gene not found in mapping.")
        return
    
    target_gene_embedding = z[gene_index].unsqueeze(0)  # Get the embedding for the target gene
    
    # Compare target gene embedding with all compound embeddings
    compound_indices = torch.tensor(list(compound_mapping.values()))
    compound_embeddings = z[compound_indices]
    
    # Compute dot product (similarity) between target gene and all compounds
    similarities = torch.matmul(compound_embeddings, target_gene_embedding.T).squeeze()
    
    # Rank the compounds by similarity (higher similarity means stronger potential link)
    top_k = 10  # Choose how many top predictions to return
    top_k_indices = torch.topk(similarities, top_k).indices
    
    # Map the top compound indices back to compound names
    reverse_compound_mapping = {v: k for k, v in compound_mapping.items()}
    
    print(f"Top {top_k} predicted drugs for target gene '{target_gene_id}':")
    for idx in top_k_indices:
        compound_node_id = compound_indices[idx].item()
        compound_name = reverse_compound_mapping[compound_node_id]
        similarity_score = similarities[idx].item()
        print(f"Compound: {compound_name}, Similarity score: {similarity_score:.4f}")

In [402]:
target_gene_id = 'Gene::3630' 
predict_for_target_gene(target_gene_id, gene_mapping, compound_mapping, test_data)

Top 10 predicted drugs for target gene 'Gene::3630':
Compound: Compound::DB09341, Similarity score: 25.3869
Compound: Compound::DB01373, Similarity score: 24.2335
Compound: Compound::CHEBI:33704, Similarity score: 23.9705
Compound: Compound::CHEBI:18186, Similarity score: 23.3798
Compound: Compound::DB00898, Similarity score: 21.0437
Compound: Compound::DB01593, Similarity score: 20.8333
Compound: Compound::DB01234, Similarity score: 19.9081
Compound: Compound::DB09140, Similarity score: 18.6631
Compound: Compound::DB04540, Similarity score: 18.5006
Compound: Compound::MESH:D004967, Similarity score: 18.0445


In [399]:
def predict_for_target_disease(target_disease_id, dz_mapping, compound_mapping, model, data, k=10):
    # Get the embedding of the target disease
    if target_disease_id in dz_mapping:
        disease_index = dz_mapping[target_disease_id]  # Map disease name to node index
    else:
        print("Target disease not found in mapping.")
        return
    
    disease_embedding = model(data)[disease_index].unsqueeze(0)  # Get the embedding for the target disease
    
    # Compare target disease embedding with all compound embeddings
    compound_indices = torch.tensor(list(compound_mapping.values()), device=data.device)
    compound_embeddings = model(data)[compound_indices]

    # Compute dot product (similarity) between target disease and all compounds
    similarities = torch.matmul(compound_embeddings, disease_embedding.T).squeeze()

    # Rank the compounds by similarity (higher similarity means stronger potential link)
    top_k_indices = torch.topk(similarities, k).indices

    # Map the top compound indices back to compound names
    reverse_compound_mapping = {v: k for k, v in compound_mapping.items()}

    # Display the top-k compounds with their similarity scores
    print(f"Top {k} predicted drugs for target disease '{target_disease_id}':")
    for idx in top_k_indices:
        compound_node_id = compound_indices[idx].item()
        compound_name = reverse_compound_mapping[compound_node_id]
        similarity_score = similarities[idx].item()
        print(f"Compound: {compound_name}, Similarity score: {similarity_score:.4f}")


In [404]:
target_disease_id = 'Disease::MESH:D008545'  
predict_for_target_gene(target_disease_id, dz_mapping, compound_mapping, test_data)

Top 10 predicted drugs for target gene 'Disease::MESH:D008545':
Compound: Compound::DB09341, Similarity score: 32.5077
Compound: Compound::DB01373, Similarity score: 30.9982
Compound: Compound::CHEBI:33704, Similarity score: 30.6467
Compound: Compound::CHEBI:18186, Similarity score: 29.9532
Compound: Compound::DB00898, Similarity score: 26.8225
Compound: Compound::DB01593, Similarity score: 26.6728
Compound: Compound::DB01234, Similarity score: 25.4039
Compound: Compound::DB09140, Similarity score: 23.8226
Compound: Compound::DB04540, Similarity score: 23.6440
Compound: Compound::MESH:D004967, Similarity score: 23.0577
