In [None]:
# -*- coding: utf-8 -*-
"""Drug Recommendation with GraphSAGE on Hetionet"""

!pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html
!pip install pykeen torchvision

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn import SAGEConv, HeteroGraphConv
from pykeen.datasets import Hetionet
from tqdm import tqdm

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#%% [Step 1] Load and Prepare Data
dataset = Hetionet()
training_triples = dataset.training.mapped_triples.to(device)

# Create reverse mappings
id_to_entity = {v: k for k, v in dataset.entity_to_id.items()}
id_to_relation = {v: k for k, v in dataset.relation_to_id.items()}

#%% [Step 2] Convert to DGL Heterogeneous Graph
def create_hetero_graph(triples):
    data_dict = {}
    node_counts = {}
    node_mappings = {}

    # First pass: Create node mappings
    for h, _, t in triples.cpu().numpy():
        h_ent = id_to_entity[h].split('::')[0]
        t_ent = id_to_entity[t].split('::')[0]

        for ent, ntype in [(h, h_ent), (t, t_ent)]:
            if ntype not in node_mappings:
                node_mappings[ntype] = {}
                node_counts[ntype] = 0

            if ent not in node_mappings[ntype]:
                node_mappings[ntype][ent] = node_counts[ntype]
                node_counts[ntype] += 1

    # Second pass: Build edges
    for h, r, t in triples.cpu().numpy():
        h_ent = id_to_entity[h].split('::')[0]
        t_ent = id_to_entity[t].split('::')[0]
        rel = id_to_relation[r]

        src = node_mappings[h_ent][h]
        dst = node_mappings[t_ent][t]
        etype = (h_ent, rel, t_ent)

        if etype not in data_dict:
            data_dict[etype] = ([], [])

        data_dict[etype][0].append(src)
        data_dict[etype][1].append(dst)

    # Convert to tensors
    for etype in data_dict:
        data_dict[etype] = (
            torch.tensor(data_dict[etype][0]),
            torch.tensor(data_dict[etype][1])
                        )
    return dgl.heterograph(data_dict), node_mappings

hg, node_mappings = create_hetero_graph(training_triples)
hg = hg.to(device)
print("Heterogeneous Graph Created:")
print(hg)

In [None]:
#%% [Step 3] GraphSAGE Model Definition
class HeteroSAGE(nn.Module):
    def __init__(self, hg, in_feats, hidden_size, out_feats):
        super().__init__()
        self.embed = nn.ModuleDict({
            ntype: nn.Embedding(hg.num_nodes(ntype), in_feats)
            for ntype in hg.ntypes
        })

        self.conv1 = HeteroGraphConv({
            etype: SAGEConv(in_feats, hidden_size, 'mean')
            for etype in hg.etypes
        })

        self.conv2 = HeteroGraphConv({
            etype: SAGEConv(hidden_size, out_feats, 'mean')
            for etype in hg.etypes
        })

    def forward(self, g):
        h = {ntype: self.embed[ntype](torch.arange(g.num_nodes(ntype)).to(device))
             for ntype in g.ntypes}
        h = self.conv1(g, h)
        h = {k: F.leaky_relu(v) for k, v in h.items()}
        h = self.conv2(g, h)
        return h

In [None]:
ctd_triples = training_triples[training_triples[:, 1] == dataset.relation_to_id['CtD']]
pos_drugs_orig = ctd_triples[:, 0]  # Compound IDs (head)
pos_diseases_orig = ctd_triples[:, 2]  # Disease IDs (tail)
# Create mapping tensors
compound_ids = list(node_mappings['Compound'].keys())
disease_ids = list(node_mappings['Disease'].keys())

# Create lookup dictionaries
compound_orig2dgl = {orig: dgl_id for orig, dgl_id in node_mappings['Compound'].items()}
disease_orig2dgl = {orig: dgl_id for orig, dgl_id in node_mappings['Disease'].items()}

# Convert original IDs to DGL IDs
pos_drugs_dgl = torch.tensor(
    [compound_orig2dgl[oid.item()] for oid in pos_drugs_orig],
    device=device
)
pos_diseases_dgl = torch.tensor(
    [disease_orig2dgl[oid.item()] for oid in pos_diseases_orig],
    device=device
)

In [None]:
def train(epochs=1000, neg_samples=5):
    model.train()
    for epoch in range(epochs):
        # Positive samples (using DGL-mapped IDs)
        pos_diseases = pos_diseases_dgl
        pos_drugs = pos_drugs_dgl
        
        # Negative sampling (using correct node type)
        neg_drugs = torch.randint(
            0, hg.num_nodes('Compound'),
            (len(pos_diseases) * neg_samples,)
        ).to(device)

        # Get embeddings
        embeddings = model(hg)

        # Calculate scores
        pos_scores = (
            embeddings['Disease'][pos_diseases] * 
            embeddings['Compound'][pos_drugs]
        ).sum(dim=1)
        
        neg_scores = (
            embeddings['Disease'][pos_diseases.repeat_interleave(neg_samples)] * 
            embeddings['Compound'][neg_drugs]
        ).sum(dim=1)

        # Loss calculation and backprop
        loss = F.margin_ranking_loss(
            pos_scores.repeat_interleave(neg_samples),
            neg_scores,
            torch.ones_like(neg_scores),
            margin=1.0
        )
        print(epoch, loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
model = HeteroSAGE(hg, 128, 256, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
train()

In [42]:
#%% [Step 6] Corrected Side Effect Filtering
def calculate_side_effect_risk():
    # Get CcSE (Compound causes Side Effect) triples
    ccse_triples = training_triples[training_triples[:, 1] == dataset.relation_to_id['CcSE']]
    
    # Create mapping from original compound IDs to DGL node IDs
    compound_orig2dgl = {orig: dgl_id for orig, dgl_id in node_mappings['Compound'].items()}
    
    # Initialize counts on correct device
    drug_se_counts = torch.zeros(hg.num_nodes('Compound'), device=device)
    
    # Convert original compound IDs to DGL IDs
    original_drug_ids = ccse_triples[:, 0].unique()
    for orig_id in original_drug_ids:
        dgl_id = compound_orig2dgl.get(orig_id.item(), None)
        if dgl_id is not None:
            count = (ccse_triples[:, 0] == orig_id).sum().item()
            drug_se_counts[dgl_id] = count
    
    # Normalize risk scores
    max_se = drug_se_counts.max()
    risk_scores = drug_se_counts / max_se if max_se > 0 else drug_se_counts
    return risk_scores

risk_scores = calculate_side_effect_risk()
safe_drugs = torch.where(risk_scores <= 1)[0].tolist()
risk_scores

tensor([0.3499, 0.0000, 0.1756,  ..., 0.0000, 0.0000, 0.0000], device='cuda:0')

In [43]:
#%% [Step 7] Recommendation Function
def recommend_drugs(disease_name, top_k=10):
    # Convert disease name to ID
    disease_key = disease_name
    disease_id = None

    # Find mapped ID
    for orig_id, mapped_id in node_mappings['Disease'].items():
        if id_to_entity[orig_id] == disease_key:
            disease_id = mapped_id
            break

    if disease_id is None:
        print(f"Disease {disease_name} not found!")
        return []

    # Get embeddings
    model.eval()
    with torch.no_grad():
        embeddings = model(hg)

    # Calculate similarity scores
    disease_emb = embeddings['Disease'][disease_id]
    drug_embs = embeddings['Compound']
    scores = torch.mm(disease_emb.unsqueeze(0), drug_embs.T).squeeze(0)

    # Filter and sort
    # valid_scores = scores[safe_drugs]
    sorted_indices = torch.argsort(scores, descending=True)

    # Map back to original drug IDs
    recommendations = []
    for idx in sorted_indices[:top_k]:
        try:
            orig_drug_id = list(node_mappings['Compound'].keys())[safe_drugs[idx]]
        except:

            print(idx)
            break
        recommendations.append(id_to_entity[orig_drug_id])

    return recommendations

In [41]:
safe_drugs[6]

10

In [30]:

recommendations = recommend_drugs("DOID:10283")
print("\nTop Recommended Drugs:")
for i, drug in enumerate(recommendations, 1):
    print(f"{i}. {drug.split('::')[1]}")

Disease DOID:10283 not found!

Top Recommended Drugs:


In [57]:
def evaluate_hits_at_k(k=3):
    model.eval()
    test_triples = dataset.training.mapped_triples.to(device)
    ctd_triples = test_triples[test_triples[:, 1] == dataset.relation_to_id['CtD']]
    
    ground_truth = {}
    for drug, _, disease in ctd_triples.cpu().numpy():
        disease_name = id_to_entity[disease]
        drug_name = id_to_entity[drug]

        if disease_name not in ground_truth:
            ground_truth[disease_name] = set()
        ground_truth[disease_name].add(drug_name)

    # Compute Hits@K
    total_cases = 0
    hits = 0

    for disease_name in ground_truth:
        recommended_drugs = recommend_drugs(disease_name, top_k=k)
        
        if not recommended_drugs:
            print("none")
            continue  # Skip if no recommendations found

        true_drugs = ground_truth[disease_name]
        total_cases += 1

        # Check if at least one ground truth drug is in the top-k recommendations
        if any(drug in true_drugs for drug in recommended_drugs):
            hits += 1

    # Calculate Hits@K metric
    hits_at_k = hits / total_cases if total_cases > 0 else 0
    print(f"Hits@{k}: {hits_at_k:.4f} ({hits}/{total_cases})")

# Evaluate Hits@3
evaluate_hits_at_k(3)

Hits@3: 0.0263 (1/38)
