In [1]:
# -*- coding: utf-8 -*-
"""Drug Recommendation with GraphSAGE on Hetionet"""

!pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html
!pip install pykeen torchvision

Looking in links: https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html
Collecting dgl
  Downloading https://data.dgl.ai/wheels/torch-2.4/cu124/dgl-2.4.0%2Bcu124-cp310-cp310-manylinux1_x86_64.whl (347.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m347.8/347.8 MB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
Collecting torch<=2.4.0 (from dgl)
  Downloading torch-2.4.0-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<=2.4.0->dgl)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch<=2.4.0->dgl)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch<=2.4.0->dgl)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn import SAGEConv, HeteroGraphConv
from pykeen.datasets import Hetionet
from tqdm import tqdm

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#%% [Step 1] Load and Prepare Data
dataset = Hetionet()
training_triples = dataset.training.mapped_triples.to(device)

# Create reverse mappings
id_to_entity = {v: k for k, v in dataset.entity_to_id.items()}
id_to_relation = {v: k for k, v in dataset.relation_to_id.items()}

#%% [Step 2] Convert to DGL Heterogeneous Graph
def create_hetero_graph(triples):
    data_dict = {}
    node_counts = {}
    node_mappings = {}

    # First pass: Create node mappings
    for h, _, t in triples.cpu().numpy():
        h_ent = id_to_entity[h].split('::')[0]
        t_ent = id_to_entity[t].split('::')[0]

        for ent, ntype in [(h, h_ent), (t, t_ent)]:
            if ntype not in node_mappings:
                node_mappings[ntype] = {}
                node_counts[ntype] = 0

            if ent not in node_mappings[ntype]:
                node_mappings[ntype][ent] = node_counts[ntype]
                node_counts[ntype] += 1

    # Second pass: Build edges
    for h, r, t in triples.cpu().numpy():
        h_ent = id_to_entity[h].split('::')[0]
        t_ent = id_to_entity[t].split('::')[0]
        rel = id_to_relation[r]

        src = node_mappings[h_ent][h]
        dst = node_mappings[t_ent][t]
        etype = (h_ent, rel, t_ent)

        if etype not in data_dict:
            data_dict[etype] = ([], [])

        data_dict[etype][0].append(src)
        data_dict[etype][1].append(dst)

    # Convert to tensors
    for etype in data_dict:
        data_dict[etype] = (
            torch.tensor(data_dict[etype][0]),
            torch.tensor(data_dict[etype][1])
                        )
    return dgl.heterograph(data_dict), node_mappings

hg, node_mappings = create_hetero_graph(training_triples)
hg = hg.to(device)
print("Heterogeneous Graph Created:")
print(hg)

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


Downloading hetionet-v1.0-edges.sif.gz: 0.00B [00:00, ?B/s]

Heterogeneous Graph Created:
Graph(num_nodes={'Anatomy': 400, 'Biological Process': 11381, 'Cellular Component': 1391, 'Compound': 1538, 'Disease': 136, 'Gene': 19145, 'Molecular Function': 2884, 'Pathway': 1822, 'Pharmacologic Class': 345, 'Side Effect': 5701, 'Symptom': 415},
      num_edges={('Anatomy', 'AdG', 'Gene'): 81434, ('Anatomy', 'AeG', 'Gene'): 419428, ('Anatomy', 'AuG', 'Gene'): 77903, ('Compound', 'CbG', 'Gene'): 9476, ('Compound', 'CcSE', 'Side Effect'): 111871, ('Compound', 'CdG', 'Gene'): 16797, ('Compound', 'CpD', 'Disease'): 316, ('Compound', 'CrC', 'Compound'): 5124, ('Compound', 'CtD', 'Disease'): 599, ('Compound', 'CuG', 'Gene'): 14879, ('Disease', 'DaG', 'Gene'): 10127, ('Disease', 'DdG', 'Gene'): 6082, ('Disease', 'DlA', 'Anatomy'): 2906, ('Disease', 'DpS', 'Symptom'): 2706, ('Disease', 'DrD', 'Disease'): 423, ('Disease', 'DuG', 'Gene'): 6212, ('Gene', 'GcG', 'Gene'): 50979, ('Gene', 'GiG', 'Gene'): 118198, ('Gene', 'GpBP', 'Biological Process'): 448012, ('Gene'

In [3]:
#%% [Step 3] GraphSAGE Model Definition
class HeteroSAGE(nn.Module):
    def __init__(self, hg, in_feats, hidden_size, out_feats):
        super().__init__()
        self.embed = nn.ModuleDict({
            ntype: nn.Embedding(hg.num_nodes(ntype), in_feats)
            for ntype in hg.ntypes
        })

        self.conv1 = HeteroGraphConv({
            etype: SAGEConv(in_feats, hidden_size, 'mean')
            for etype in hg.etypes
        })

        self.conv2 = HeteroGraphConv({
            etype: SAGEConv(hidden_size, out_feats, 'mean')
            for etype in hg.etypes
        })

    def forward(self, g):
        h = {ntype: self.embed[ntype](torch.arange(g.num_nodes(ntype)).to(device))
             for ntype in g.ntypes}
        h = self.conv1(g, h)
        h = {k: F.leaky_relu(v) for k, v in h.items()}
        h = self.conv2(g, h)
        return h

In [4]:
ctd_triples = training_triples[training_triples[:, 1] == dataset.relation_to_id['CtD']]
pos_drugs_orig = ctd_triples[:, 0]  # Compound IDs (head)
pos_diseases_orig = ctd_triples[:, 2]  # Disease IDs (tail)
# Create mapping tensors
compound_ids = list(node_mappings['Compound'].keys())
disease_ids = list(node_mappings['Disease'].keys())

# Create lookup dictionaries
compound_orig2dgl = {orig: dgl_id for orig, dgl_id in node_mappings['Compound'].items()}
disease_orig2dgl = {orig: dgl_id for orig, dgl_id in node_mappings['Disease'].items()}

# Convert original IDs to DGL IDs
pos_drugs_dgl = torch.tensor(
    [compound_orig2dgl[oid.item()] for oid in pos_drugs_orig],
    device=device
)
pos_diseases_dgl = torch.tensor(
    [disease_orig2dgl[oid.item()] for oid in pos_diseases_orig],
    device=device
)

In [6]:
def train(epochs=1000, neg_samples=5):
    model.train()
    for epoch in range(epochs):
        # Positive samples (using DGL-mapped IDs)
        pos_diseases = pos_diseases_dgl
        pos_drugs = pos_drugs_dgl
        
        # Negative sampling (using correct node type)
        neg_drugs = torch.randint(
            0, hg.num_nodes('Compound'),
            (len(pos_diseases) * neg_samples,)
        ).to(device)

        # Get embeddings
        embeddings = model(hg)

        # Calculate scores
        pos_scores = (
            embeddings['Disease'][pos_diseases] * 
            embeddings['Compound'][pos_drugs]
        ).sum(dim=1)
        
        neg_scores = (
            embeddings['Disease'][pos_diseases.repeat_interleave(neg_samples)] * 
            embeddings['Compound'][neg_drugs]
        ).sum(dim=1)

        # Loss calculation and backprop
        loss = F.margin_ranking_loss(
            pos_scores.repeat_interleave(neg_samples),
            neg_scores,
            torch.ones_like(neg_scores),
            margin=1.0
        )
        if epoch % 50 == 0:
            print(epoch, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
model = HeteroSAGE(hg, 128, 256, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
train()

0 67.68355560302734
50 3.6543614864349365
100 1.2399370670318604
150 0.687647819519043
200 0.3975412845611572
250 0.19452595710754395
300 0.3570535182952881
350 0.1712265908718109
400 0.06970909982919693
450 0.0900368019938469
500 0.10378123819828033
550 0.07919901609420776
600 0.07437802106142044
650 0.06348148733377457
700 0.06484603136777878
750 0.05753597244620323
800 0.030201973393559456
850 0.047552887350320816
900 0.0701623260974884
950 0.057983413338661194


In [7]:
#%% [Step 6] Corrected Side Effect Filtering
def calculate_side_effect_risk():
    # Get CcSE (Compound causes Side Effect) triples
    ccse_triples = training_triples[training_triples[:, 1] == dataset.relation_to_id['CcSE']]
    
    # Create mapping from original compound IDs to DGL node IDs
    compound_orig2dgl = {orig: dgl_id for orig, dgl_id in node_mappings['Compound'].items()}
    
    # Initialize counts on correct device
    drug_se_counts = torch.zeros(hg.num_nodes('Compound'), device=device)
    
    # Convert original compound IDs to DGL IDs
    original_drug_ids = ccse_triples[:, 0].unique()
    for orig_id in original_drug_ids:
        dgl_id = compound_orig2dgl.get(orig_id.item(), None)
        if dgl_id is not None:
            count = (ccse_triples[:, 0] == orig_id).sum().item()
            drug_se_counts[dgl_id] = count
    
    # Normalize risk scores
    max_se = drug_se_counts.max()
    risk_scores = drug_se_counts / max_se if max_se > 0 else drug_se_counts
    return risk_scores

risk_scores = calculate_side_effect_risk()
safe_drugs = torch.where(risk_scores <= 0.15)[0].tolist()
risk_scores

tensor([0.3499, 0.0000, 0.1756,  ..., 0.0000, 0.0000, 0.0000], device='cuda:0')

In [8]:
#%% [Step 7] Recommendation Function
def recommend_drugs(disease_name, top_k=10):
    # Convert disease name to ID
    disease_id = next((mapped_id for orig_id, mapped_id in node_mappings['Disease'].items() 
                       if id_to_entity[orig_id] == disease_name), None)

    if disease_id is None:
        print(f"Disease {disease_name} not found!")
        return [], []

    model.eval()
    with torch.no_grad():
        embeddings = model(hg)

    disease_emb = embeddings['Disease'][disease_id]
    drug_embs = embeddings['Compound']
    scores = torch.mm(disease_emb.unsqueeze(0), drug_embs.T).squeeze(0)

    valid_scores = scores[safe_drugs]
    sorted_indices = torch.argsort(valid_scores, descending=True)

    recommendations = []
    top_scores = []

    for idx in sorted_indices[:top_k]:
        try:
            orig_drug_id = list(node_mappings['Compound'].keys())[safe_drugs[idx]]
            recommendations.append(id_to_entity[orig_drug_id])
            top_scores.append(valid_scores[idx].item())
        except Exception as e:
            print(f"Error at index {idx}: {e}")
            break

    return recommendations, top_scores

In [11]:
recommendations = recommend_drugs("Disease::DOID:10283")
print("\nTop Recommended Drugs:")
for i, drug in enumerate(recommendations, 1):
    print(i, drug)


Top Recommended Drugs:
1 ['Compound::DB00665', 'Compound::DB06699', 'Compound::DB08866', 'Compound::DB00655', 'Compound::DB01196', 'Compound::DB00126', 'Compound::DB00253', 'Compound::DB00499', 'Compound::DB01223', 'Compound::DB00509']
2 [204.1898193359375, 203.85418701171875, 195.43780517578125, 189.95806884765625, 178.70669555664062, 178.59559631347656, 174.3504180908203, 172.75682067871094, 169.87081909179688, 168.62030029296875]


In [36]:
def evaluate_hits_at_k(k=3):
    model.eval()
    test_triples = dataset.testing.mapped_triples.to(device)
    ctd_triples = test_triples[test_triples[:, 1] == dataset.relation_to_id['CtD']]
    
    ground_truth = {}
    for drug, _, disease in ctd_triples.cpu().numpy():
        disease_name = id_to_entity[disease]
        drug_name = id_to_entity[drug]

        if disease_name not in ground_truth:
            ground_truth[disease_name] = set()
        ground_truth[disease_name].add(drug_name)

    total_cases = 0
    hits = 0
    drugs = []
    for disease_name in ground_truth:
        recommended_drugs, scores = recommend_drugs(disease_name, top_k=k)
        
        if not recommended_drugs:
            print("none")
            continue  

        true_drugs = ground_truth[disease_name]
        total_cases += 1

        if any(drug in true_drugs for drug in recommended_drugs):
            hits += 1
        drugs.append([disease_name, recommended_drugs, scores])
    hits_at_k = hits / total_cases if total_cases > 0 else 0
    print(f"Hits@{k}: {hits_at_k:.4f}")
    return drugs
    
a = evaluate_hits_at_k(3)

Hits@3: 0.7237


In [33]:
d = []
for i in a:
    l = {}
    l["type"] = "alternate_drug_global"
    l["disease_id"] = i[0].split(":")[-1]
    l["criteria"] = {
      "avoid_side_effects": True,
      "side_effect_threshold": 0.3970146364075608
    }
    l['candidates'] = [j[-4:] for j in i[1]]
    l["scores"] = i[2]
    
    d.append(l)

In [34]:
import json
with open("output.json", "w") as f:
    json.dump(d, f)