# Drug Repurposing Pathfinding Algorithm Benchmark

**Purpose:** Evaluate graph pathfinding algorithms for drug repurposing by comparing predicted mechanistic pathways against curated ground truth pathways.



---

## Setup


In [38]:
# Standard libraries
import pandas as pd
import numpy as np
import networkx as nx
import time
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
import heapq
from typing import Dict, List, Tuple
from collections import deque, Counter
from evaluation_helpers import *
from evaluation_metrics import *
from Algorithms import *

In [28]:


# === CONFIGURE PATHS ===
# Update this to your project root directory
DATA_DIR = '/Users/maxchiu/Desktop/Primkg_github/PrimeKG-Pathfinding-Algorithm-Benchmark-Laboratory'

# File paths (relative to DATA_DIR)
PATHS = {
    'nodes': f'{DATA_DIR}/Data/nodes_cleaned.csv',
    'edges': f'{DATA_DIR}/Data/edges_cleaned.csv',
    'ground_truth_nodes': f'{DATA_DIR}/Cleaned_Ground_Truths/benchmark_pathways_nodes.csv',
    'ground_truth_edges': f'{DATA_DIR}/Cleaned_Ground_Truths/benchmark_pathways_edges.csv'
}

print("Configuration complete. Paths set to:")
for name, path in PATHS.items():
    print(f"  {name}: {path}")

Configuration complete. Paths set to:
  nodes: /Users/maxchiu/Desktop/Primkg_github/PrimeKG-Pathfinding-Algorithm-Benchmark-Laboratory/Data/nodes_cleaned.csv
  edges: /Users/maxchiu/Desktop/Primkg_github/PrimeKG-Pathfinding-Algorithm-Benchmark-Laboratory/Data/edges_cleaned.csv
  ground_truth_nodes: /Users/maxchiu/Desktop/Primkg_github/PrimeKG-Pathfinding-Algorithm-Benchmark-Laboratory/Cleaned_Ground_Truths/benchmark_pathways_nodes.csv
  ground_truth_edges: /Users/maxchiu/Desktop/Primkg_github/PrimeKG-Pathfinding-Algorithm-Benchmark-Laboratory/Cleaned_Ground_Truths/benchmark_pathways_edges.csv


---
## Load Data

Load the PrimeKG knowledge graph and ground truth pathways.

In [29]:
# Load PrimeKG data
print("Loading PrimeKG data...")
nodes = pd.read_csv(PATHS['nodes'], encoding="latin1")
edges = pd.read_csv(PATHS['edges'], encoding="latin1")

print(f"  Nodes: {len(nodes):,}")
print(f"  Edges: {len(edges):,}")
print(f"  Node types: {nodes[':LABEL'].nunique()}")
print(f"  Edge types: {edges[':TYPE'].nunique()}")

# Load ground truth
print("\nLoading ground truth pathways...")
ground_truth_nodes = pd.read_csv(PATHS['ground_truth_nodes'], dtype={'node_index': int})
ground_truth_edges = pd.read_csv(PATHS['ground_truth_edges'])

pathways = ground_truth_nodes['pathway_id'].unique()
print(f"  Pathways: {len(pathways)}")
for p in pathways:
    n_nodes = len(ground_truth_nodes[ground_truth_nodes['pathway_id'] == p])
    print(f"    - {p}: {n_nodes} nodes")

Loading PrimeKG data...
  Nodes: 129,375
  Edges: 4,050,064
  Node types: 10
  Edge types: 30

Loading ground truth pathways...
  Pathways: 343
    - valganciclovir_CMV_infection: 3 nodes
    - antazoline_Vasomotor_rhinitis: 3 nodes
    - apalutamide_Nonmetastatic_prostate_cancer: 3 nodes
    - aminoglutethimide_Secondary_malignant_neoplasm_o: 3 nodes
    - anastrozole_Hormone_receptor_positive_mali: 3 nodes
    - diflunisal_Rheumatoid_arthritis: 5 nodes
    - mepyramine_Vasomotor_rhinitis: 3 nodes
    - azilsartan_medoxomil_Hypertensive_disorder: 3 nodes
    - bromocriptine_Hyperprolactinemia: 5 nodes
    - amobarbital_Epilepsy: 6 nodes
    - terbutaline_Asthma: 3 nodes
    - tafamidis_Amyloidosis: 3 nodes
    - FYX-051_Hyperuricemia: 3 nodes
    - FYX-051_Gout: 3 nodes
    - glimepiride_Diabetes_mellitus_type_2: 3 nodes
    - capecitabine_Malignant_tumor_of_stomach: 4 nodes
    - oxaprozin_Rheumatoid_arthritis: 5 nodes
    - pimozide_Gilles_de_la_Tourette's_syndro: 3 nodes
    - clem

---
## Build Knowledge Graph

Construct a NetworkX directed graph with node/edge attributes for pathfinding.

In [30]:
def build_graph(nodes_df, edges_df, bidirectional=True):
    """
    Build a NetworkX graph from PrimeKG data.
    
    Args:
        nodes_df: DataFrame with columns [node_index:ID, node_id, node_name, :LABEL]
        edges_df: DataFrame with columns [:START_ID, :END_ID, :TYPE, display_relation]
        bidirectional: If True, add edges in both directions
    
    Returns:
        NetworkX DiGraph
    """
    G = nx.DiGraph()
    
    # Add nodes
    for _, row in nodes_df.iterrows():
        G.add_node(
            int(row['node_index:ID']),
            node_id=str(row['node_id']),
            node_name=str(row['node_name']),
            node_type=str(row[':LABEL'])
        )
    
    # Add edges
    for _, row in edges_df.iterrows():
        G.add_edge(
            int(row[':START_ID']), 
            int(row[':END_ID']),
            relation=str(row[':TYPE']),
            display_relation=str(row['display_relation'])
        )
        if bidirectional:
            G.add_edge(
                int(row[':END_ID']), 
                int(row[':START_ID']),
                relation=str(row[':TYPE']),
                display_relation=str(row['display_relation'])
            )
    
    return G

# Build the graph
print("Building graph...")
G = build_graph(nodes, edges, bidirectional=True)
print(f"Graph built: {G.number_of_nodes():,} nodes, {G.number_of_edges():,} edges")

Building graph...
Graph built: 129,375 nodes, 8,099,284 edges


---
## Algorithm 1 - Shortest Path Baseline

The simplest baseline: find the shortest path (by hop count) between drug and disease.

**Expected behavior:**
- ✅ Will always find the target disease (if connected)
- ❌ May take shortcuts through direct drug→disease edges
- ❌ Ignores edge types and biological mechanism

In [32]:
def run_shortest_path(graph, ground_truth_df):
    """
    Run shortest path algorithm on all pathways.
    
    Returns:
        DataFrame with predictions for each pathway
    """
    results = []
    
    for pathway_id in ground_truth_df['pathway_id'].unique():
        pathway_df = ground_truth_df[ground_truth_df['pathway_id'] == pathway_id].sort_values('step_order')
        
        # Get source (drug) and target (disease) indices
        source_idx = int(pathway_df.iloc[0]['node_index'])
        target_idx = int(pathway_df.iloc[-1]['node_index'])
        
        source_name = pathway_df.iloc[0]['node_name']
        target_name = pathway_df.iloc[-1]['node_name']
        
        print(f"\n{pathway_id}: {source_name} → {target_name}")
        
        try:
            # Find shortest path
            predicted_path = nx.shortest_path(graph, source_idx, target_idx)
            predicted_node_ids = [graph.nodes[idx]['node_id'] for idx in predicted_path]
            predicted_node_names = [graph.nodes[idx]['node_name'] for idx in predicted_path]
            
            # Get edge relations along path
            predicted_relations = []
            for i in range(len(predicted_path) - 1):
                edge_data = graph.get_edge_data(predicted_path[i], predicted_path[i+1])
                predicted_relations.append(edge_data['relation'])
            
            print(f"  ✓ Found path: {len(predicted_path)} nodes")
            print(f"  Path: {' → '.join(predicted_node_names[:5])}{'...' if len(predicted_path) > 5 else ''}")
            
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': ','.join(map(str, predicted_path)),
                'predicted_node_ids': ','.join(predicted_node_ids),
                'predicted_node_names': ','.join(predicted_node_names),
                'predicted_relations': ','.join(predicted_relations),
                'predicted_length': len(predicted_path),
                'ground_truth_length': len(pathway_df)
            })
            
        except nx.NetworkXNoPath:
            print(f"  ✗ No path found")
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': 'NONE',
                'predicted_node_ids': 'NONE',
                'predicted_node_names': 'NONE',
                'predicted_relations': 'NONE',
                'predicted_length': 0,
                'ground_truth_length': len(pathway_df)
            })
    
    return pd.DataFrame(results)


# Run shortest path
print("="*60)
print("Running Shortest Path Algorithm")
print("="*60)

sp_predictions = run_shortest_path(G, ground_truth_nodes)
# sp_predictions.to_csv('baseline_shortest_path_predictions.csv', index=False)
# print("\n✓ Saved: baseline_shortest_path_predictions.csv")

Running Shortest Path Algorithm

valganciclovir_CMV_infection: valganciclovir → CMV infection
  ✓ Found path: 3 nodes
  Path: Valganciclovir → Ganciclovir → cytomegalovirus infection

antazoline_Vasomotor_rhinitis: antazoline → Vasomotor rhinitis
  ✓ Found path: 2 nodes
  Path: Antazoline → vasomotor rhinitis

apalutamide_Nonmetastatic_prostate_cancer: apalutamide → Nonmetastatic prostate cancer
  ✓ Found path: 2 nodes
  Path: Apalutamide → prostate cancer

aminoglutethimide_Secondary_malignant_neoplasm_o: aminoglutethimide → Secondary malignant neoplasm of female breast
  ✓ Found path: 3 nodes
  Path: Aminoglutethimide → Paclitaxel → breast neoplasm

anastrozole_Hormone_receptor_positive_mali: anastrozole → Hormone receptor positive malignant neoplasm of breast
  ✓ Found path: 3 nodes
  Path: Anastrozole → Bevacizumab → breast neoplasm

diflunisal_Rheumatoid_arthritis: diflunisal → Rheumatoid arthritis
  ✓ Found path: 2 nodes
  Path: Diflunisal → rheumatoid arthritis

mepyramine_Vasom

---
## Meta-Path Constrained BFS

An improved baseline that enforces biologically valid edge type sequences.

**Valid meta-path patterns:**
1. `drug → protein → disease` (direct mechanism)
2. `drug → protein → protein → disease` (protein interactions)
3. `drug → protein → anatomy → protein → disease` (tissue-specific)

**Invalid shortcuts blocked:**
- ❌ `drug → disease` (clinical indication, not mechanism)
- ❌ `drug → drug → disease` (drug similarity)

In [37]:
# Define valid meta-path patterns (edge type sequences)
VALID_METAPATHS = [
    # Pattern 1: Direct protein mechanism (drug → protein → disease)
    # ['drug_protein', 'disease_protein'],
    
    # Pattern 2: Protein-protein interaction
    ['drug_protein', 'protein_protein', 'disease_protein'],
    
    # Pattern 3: Multiple protein interactions
    ['drug_protein', 'protein_protein', 'protein_protein', 'disease_protein'],
    
    # Pattern 4: Pathway-mediated
    ['drug_protein', 'pathway_protein', 'disease_protein'],
    ['drug_protein', 'pathway_protein', 'pathway_protein', 'disease_protein'],
    ['drug_protein', 'pathway_protein', 'pathway_pathway', 'pathway_protein', 'disease_protein'],
    
    # Pattern 5: Anatomy-mediated (tissue-specific)
    ['drug_protein', 'anatomy_protein_present', 'anatomy_protein_present', 'disease_protein'],
    
    # Pattern 6: Complex pathways
    ['drug_protein', 'protein_protein', 'pathway_protein', 'disease_protein'],
    ['drug_protein', 'pathway_protein', 'pathway_protein', 'pathway_protein', 'disease_protein'],
]

print(f"Defined {len(VALID_METAPATHS)} valid meta-path patterns:")
for i, pattern in enumerate(VALID_METAPATHS, 1):
    print(f"  {i}. {' → '.join(pattern)}")

Defined 8 valid meta-path patterns:
  1. drug_protein → protein_protein → disease_protein
  2. drug_protein → protein_protein → protein_protein → disease_protein
  3. drug_protein → pathway_protein → disease_protein
  4. drug_protein → pathway_protein → pathway_protein → disease_protein
  5. drug_protein → pathway_protein → pathway_pathway → pathway_protein → disease_protein
  6. drug_protein → anatomy_protein_present → anatomy_protein_present → disease_protein
  7. drug_protein → protein_protein → pathway_protein → disease_protein
  8. drug_protein → pathway_protein → pathway_protein → pathway_protein → disease_protein


In [35]:
# def is_valid_metapath(relations, valid_metapaths):
#     """Check if a relation sequence matches any valid meta-path pattern."""
#     return relations in valid_metapaths


# def could_match_metapath(relations, valid_metapaths):
#     """Check if the current relation sequence could potentially lead to a valid path."""
#     for pattern in valid_metapaths:
#         if len(relations) <= len(pattern):
#             if relations == pattern[:len(relations)]:
#                 return True
#     return False


# def metapath_constrained_bfs(source_idx, target_idx, graph, valid_metapaths, max_length=10):
#     """
#     Find shortest path that follows valid meta-path patterns.
    
#     Uses BFS but only explores edges that could lead to a valid meta-path.
    
#     Returns:
#         (path_nodes, path_relations) or ([], []) if no valid path found
#     """
#     # Queue: (current_node, path_so_far, relations_so_far)
#     queue = deque([(source_idx, [source_idx], [])])
#     visited = {source_idx: []}  # Track visited states with relation sequences
    
#     while queue:
#         current_node, path, relations = queue.popleft()
        
#         # Check if we reached target with valid meta-path
#         if current_node == target_idx:
#             if is_valid_metapath(relations, valid_metapaths):
#                 return path, relations
        
#         # Stop if path too long
#         if len(path) >= max_length:
#             continue
        
#         # Explore neighbors
#         for neighbor in graph.neighbors(current_node):
#             edge_data = graph.get_edge_data(current_node, neighbor)
#             new_relation = edge_data['relation']
#             new_relations = relations + [new_relation]
            
#             # Only continue if this could lead to a valid meta-path
#             if could_match_metapath(new_relations, valid_metapaths):
#                 state_key = (neighbor, tuple(new_relations))
                
#                 # Avoid revisiting same state
#                 if neighbor not in visited or visited[neighbor] != new_relations:
#                     visited[neighbor] = new_relations
#                     queue.append((neighbor, path + [neighbor], new_relations))
    
#     return [], []  # No valid path found


# def run_metapath_algorithm(graph, ground_truth_df, valid_metapaths):
#     """
#     Run meta-path constrained BFS on all pathways.
#     """
#     results = []
    
#     for pathway_id in ground_truth_df['pathway_id'].unique():
#         pathway_df = ground_truth_df[ground_truth_df['pathway_id'] == pathway_id].sort_values('step_order')
        
#         source_idx = int(pathway_df.iloc[0]['node_index'])
#         target_idx = int(pathway_df.iloc[-1]['node_index'])
        
#         source_name = pathway_df.iloc[0]['node_name']
#         target_name = pathway_df.iloc[-1]['node_name']
        
#         gt_path = ' → '.join(pathway_df['node_name'].tolist())
        
#         print(f"\n{pathway_id}: {source_name} → {target_name}")
        
#         # Find meta-path constrained path
#         predicted_path, predicted_relations = metapath_constrained_bfs(
#             source_idx, target_idx, graph, valid_metapaths
#         )
        
#         if predicted_path:
#             predicted_node_ids = [graph.nodes[idx]['node_id'] for idx in predicted_path]
#             predicted_node_names = [graph.nodes[idx]['node_name'] for idx in predicted_path]
            
#             print(f"  ✓ Found valid path: {len(predicted_path)} nodes")
#             print(f"  Meta-path: {' → '.join(predicted_relations)}")
#             print(f"  Path: {' → '.join(predicted_node_names)}")
#             print(f"  Ground truth: {gt_path}")
            
#             results.append({
#                 'pathway_id': pathway_id,
#                 'predicted_node_indices': ','.join(map(str, predicted_path)),
#                 'predicted_node_ids': ','.join(predicted_node_ids),
#                 'predicted_node_names': ','.join(predicted_node_names),
#                 'predicted_relations': ','.join(predicted_relations),
#                 'predicted_length': len(predicted_path),
#                 'ground_truth_length': len(pathway_df)
#             })
#         else:
#             print(f"  ✗ No valid meta-path found")
#             results.append({
#                 'pathway_id': pathway_id,
#                 'predicted_node_indices': 'NONE',
#                 'predicted_node_ids': 'NONE',
#                 'predicted_node_names': 'NONE',
#                 'predicted_relations': 'NONE',
#                 'predicted_length': 0,
#                 'ground_truth_length': len(pathway_df)
#             })
    
#     return pd.DataFrame(results)


# # Run meta-path algorithm
# print("="*60)
# print("Running Meta-Path Constrained Algorithm")
# print("="*60)

# mp_predictions = run_metapath_algorithm(G, ground_truth_nodes, VALID_METAPATHS)
# # mp_predictions.to_csv('baseline_metapath_predictions.csv', index=False)
# # print("\n✓ Saved: baseline_metapath_predictions.csv")

## Algorithm 2: Hub-Penalized Weighted Shortest Path

**Core Idea:** High-degree "hub" nodes (like inflammation markers) connect to everything but don't represent specific mechanisms. Penalize them.

**Weight Formula:** `weight[u,v] = 1 + α * log(degree[v])`

- α = 0.5 is a good default (can be tuned)
- Higher degree → higher weight → less preferred

In [39]:
# ============================================================
# ALGORITHM 2: Hub-Penalized Weighted Shortest Path
# ============================================================


def run_hub_penalized(graph, ground_truth_df, alpha=0.5):
    """
    Run Hub-Penalized algorithm on all pathways.
    """
    results = []
    
    # Initialize algorithm
    print("Initializing Hub-Penalized algorithm...")
    algo = HubPenalizedShortestPath(graph, alpha=alpha)
    print(f"  Edge weights computed (α={alpha})")
    
    for pathway_id in ground_truth_df['pathway_id'].unique():
        pathway_df = ground_truth_df[ground_truth_df['pathway_id'] == pathway_id].sort_values('step_order')
        
        source_idx = int(pathway_df.iloc[0]['node_index'])
        target_idx = int(pathway_df.iloc[-1]['node_index'])
        source_name = pathway_df.iloc[0]['node_name']
        target_name = pathway_df.iloc[-1]['node_name']
        
        print(f"\n{pathway_id}: {source_name} → {target_name}")
        
        path, relations, weight = algo.find_path(source_idx, target_idx)
        
        if path:
            node_ids = [graph.nodes[idx].get('node_id', str(idx)) for idx in path]
            node_names = [graph.nodes[idx].get('node_name', str(idx)) for idx in path]
            
            print(f"  ✓ Found path: {len(path)} nodes")
            print(f"  Path: {' → '.join(node_names)}")
            
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': ','.join(map(str, path)),
                'predicted_node_ids': ','.join(node_ids),
                'predicted_node_names': ','.join(node_names),
                'predicted_relations': ','.join(relations),
                'predicted_length': len(path),
                'ground_truth_length': len(pathway_df)
            })
        else:
            print(f"  ✗ No path found")
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': 'NONE',
                'predicted_node_ids': 'NONE',
                'predicted_node_names': 'NONE',
                'predicted_relations': 'NONE',
                'predicted_length': 0,
                'ground_truth_length': len(pathway_df)
            })
    
    return pd.DataFrame(results)


# Run Algorithm 2
print("="*60)
print("Running Hub-Penalized Algorithm")
print("="*60)

hub_predictions = run_hub_penalized(G, ground_truth_nodes, alpha=0.5)
# hub_predictions.to_csv('hub_penalized_predictions.csv', index=False)
# print("\n✓ Saved: hub_penalized_predictions.csv")

Running Hub-Penalized Algorithm
Initializing Hub-Penalized algorithm...
  Edge weights computed (α=0.5)

valganciclovir_CMV_infection: valganciclovir → CMV infection
  ✓ Found path: 4 nodes
  Path: Valganciclovir → Splenomegaly → fetal cytomegalovirus syndrome → cytomegalovirus infection

antazoline_Vasomotor_rhinitis: antazoline → Vasomotor rhinitis
  ✓ Found path: 3 nodes
  Path: Antazoline → HRH1 → vasomotor rhinitis

apalutamide_Nonmetastatic_prostate_cancer: apalutamide → Nonmetastatic prostate cancer
  ✓ Found path: 3 nodes
  Path: Apalutamide → CYP2C19 → prostate cancer

aminoglutethimide_Secondary_malignant_neoplasm_o: aminoglutethimide → Secondary malignant neoplasm of female breast
  ✓ Found path: 3 nodes
  Path: Aminoglutethimide → CYP19A1 → breast neoplasm

anastrozole_Hormone_receptor_positive_mali: anastrozole → Hormone receptor positive malignant neoplasm of breast
  ✓ Found path: 3 nodes
  Path: Anastrozole → CYP19A1 → breast neoplasm

diflunisal_Rheumatoid_arthritis: d

## Algorithm 3: PageRank-Inverse Weighted Shortest Path

**Core Idea:** PageRank captures global graph centrality. Nodes with HIGH PageRank are generic hubs. We want paths through LOW PageRank (more specific) nodes.

**Weight Formula:** `weight[u,v] = 1 / (1 + pagerank[v])`

- Low PageRank → low weight → preferred
- PageRank is computed once upfront

In [40]:
# ============================================================
# ALGORITHM 3: PageRank-Inverse Weighted Shortest Path
# ============================================================

def run_pagerank_inverse(graph, ground_truth_df, damping=0.85):
    """
    Run PageRank-Inverse algorithm on all pathways.
    """
    results = []
    
    # Initialize algorithm
    print("Initializing PageRank-Inverse algorithm...")
    algo = PageRankInverseShortestPath(graph, damping=damping)
    print(f"  Edge weights computed")
    
    for pathway_id in ground_truth_df['pathway_id'].unique():
        pathway_df = ground_truth_df[ground_truth_df['pathway_id'] == pathway_id].sort_values('step_order')
        
        source_idx = int(pathway_df.iloc[0]['node_index'])
        target_idx = int(pathway_df.iloc[-1]['node_index'])
        source_name = pathway_df.iloc[0]['node_name']
        target_name = pathway_df.iloc[-1]['node_name']
        
        print(f"\n{pathway_id}: {source_name} → {target_name}")
        
        path, relations, weight = algo.find_path(source_idx, target_idx)
        
        if path:
            node_ids = [graph.nodes[idx].get('node_id', str(idx)) for idx in path]
            node_names = [graph.nodes[idx].get('node_name', str(idx)) for idx in path]
            
            print(f"  ✓ Found path: {len(path)} nodes")
            print(f"  Path: {' → '.join(node_names)}")
            
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': ','.join(map(str, path)),
                'predicted_node_ids': ','.join(node_ids),
                'predicted_node_names': ','.join(node_names),
                'predicted_relations': ','.join(relations),
                'predicted_length': len(path),
                'ground_truth_length': len(pathway_df)
            })
        else:
            print(f"  ✗ No path found")
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': 'NONE',
                'predicted_node_ids': 'NONE',
                'predicted_node_names': 'NONE',
                'predicted_relations': 'NONE',
                'predicted_length': 0,
                'ground_truth_length': len(pathway_df)
            })
    
    return pd.DataFrame(results)


# Run Algorithm 3
print("="*60)
print("Running PageRank-Inverse Algorithm")
print("="*60)

pr_predictions = run_pagerank_inverse(G, ground_truth_nodes, damping=0.85)
# pr_predictions.to_csv('pagerank_inverse_predictions.csv', index=False)
# print("\n✓ Saved: pagerank_inverse_predictions.csv")

Running PageRank-Inverse Algorithm
Initializing PageRank-Inverse algorithm...
  Computing PageRank (this may take a minute)...
  PageRank computed for 129,375 nodes
  Edge weights computed

valganciclovir_CMV_infection: valganciclovir → CMV infection
  ✓ Found path: 3 nodes
  Path: Valganciclovir → Ganciclovir → cytomegalovirus infection

antazoline_Vasomotor_rhinitis: antazoline → Vasomotor rhinitis
  ✓ Found path: 2 nodes
  Path: Antazoline → vasomotor rhinitis

apalutamide_Nonmetastatic_prostate_cancer: apalutamide → Nonmetastatic prostate cancer
  ✓ Found path: 2 nodes
  Path: Apalutamide → prostate cancer

aminoglutethimide_Secondary_malignant_neoplasm_o: aminoglutethimide → Secondary malignant neoplasm of female breast
  ✓ Found path: 3 nodes
  Path: Aminoglutethimide → Paclitaxel → breast neoplasm

anastrozole_Hormone_receptor_positive_mali: anastrozole → Hormone receptor positive malignant neoplasm of breast
  ✓ Found path: 3 nodes
  Path: Anastrozole → Paclitaxel → breast neop

## Algorithm 4: Learned Embeddings + A* with Supervised Edge Weights

**Core Idea:** Learn from known drug repurposing pathways what makes a "good" edge.

**Two Phases:**
1. **Embed:** Train Node2Vec (or use spectral embeddings) to capture graph structure
2. **Learn:** Train MLP to predict edge goodness from:
   - Embedding similarity
   - Degree features
   - Edge type

**Search:** A* with learned weights + embedding-based heuristic

In [41]:
# ============================================================
# ALGORITHM 4: Learned Embeddings + A* with Supervised Edge Weights
# ============================================================

def run_learned_astar(graph, ground_truth_df, embedding_dim=64):
    """
    Run Learned Embeddings + A* algorithm on all pathways.
    """
    results = []
    
    # Initialize algorithm
    print("Initializing Learned Embeddings + A* algorithm...")
    algo = LearnedEmbeddingsAStar(graph, embedding_dim=embedding_dim)
    algo.train_embeddings()
    
    # Prepare training data from ground truth
    training_pathways = []
    for pathway_id in ground_truth_df['pathway_id'].unique():
        pathway_df = ground_truth_df[ground_truth_df['pathway_id'] == pathway_id].sort_values('step_order')
        training_pathways.append({'path_nodes': pathway_df['node_index'].tolist()})
    
    algo.train_edge_weights(training_pathways)
    
    for pathway_id in ground_truth_df['pathway_id'].unique():
        pathway_df = ground_truth_df[ground_truth_df['pathway_id'] == pathway_id].sort_values('step_order')
        
        source_idx = int(pathway_df.iloc[0]['node_index'])
        target_idx = int(pathway_df.iloc[-1]['node_index'])
        source_name = pathway_df.iloc[0]['node_name']
        target_name = pathway_df.iloc[-1]['node_name']
        
        print(f"\n{pathway_id}: {source_name} → {target_name}")
        
        path, relations, weight = algo.find_path(source_idx, target_idx)
        
        if path:
            node_ids = [graph.nodes[idx].get('node_id', str(idx)) for idx in path]
            node_names = [graph.nodes[idx].get('node_name', str(idx)) for idx in path]
            
            print(f"  ✓ Found path: {len(path)} nodes")
            print(f"  Path: {' → '.join(node_names)}")
            
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': ','.join(map(str, path)),
                'predicted_node_ids': ','.join(node_ids),
                'predicted_node_names': ','.join(node_names),
                'predicted_relations': ','.join(relations),
                'predicted_length': len(path),
                'ground_truth_length': len(pathway_df)
            })
        else:
            print(f"  ✗ No path found")
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': 'NONE',
                'predicted_node_ids': 'NONE',
                'predicted_node_names': 'NONE',
                'predicted_relations': 'NONE',
                'predicted_length': 0,
                'ground_truth_length': len(pathway_df)
            })
    
    return pd.DataFrame(results)


# Run Algorithm 4
print("="*60)
print("Running Learned Embeddings + A* Algorithm")
print("="*60)

learned_predictions = run_learned_astar(G, ground_truth_nodes, embedding_dim=64)
# learned_predictions.to_csv('learned_astar_predictions.csv', index=False)
# print("\n✓ Saved: learned_astar_predictions.csv")

Running Learned Embeddings + A* Algorithm
Initializing Learned Embeddings + A* algorithm...
  Computing spectral embeddings (sparse method)...
  Embeddings computed for 129,375 nodes
  Training edge weight MLP...
  MLP trained on 2176 samples (R²=0.627)
  Precomputing edge weights...
  Edge weights computed for 8,099,284 edges

valganciclovir_CMV_infection: valganciclovir → CMV infection
  ✓ Found path: 9 nodes
  Path: Valganciclovir → SLC15A2 → AQP1 → TMEM237 → CORO1C → MMGT1 → ABCC2 → Letermovir → cytomegalovirus infection

antazoline_Vasomotor_rhinitis: antazoline → Vasomotor rhinitis
  ✓ Found path: 3 nodes
  Path: Antazoline → HRH1 → vasomotor rhinitis

apalutamide_Nonmetastatic_prostate_cancer: apalutamide → Nonmetastatic prostate cancer
  ✓ Found path: 2 nodes
  Path: Apalutamide → prostate cancer

aminoglutethimide_Secondary_malignant_neoplasm_o: aminoglutethimide → Secondary malignant neoplasm of female breast
  ✓ Found path: 9 nodes
  Path: Aminoglutethimide → CYP11A1 → Glute

## Algorithm 5: Semantic Bridging with Intermediate Node Scoring

**Core Idea:** Use NLP to find paths where consecutive nodes are semantically related (they "make sense" together).

**Weight Formula:** `weight[u,v] = 1 - β * cosine_sim(text_emb[u], text_emb[v])`

- β = 0.3 balances semantic preference with path length
- Uses TF-IDF embeddings (or SciBERT if available)

In [None]:
# ============================================================
# ALGORITHM 5: Semantic Bridging with Intermediate Node Scoring
# ============================================================

def run_semantic_bridging(graph, ground_truth_df, beta=0.3):
    """
    Run Semantic Bridging algorithm on all pathways.
    """
    results = []
    
    # Initialize algorithm
    print("Initializing Semantic Bridging algorithm...")
    algo = SemanticBridgingPath(graph, beta=beta)
    algo.compute_embeddings()
    algo.compute_edge_weights()
    
    for pathway_id in ground_truth_df['pathway_id'].unique():
        pathway_df = ground_truth_df[ground_truth_df['pathway_id'] == pathway_id].sort_values('step_order')
        
        source_idx = int(pathway_df.iloc[0]['node_index'])
        target_idx = int(pathway_df.iloc[-1]['node_index'])
        source_name = pathway_df.iloc[0]['node_name']
        target_name = pathway_df.iloc[-1]['node_name']
        
        print(f"\n{pathway_id}: {source_name} → {target_name}")
        
        path, relations, weight = algo.find_path(source_idx, target_idx)
        
        if path:
            node_ids = [graph.nodes[idx].get('node_id', str(idx)) for idx in path]
            node_names = [graph.nodes[idx].get('node_name', str(idx)) for idx in path]
            
            print(f"  ✓ Found path: {len(path)} nodes")
            print(f"  Path: {' → '.join(node_names)}")
            
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': ','.join(map(str, path)),
                'predicted_node_ids': ','.join(node_ids),
                'predicted_node_names': ','.join(node_names),
                'predicted_relations': ','.join(relations),
                'predicted_length': len(path),
                'ground_truth_length': len(pathway_df)
            })
        else:
            print(f"  ✗ No path found")
            results.append({
                'pathway_id': pathway_id,
                'predicted_node_indices': 'NONE',
                'predicted_node_ids': 'NONE',
                'predicted_node_names': 'NONE',
                'predicted_relations': 'NONE',
                'predicted_length': 0,
                'ground_truth_length': len(pathway_df)
            })
    
    return pd.DataFrame(results)


# Run Algorithm 5
print("="*60)
print("Running Semantic Bridging Algorithm")
print("="*60)

semantic_predictions = run_semantic_bridging(G, ground_truth_nodes, beta=0.3)
# semantic_predictions.to_csv('semantic_bridging_predictions.csv', index=False)
# print("\n✓ Saved: semantic_bridging_predictions.csv")

Running Semantic Bridging Algorithm
Initializing Semantic Bridging algorithm...
  Computing TF-IDF embeddings...
  Embeddings computed for 129,375 nodes
  Computing edge weights...
  Edge weights computed for 8,099,284 edges

valganciclovir_CMV_infection: valganciclovir → CMV infection
  ✓ Found path: 3 nodes
  Path: Valganciclovir → Ganciclovir → cytomegalovirus infection

antazoline_Vasomotor_rhinitis: antazoline → Vasomotor rhinitis
  ✓ Found path: 2 nodes
  Path: Antazoline → vasomotor rhinitis

apalutamide_Nonmetastatic_prostate_cancer: apalutamide → Nonmetastatic prostate cancer
  ✓ Found path: 2 nodes
  Path: Apalutamide → prostate cancer

aminoglutethimide_Secondary_malignant_neoplasm_o: aminoglutethimide → Secondary malignant neoplasm of female breast
  ✓ Found path: 3 nodes
  Path: Aminoglutethimide → Paclitaxel → breast neoplasm

anastrozole_Hormone_receptor_positive_mali: anastrozole → Hormone receptor positive malignant neoplasm of breast
  ✓ Found path: 3 nodes
  Path: An

---
## Evaluate Both Algorithms

Calculate all 9 metrics for both algorithms and compare.

In [None]:
def evaluate_predictions(predictions_df, ground_truth_nodes_df, ground_truth_edges_df, edges_df, algorithm_name):
    """
    Calculate all 9 evaluation metrics for predictions.
    
    Returns:
        DataFrame with metrics for each pathway
    """
    results = []
    
    for _, pred_row in predictions_df.iterrows():
        pathway_id = pred_row['pathway_id']
        
        # Get ground truth for this pathway
        gt_nodes = ground_truth_nodes_df[ground_truth_nodes_df['pathway_id'] == pathway_id].sort_values('step_order')
        gt_edges = ground_truth_edges_df[ground_truth_edges_df['pathway_id'] == pathway_id]
        
        gt_node_ids = gt_nodes['node_id'].tolist()
        gt_target_id = str(gt_nodes.iloc[-1]['node_id'])
        gt_edge_types = gt_edges['relation_type'].tolist() if not gt_edges.empty else []
        
        # Parse predictions
        if pred_row['predicted_node_ids'] == 'NONE':
            pred_node_ids = []
            pred_node_indices = []
            pred_relations = []
        else:
            pred_node_ids = pred_row['predicted_node_ids'].split(',')
            pred_node_indices = [int(x) for x in pred_row['predicted_node_indices'].split(',')]
            pred_relations = pred_row['predicted_relations'].split(',') if pred_row['predicted_relations'] != 'NONE' else []
        
        # Calculate metrics
        precision, recall, f1 = f1_score(pred_node_ids, [str(x) for x in gt_node_ids])
        hits = calculate_hits_at_k(pred_node_ids, gt_target_id)
        path_mae = calculate_path_length_mae(pred_row['predicted_length'], pred_row['ground_truth_length'])
        relation_acc = calculate_relation_accuracy(pred_relations, gt_edge_types)
        hub_ratio = calculate_hub_node_ratio(pred_node_indices, edges_df)
        edit_dist = calculate_edit_distance(pred_node_ids, [str(x) for x in gt_node_ids])
        
        results.append({
            'pathway_id': pathway_id,
            'algorithm': algorithm_name,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'hits_at_1': hits['hits_at_1'],
            'hits_at_3': hits['hits_at_3'],
            'hits_at_5': hits['hits_at_5'],
            'path_length_mae': path_mae,
            'relation_type_accuracy': relation_acc,
            'hub_node_ratio': hub_ratio,
            'path_edit_distance': edit_dist
        })
    
    return pd.DataFrame(results)


# Evaluate both algorithms
print("Evaluating Shortest Path...")
sp_eval = evaluate_predictions(sp_predictions, ground_truth_nodes, ground_truth_edges, edges, 'Shortest Path')

print("Evaluating Meta-Path...")
mp_eval = evaluate_predictions(mp_predictions, ground_truth_nodes, ground_truth_edges, edges, 'Meta-Path')

# Combine results
all_eval = pd.concat([sp_eval, mp_eval], ignore_index=True)
all_eval.to_csv('evaluation_results_all.csv', index=False)
print("\n✓ Saved: evaluation_results_all.csv")

Evaluating Shortest Path...


KeyError: 'relation_type'

---
## Results Summary

Compare algorithm performance across all metrics.

In [None]:
# Calculate average metrics per algorithm
metrics = ['precision', 'recall', 'f1_score', 'hits_at_1', 'relation_type_accuracy', 
           'hub_node_ratio', 'path_edit_distance', 'path_length_mae']

summary = all_eval.groupby('algorithm')[metrics].mean().round(3)

print("="*70)
print("ALGORITHM COMPARISON: Average Metrics Across All Pathways")
print("="*70)
print(summary.T.to_string())
print("\n" + "="*70)

# Count perfect matches
print("\nPerfect Matches (Edit Distance = 0):")
for alg in ['Shortest Path', 'Meta-Path']:
    perfect = (all_eval[all_eval['algorithm'] == alg]['path_edit_distance'] == 0).sum()
    total = len(all_eval[all_eval['algorithm'] == alg])
    print(f"  {alg}: {perfect}/{total} pathways")

---
## Visualization

Create comparison charts for the two algorithms.

In [None]:
# Set up the figure with 2x2 subplots
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Algorithm Comparison: Shortest Path vs Meta-Path', fontsize=14, fontweight='bold')

# Color scheme
colors = {'Shortest Path': '#e74c3c', 'Meta-Path': '#3498db'}

# --- Subplot 1: Node Accuracy Metrics ---
ax1 = axes[0, 0]
node_metrics = ['precision', 'recall', 'f1_score']
x = np.arange(len(node_metrics))
width = 0.35

sp_vals = [summary.loc['Shortest Path', m] for m in node_metrics]
mp_vals = [summary.loc['Meta-Path', m] for m in node_metrics]

ax1.bar(x - width/2, sp_vals, width, label='Shortest Path', color=colors['Shortest Path'])
ax1.bar(x + width/2, mp_vals, width, label='Meta-Path', color=colors['Meta-Path'])
ax1.set_ylabel('Score')
ax1.set_title('Node Accuracy Metrics')
ax1.set_xticks(x)
ax1.set_xticklabels(['Precision', 'Recall', 'F1 Score'])
ax1.legend()
ax1.set_ylim(0, 1.1)
ax1.grid(axis='y', alpha=0.3)

# --- Subplot 2: Mechanistic Quality ---
ax2 = axes[0, 1]
mech_metrics = ['relation_type_accuracy', 'hub_node_ratio']
x = np.arange(len(mech_metrics))

sp_vals = [summary.loc['Shortest Path', m] for m in mech_metrics]
mp_vals = [summary.loc['Meta-Path', m] for m in mech_metrics]

ax2.bar(x - width/2, sp_vals, width, label='Shortest Path', color=colors['Shortest Path'])
ax2.bar(x + width/2, mp_vals, width, label='Meta-Path', color=colors['Meta-Path'])
ax2.set_ylabel('Score')
ax2.set_title('Mechanistic Quality Metrics')
ax2.set_xticks(x)
ax2.set_xticklabels(['Relation Accuracy', 'Hub Node Ratio'])
ax2.legend()
ax2.set_ylim(0, 1.1)
ax2.grid(axis='y', alpha=0.3)
ax2.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Baseline')

# --- Subplot 3: Path Length Comparison ---
ax3 = axes[1, 0]
pathways = sp_predictions['pathway_id'].tolist()
x = np.arange(len(pathways))
width = 0.25

gt_lengths = sp_predictions['ground_truth_length'].tolist()
sp_lengths = sp_predictions['predicted_length'].tolist()
mp_lengths = mp_predictions['predicted_length'].tolist()

ax3.bar(x - width, gt_lengths, width, label='Ground Truth', color='#2ecc71')
ax3.bar(x, sp_lengths, width, label='Shortest Path', color=colors['Shortest Path'])
ax3.bar(x + width, mp_lengths, width, label='Meta-Path', color=colors['Meta-Path'])
ax3.set_ylabel('Path Length (nodes)')
ax3.set_title('Path Length by Pathway')
ax3.set_xticks(x)
ax3.set_xticklabels(pathways, rotation=45, ha='right')
ax3.legend()
ax3.grid(axis='y', alpha=0.3)

# --- Subplot 4: Radar Chart ---
ax4 = axes[1, 1]
ax4.set_visible(False)  # We'll create a polar plot instead

# Create radar chart in new axis
ax4_radar = fig.add_subplot(2, 2, 4, projection='polar')

radar_metrics = ['precision', 'recall', 'f1_score', 'hits_at_1', 'relation_type_accuracy']
radar_labels = ['Precision', 'Recall', 'F1', 'Hits@1', 'Relation Acc']

# Get values and close the polygon
sp_radar = [summary.loc['Shortest Path', m] for m in radar_metrics] + [summary.loc['Shortest Path', radar_metrics[0]]]
mp_radar = [summary.loc['Meta-Path', m] for m in radar_metrics] + [summary.loc['Meta-Path', radar_metrics[0]]]

angles = np.linspace(0, 2 * np.pi, len(radar_metrics), endpoint=False).tolist()
angles += angles[:1]  # Close the polygon

ax4_radar.plot(angles, sp_radar, 'o-', linewidth=2, label='Shortest Path', color=colors['Shortest Path'])
ax4_radar.fill(angles, sp_radar, alpha=0.25, color=colors['Shortest Path'])
ax4_radar.plot(angles, mp_radar, 'o-', linewidth=2, label='Meta-Path', color=colors['Meta-Path'])
ax4_radar.fill(angles, mp_radar, alpha=0.25, color=colors['Meta-Path'])

ax4_radar.set_xticks(angles[:-1])
ax4_radar.set_xticklabels(radar_labels)
ax4_radar.set_ylim(0, 1)
ax4_radar.set_title('Performance Profile', pad=20)
ax4_radar.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0))

plt.tight_layout()
plt.savefig('algorithm_comparison.png', dpi=300, bbox_inches='tight')
print("✓ Saved: algorithm_comparison.png")
plt.show()

---
## Detailed Pathway Analysis

Look at individual pathway results to understand where each algorithm succeeds or fails.

In [None]:
print("="*80)
print("DETAILED PATHWAY ANALYSIS")
print("="*80)

for pathway_id in sp_predictions['pathway_id'].unique():
    print(f"\n{'='*60}")
    print(f"Pathway: {pathway_id}")
    print(f"{'='*60}")
    
    # Ground truth
    gt = ground_truth_nodes[ground_truth_nodes['pathway_id'] == pathway_id].sort_values('step_order')
    gt_path = ' → '.join(gt['node_name'].tolist())
    print(f"\nGround Truth ({len(gt)} nodes):")
    print(f"  {gt_path}")
    
    # Shortest path prediction
    sp = sp_predictions[sp_predictions['pathway_id'] == pathway_id].iloc[0]
    sp_path = sp['predicted_node_names'].replace(',', ' → ')
    print(f"\nShortest Path ({sp['predicted_length']} nodes):")
    print(f"  {sp_path}")
    
    # Meta-path prediction  
    mp = mp_predictions[mp_predictions['pathway_id'] == pathway_id].iloc[0]
    mp_path = mp['predicted_node_names'].replace(',', ' → ')
    print(f"\nMeta-Path ({mp['predicted_length']} nodes):")
    print(f"  {mp_path}")
    
    # Metrics comparison
    sp_metrics = all_eval[(all_eval['pathway_id'] == pathway_id) & (all_eval['algorithm'] == 'Shortest Path')].iloc[0]
    mp_metrics = all_eval[(all_eval['pathway_id'] == pathway_id) & (all_eval['algorithm'] == 'Meta-Path')].iloc[0]
    
    print(f"\nMetrics:")
    print(f"  {'Metric':<25} {'Shortest Path':>15} {'Meta-Path':>15}")
    print(f"  {'-'*55}")
    for metric in ['f1_score', 'relation_type_accuracy', 'path_edit_distance']:
        print(f"  {metric:<25} {sp_metrics[metric]:>15.3f} {mp_metrics[metric]:>15.3f}")