# Algorithm Tuning: Learned A* (v2) & Semantic Bridging (v2)

**Purpose:** Improve Algorithms 4 and 5 from the baseline benchmark with:
- **Meta-path awareness** — mine valid edge-type patterns from ground truth, use as soft/hard constraints
- **Path-level context** (A*) — edge features include hop count, distance to target, node types
- **Node-type-aware features** (A*) — replace spectral embeddings with biologically meaningful features
- **5-fold cross-validation** (A*) — eliminate train-on-test leakage using pre-assigned folds
- **Per-edge-type weighting** (Semantic) — learned penalties from GT edge-type frequency
- **Soft meta-path penalties** (Semantic) — 10x weight multiplier for edges inconsistent with valid patterns

**Data requirement:** PrimeKG `nodes.csv` and `edges.csv` in `data/raw/`. Ground truth and fold assignments in `data/processed/`.


## 1. Setup

In [1]:
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import heapq
import time
import warnings

from typing import Dict, List, Tuple, Set
from collections import Counter
from scipy.sparse.linalg import eigsh
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD

warnings.filterwarnings('ignore')
print("✓ Imports loaded")


✓ Imports loaded


## 2. Configuration

In [2]:
import os

DATA_DIR = '..'

PATHS = {
    'nodes':              f'{DATA_DIR}/data/raw/nodes.csv',
    'edges':              f'{DATA_DIR}/data/raw/edges.csv',
    'ground_truth_nodes': f'{DATA_DIR}/data/processed/benchmark_pathways_nodes.csv',
    'ground_truth_edges': f'{DATA_DIR}/data/processed/benchmark_pathways_edges.csv',
    'fold_assignments':   f'{DATA_DIR}/data/processed/pathway_fold_assignments.csv',
}

MIN_PATHWAY_NODES = 4
N_FOLDS = 5

print("Configuration:")
for name, path in PATHS.items():
    exists = "✓" if os.path.exists(path) else "✗ NOT FOUND"
    print(f"  {exists}  {name}: {path}")
print(f"  min_pathway_nodes: {MIN_PATHWAY_NODES}")
print(f"  n_folds: {N_FOLDS}")


Configuration:
  ✓  nodes: ../data/raw/nodes.csv
  ✓  edges: ../data/raw/edges.csv
  ✓  ground_truth_nodes: ../data/processed/benchmark_pathways_nodes.csv
  ✓  ground_truth_edges: ../data/processed/benchmark_pathways_edges.csv
  ✓  fold_assignments: ../data/processed/pathway_fold_assignments.csv
  min_pathway_nodes: 4
  n_folds: 5


## 3. Evaluation Helpers & Metrics

In [3]:
# ============================================================
# EVALUATION HELPERS & METRICS (all inline)
# ============================================================

def is_valid_prediction(predicted_ids):
    return predicted_ids and predicted_ids != ['NONE']

def calculate_edit_distance(predicted_ids, ground_truth_ids):
    if not predicted_ids or predicted_ids == ['NONE']:
        return 1.0
    m, n = len(predicted_ids), len(ground_truth_ids)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if predicted_ids[i - 1] == ground_truth_ids[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
    return dp[m][n] / max(m, n)

def compute_degree_counts(edges_df):
    degree_count = Counter()
    for _, row in edges_df.iterrows():
        degree_count[row['x_index']] += 1
        degree_count[row['y_index']] += 1
    return degree_count

def compute_hub_threshold(degree_count, percentile=95):
    return np.percentile(list(degree_count.values()), percentile)

def calculate_hits_at_k(predicted_ids, ground_truth_target, k_values=[1, 3, 5]):
    hits = {f'hits_at_{k}': 0 for k in k_values}
    if not is_valid_prediction(predicted_ids):
        return hits
    for k in k_values:
        last_k = predicted_ids[-k:] if len(predicted_ids) >= k else predicted_ids
        hits[f'hits_at_{k}'] = 1 if ground_truth_target in last_k else 0
    return hits

def calculate_relation_accuracy(predicted_relations, ground_truth_edge_types):
    if not predicted_relations:
        return 0.0
    gt_types = set(ground_truth_edge_types)
    return sum(1 for r in predicted_relations if r in gt_types) / len(predicted_relations)

def calculate_path_length_mae(predicted_length, ground_truth_length):
    return abs(predicted_length - ground_truth_length)

def calculate_hub_node_ratio(predicted_indices, degree_count, hub_threshold):
    if not predicted_indices:
        return 0.0
    return sum(1 for idx in predicted_indices if degree_count.get(idx, 0) >= hub_threshold) / len(predicted_indices)

def metric_precision(predicted_ids, ground_truth_ids):
    if not is_valid_prediction(predicted_ids):
        return 0.0
    pred_set, gt_set = set(predicted_ids), set(ground_truth_ids)
    return len(pred_set & gt_set) / len(pred_set)

def metric_recall(predicted_ids, ground_truth_ids):
    if not is_valid_prediction(predicted_ids):
        return 0.0
    pred_set, gt_set = set(predicted_ids), set(ground_truth_ids)
    return len(pred_set & gt_set) / len(gt_set) if gt_set else 0.0

def metric_f1(predicted_ids, ground_truth_ids):
    p = metric_precision(predicted_ids, ground_truth_ids)
    r = metric_recall(predicted_ids, ground_truth_ids)
    return 2 * p * r / (p + r) if (p + r) > 0 else 0.0

def metric_path_length_accuracy(predicted_length, ground_truth_length):
    if predicted_length == 0 and ground_truth_length == 0:
        return 1.0
    max_len = max(predicted_length, ground_truth_length)
    return 1 - abs(predicted_length - ground_truth_length) / max_len if max_len > 0 else 0.0

def metric_mrr(predicted_ids, ground_truth_ids):
    if not is_valid_prediction(predicted_ids):
        return 0.0
    gt_set = set(ground_truth_ids)
    for rank, node in enumerate(predicted_ids, start=1):
        if node in gt_set:
            return 1 / rank
    return 0.0

print("✓ Evaluation helpers & metrics loaded")


✓ Evaluation helpers & metrics loaded


## 4. Shared Pathfinding Engine

In [4]:
# ============================================================
# TRANSITION FILTER & DIJKSTRA ENGINE
# ============================================================

def allowed_transition(G, src, u, v) -> bool:
    u_type = G.nodes[u].get("node_type", "")
    v_type = G.nodes[v].get("node_type", "")
    if u_type == "drug" and v_type == "disease":
        return False
    src_type = G.nodes[src].get("node_type", "")
    if src_type == "drug" and u == src and v_type == "drug":
        return False
    return True


def find_path_engine(graph, weighted_graph, source, target, transition_fn):
    dist = {source: 0.0}
    parent = {source: None}
    pq = [(0.0, source)]
    while pq:
        cur_cost, u = heapq.heappop(pq)
        if cur_cost != dist.get(u, float('inf')):
            continue
        if u == target:
            break
        for v in graph.successors(u):
            if not transition_fn(graph, source, u, v):
                continue
            w = weighted_graph[u][v].get("weight", 1.0)
            new_cost = cur_cost + w
            if new_cost < dist.get(v, float('inf')):
                dist[v] = new_cost
                parent[v] = u
                heapq.heappush(pq, (new_cost, v))
    if target not in dist:
        return [], [], float("inf")
    path = []
    cur = target
    while cur is not None:
        path.append(cur)
        cur = parent[cur]
    path.reverse()
    relations = []
    for i in range(len(path) - 1):
        edge_data = graph.get_edge_data(path[i], path[i + 1]) or {}
        relations.append(edge_data.get("relation", "unknown"))
    return path, relations, dist[target]


print("✓ Shared engine loaded")


✓ Shared engine loaded


## 5. Meta-Path Mining

Extract the dominant edge-type patterns from ground truth pathways. These become constraints for both v2 algorithms.


In [5]:
def mine_metapath_patterns(gt_edges_df, gt_nodes_df, min_count=2):
    """
    Extract meta-path patterns (edge-type sequences) from ground truth pathways.

    Returns:
        patterns     : list of tuples, e.g. [('drug_protein', 'disease_protein'), ...]
        pattern_counts: Counter of pattern frequencies
        prefix_set   : set of all valid prefixes (for constraining search)
    """
    # Auto-detect edge type column
    edge_type_col = None
    for candidate in ['relation_type', 'relation', 'edge_type']:
        if candidate in gt_edges_df.columns:
            edge_type_col = candidate
            break
    if edge_type_col is None:
        edge_type_col = gt_edges_df.columns[0]
    print(f"  Using edge type column: '{edge_type_col}'")

    pattern_counts = Counter()

    for pathway_id in gt_edges_df['pathway_id'].unique():
        pw_edges = gt_edges_df[gt_edges_df['pathway_id'] == pathway_id]

        # Sort edges by step_from to get the correct sequence
        if 'step_from' in pw_edges.columns:
            pw_edges = pw_edges.sort_values('step_from')

        edge_types = tuple(pw_edges[edge_type_col].tolist())
        if len(edge_types) > 0:
            pattern_counts[edge_types] += 1

    # Filter to patterns seen at least min_count times
    patterns = [p for p, c in pattern_counts.most_common() if c >= min_count]

    # Build prefix set for search-time constraint checking
    prefix_set = set()
    for pattern in patterns:
        for i in range(1, len(pattern) + 1):
            prefix_set.add(pattern[:i])

    # Also include ALL patterns (even rare) in prefix set for flexibility
    for pattern in pattern_counts:
        for i in range(1, len(pattern) + 1):
            prefix_set.add(pattern[:i])

    print(f"  Mined {len(pattern_counts)} unique meta-path patterns")
    print(f"  {len(patterns)} patterns with count >= {min_count}")
    print(f"  {len(prefix_set)} valid prefixes for search constraint")
    print(f"\n  Top 15 patterns:")
    for pattern, count in pattern_counts.most_common(15):
        print(f"    {count:3d}x  {' → '.join(pattern)}")

    return patterns, pattern_counts, prefix_set


def compute_edge_type_weights(gt_edges_df, all_edge_types, bonus=0.3, penalty=2.0):
    """
    Compute per-edge-type weight multipliers from ground truth frequency.

    Frequent GT edge types get a bonus (multiplier < 1).
    Rare/absent edge types get a penalty (multiplier > 1).
    """
    edge_type_col = None
    for candidate in ['relation_type', 'relation', 'edge_type']:
        if candidate in gt_edges_df.columns:
            edge_type_col = candidate
            break
    if edge_type_col is None:
        edge_type_col = gt_edges_df.columns[0]

    gt_counts = Counter(gt_edges_df[edge_type_col].tolist())
    total = sum(gt_counts.values())

    weights = {}
    for edge_type in all_edge_types:
        freq = gt_counts.get(edge_type, 0) / total if total > 0 else 0
        if freq > 0.05:
            weights[edge_type] = 1.0 - bonus * min(freq * 5, 1.0)  # reward: 0.7–1.0
        elif freq > 0:
            weights[edge_type] = 1.0  # neutral
        else:
            weights[edge_type] = penalty  # penalize unseen types
    
    print(f"  Edge type weights (sample):")
    for et, w in sorted(weights.items(), key=lambda x: x[1])[:10]:
        print(f"    {et:<35s}: {w:.3f}")

    return weights


print("✓ Meta-path mining functions loaded")


✓ Meta-path mining functions loaded


## 6. Algorithm 4 v2: Improved Learned A*

**Changes from v1:**
- Edge features now include **path-level context**: hop count, node types (one-hot), distance to target
- **Meta-path prefix constraint**: A* only expands edges consistent with valid GT patterns
- **Per-edge-type weight multiplier**: learned from GT frequency
- **5-fold CV built in**: `train_and_evaluate_cv()` trains on 4 folds, tests on 1


In [14]:
# ============================================================
# ALGORITHM 4 v2: Learned A* with Path Context + Meta-Path Constraints
# ============================================================

# Node type vocabulary for one-hot encoding
NODE_TYPES = ['drug', 'gene/protein', 'disease', 'biological_process',
              'molecular_function', 'pathway', 'anatomy', 'effect/phenotype',
              'cellular_component', 'exposure']

NODE_TYPE_TO_IDX = {t: i for i, t in enumerate(NODE_TYPES)}


class LearnedAStarV2:
    def __init__(self, graph: nx.DiGraph, embedding_dim: int = 64,
                 prefix_set: Set[tuple] = None, edge_type_weights: Dict[str, float] = None):
        self.graph = graph
        self.embedding_dim = embedding_dim
        self.embeddings = None
        self.edge_weights = None
        self.scaler = None
        self.mlp = None
        self.degrees = dict(graph.degree())
        self.prefix_set = prefix_set or set()
        self.edge_type_weights = edge_type_weights or {}

    def _node_type_onehot(self, node: int) -> np.ndarray:
        """One-hot encode node type."""
        vec = np.zeros(len(NODE_TYPES))
        ntype = self.graph.nodes[node].get('node_type', '')
        idx = NODE_TYPE_TO_IDX.get(ntype, -1)
        if idx >= 0:
            vec[idx] = 1.0
        return vec

    def train_embeddings(self) -> Dict[int, np.ndarray]:
        """Compute spectral embeddings."""
        print("  Computing spectral embeddings...")
        G_undirected = self.graph.to_undirected()
        largest_cc = max(nx.connected_components(G_undirected), key=len)
        G_sub = G_undirected.subgraph(largest_cc)
        L = nx.normalized_laplacian_matrix(G_sub)
        k = min(self.embedding_dim + 1, L.shape[0] - 2)
        eigenvalues, eigenvectors = eigsh(L, k=k, which='SM')
        node_list = list(G_sub.nodes())
        self.embeddings = {}
        for i, node in enumerate(node_list):
            self.embeddings[node] = eigenvectors[i, 1:]
        for node in self.graph.nodes():
            if node not in self.embeddings:
                self.embeddings[node] = np.random.randn(k - 1) * 0.01
        print(f"  Embeddings: {len(self.embeddings):,} nodes, dim={k-1}")
        return self.embeddings

    def _edge_features_v2(self, u: int, v: int, hop_count: int = 0,
                          source: int = None, target: int = None) -> np.ndarray:
        """
        V2 edge features: original features + path context + node types.
        """
        features = []

        # --- Original features ---
        if self.embeddings:
            emb_u = self.embeddings.get(u, np.zeros(self.embedding_dim))
            emb_v = self.embeddings.get(v, np.zeros(self.embedding_dim))
            norm_u, norm_v = np.linalg.norm(emb_u), np.linalg.norm(emb_v)
            cos_sim = np.dot(emb_u, emb_v) / (norm_u * norm_v) if norm_u > 0 and norm_v > 0 else 0.0
            features.append(cos_sim)
            features.append(np.linalg.norm(emb_u - emb_v))

        features.append(np.log1p(self.degrees.get(u, 0)))
        features.append(np.log1p(self.degrees.get(v, 0)))
        features.append(np.log1p(self.degrees.get(u, 1) / max(self.degrees.get(v, 1), 1)))

        # --- NEW: Path-level context ---
        features.append(float(hop_count))

        # Distance from v to target (embedding space)
        if self.embeddings and target is not None:
            emb_v = self.embeddings.get(v, np.zeros(self.embedding_dim))
            emb_t = self.embeddings.get(target, np.zeros(self.embedding_dim))
            features.append(np.linalg.norm(emb_v - emb_t))
        else:
            features.append(0.0)

        # --- NEW: Node type one-hots for u and v ---
        features.extend(self._node_type_onehot(u))
        features.extend(self._node_type_onehot(v))

        # --- NEW: Edge type one-hot ---
        edge_data = self.graph.get_edge_data(u, v) or {}
        edge_rel = edge_data.get('relation', 'unknown')
        # Use a compact encoding: hash to fixed-size vector
        edge_type_feat = np.zeros(8)
        edge_type_feat[hash(edge_rel) % 8] = 1.0
        features.extend(edge_type_feat)

        return np.array(features)

    def train_edge_weights(self, training_pathways: List[Dict], negative_ratio: float = 3.0):
        """Train MLP on v2 features with path context."""
        if self.embeddings is None:
            self.train_embeddings()
        print("  Training edge weight MLP (v2 features)...")

        X_train, y_train = [], []
        positive_edges = set()

        for pathway in training_pathways:
            path = pathway['path_nodes']
            source, target = path[0], path[-1]
            for hop_i in range(len(path) - 1):
                u, v = path[hop_i], path[hop_i + 1]
                if self.graph.has_edge(u, v):
                    positive_edges.add((u, v))
                    X_train.append(self._edge_features_v2(u, v, hop_count=hop_i,
                                                          source=source, target=target))
                    y_train.append(0.1)

        all_edges = list(self.graph.edges())
        np.random.shuffle(all_edges)
        n_negative = int(len(positive_edges) * negative_ratio)
        neg_count = 0
        for u, v in all_edges:
            if neg_count >= n_negative:
                break
            if (u, v) not in positive_edges:
                X_train.append(self._edge_features_v2(u, v, hop_count=2))
                y_train.append(1.0)
                neg_count += 1

        X_train, y_train = np.array(X_train), np.array(y_train)
        self.scaler = StandardScaler()
        X_scaled = self.scaler.fit_transform(X_train)

        self.mlp = MLPRegressor(hidden_layer_sizes=(64, 32), activation='relu',
                                max_iter=500, early_stopping=True, random_state=42)
        self.mlp.fit(X_scaled, y_train)
        print(f"  MLP trained on {len(X_train)} samples (R²={self.mlp.score(X_scaled, y_train):.3f})")
        print(f"  Feature dim: {X_train.shape[1]}")

    def _heuristic(self, node: int, target: int) -> float:
        if self.embeddings is None:
            return 0.0
        emb_n = self.embeddings.get(node, np.zeros(self.embedding_dim))
        emb_t = self.embeddings.get(target, np.zeros(self.embedding_dim))
        return np.linalg.norm(emb_n - emb_t) * 0.1

    def _is_valid_prefix(self, relation_seq: tuple) -> bool:
        """Check if the current edge-type sequence is a prefix of any known GT pattern."""
        if not self.prefix_set:
            return True  # no constraints
        return relation_seq in self.prefix_set

    def find_path(self, source: int, target: int) -> Tuple[List[int], List[str], float]:
            """A* with learned edge weights and edge-type multipliers. No relation tracking in state."""
            counter = 0
            open_set = [(self._heuristic(source, target), counter, source, [source], 0.0)]
            visited = set()

            while open_set:
                f_score, _, current, path, g_score = heapq.heappop(open_set)

                if current == target:
                    relations = []
                    for i in range(len(path) - 1):
                        edge_data = self.graph.get_edge_data(path[i], path[i + 1])
                        relations.append(edge_data.get('relation', 'unknown'))
                    return path, relations, g_score

                if current in visited:
                    continue
                visited.add(current)

                hop_count = len(path) - 1

                for neighbor in self.graph.neighbors(current):
                    if neighbor in visited:
                        continue
                    if not allowed_transition(self.graph, source, current, neighbor):
                        continue

                    # MLP-predicted base weight
                    if self.mlp is not None and self.scaler is not None:
                        feat = self._edge_features_v2(current, neighbor, hop_count=hop_count,
                                                    source=source, target=target)
                        feat_scaled = self.scaler.transform(feat.reshape(1, -1))
                        base_weight = float(np.clip(self.mlp.predict(feat_scaled)[0], 0.01, 2.0))
                    else:
                        base_weight = 1.0

                    # Edge-type multiplier (baked-in meta-path awareness)
                    edge_data = self.graph.get_edge_data(current, neighbor) or {}
                    edge_rel = edge_data.get('relation', 'unknown')
                    type_mult = self.edge_type_weights.get(edge_rel, 1.5)

                    edge_weight = base_weight * type_mult
                    new_g = g_score + edge_weight
                    new_f = new_g + self._heuristic(neighbor, target)
                    counter += 1
                    heapq.heappush(open_set, (new_f, counter, neighbor, path + [neighbor], new_g))

            return [], [], float('inf')


print("✓ LearnedAStarV2 loaded")


✓ LearnedAStarV2 loaded


## 7. Algorithm 5 v2: Improved Semantic Bridging

**Changes from v1:**
- **Per-edge-type weight multipliers** learned from GT frequency (penalize `drug_effect`, reward `drug_protein`)
- **Soft meta-path penalty**: edges inconsistent with any valid GT prefix get 10x weight
- **Graph co-occurrence features**: neighbor overlap (Jaccard) supplements TF-IDF for gene symbols
- **Separate β by node-type pair**: different similarity discounts for drug→protein vs process→disease


In [7]:
# ============================================================
# ALGORITHM 5 v2: Semantic Bridging with Edge-Type Weights + Meta-Path Penalties
# ============================================================

class SemanticBridgingV2:
    def __init__(self, graph: nx.DiGraph, beta: float = 0.3,
                 prefix_set: Set[tuple] = None,
                 edge_type_weights: Dict[str, float] = None,
                 metapath_penalty: float = 10.0):
        self.graph = graph
        self.beta = beta
        self.embeddings = None
        self.weighted_graph = None
        self.prefix_set = prefix_set or set()
        self.edge_type_weights = edge_type_weights or {}
        self.metapath_penalty = metapath_penalty
        self.descriptions = {n: graph.nodes[n].get('node_name', str(n)) for n in graph.nodes()}
        self.neighbor_cache = {}

    def compute_embeddings(self) -> Dict[int, np.ndarray]:
        """TF-IDF + SVD embeddings, supplemented with neighbor-overlap features."""
        print("  Computing TF-IDF embeddings...")
        nodes = list(self.graph.nodes())
        texts = [self.descriptions[n] for n in nodes]

        vectorizer = TfidfVectorizer(max_features=5000, stop_words='english')
        tfidf_matrix = vectorizer.fit_transform(texts)

        n_components = min(64, tfidf_matrix.shape[1] - 1)
        svd = TruncatedSVD(n_components=n_components, random_state=42)
        tfidf_emb = svd.fit_transform(tfidf_matrix)

        print("  Building neighbor cache for co-occurrence features...")
        for node in nodes:
            self.neighbor_cache[node] = set(self.graph.neighbors(node))

        self.embeddings = {node: tfidf_emb[i] for i, node in enumerate(nodes)}
        self.tfidf_dim = n_components
        print(f"  Embeddings: {len(self.embeddings):,} nodes, dim={n_components}")
        return self.embeddings

    def _cosine_similarity(self, emb1, emb2):
        n1, n2 = np.linalg.norm(emb1), np.linalg.norm(emb2)
        if n1 == 0 or n2 == 0:
            return 0.0
        return np.dot(emb1, emb2) / (n1 * n2)

    def _jaccard_similarity(self, u: int, v: int) -> float:
        """Neighbor overlap — captures structural relatedness for gene symbols."""
        n_u = self.neighbor_cache.get(u, set())
        n_v = self.neighbor_cache.get(v, set())
        if not n_u and not n_v:
            return 0.0
        intersection = len(n_u & n_v)
        union = len(n_u | n_v)
        return intersection / union if union > 0 else 0.0

    def _get_beta_for_edge(self, u: int, v: int) -> float:
        """Node-type-pair-specific beta."""
        u_type = self.graph.nodes[u].get('node_type', '')
        v_type = self.graph.nodes[v].get('node_type', '')
        pair = (u_type, v_type)
        beta_map = {
            ('drug', 'gene/protein'):           0.5,
            ('gene/protein', 'gene/protein'):   0.4,
            ('gene/protein', 'biological_process'): 0.5,
            ('biological_process', 'gene/protein'): 0.5,
            ('gene/protein', 'disease'):        0.4,
            ('gene/protein', 'pathway'):        0.4,
            ('pathway', 'gene/protein'):        0.4,
        }
        return beta_map.get(pair, self.beta * 0.5)

    def compute_edge_weights(self) -> nx.DiGraph:
        """Compute edge weights: semantic sim × edge-type multiplier. All baked in."""
        if self.embeddings is None:
            self.compute_embeddings()
        print("  Computing v2 edge weights...")
        self.weighted_graph = self.graph.copy()

        for u, v in self.weighted_graph.edges():
            emb_u = self.embeddings.get(u)
            emb_v = self.embeddings.get(v)

            if emb_u is not None and emb_v is not None:
                cos_sim = self._cosine_similarity(emb_u, emb_v)
                jac_sim = self._jaccard_similarity(u, v)
                combined_sim = 0.6 * max(0, cos_sim) + 0.4 * jac_sim
            else:
                combined_sim = 0.0

            beta = self._get_beta_for_edge(u, v)
            base_weight = 1.0 - beta * combined_sim

            # Edge-type multiplier from GT frequency
            edge_data = self.graph.get_edge_data(u, v) or {}
            edge_rel = edge_data.get('relation', 'unknown')
            type_mult = self.edge_type_weights.get(edge_rel, 1.5)

            self.weighted_graph[u][v]['weight'] = max(0.01, base_weight * type_mult)

        print(f"  Edge weights computed for {self.weighted_graph.number_of_edges():,} edges")
        return self.weighted_graph

    def find_path(self, source: int, target: int) -> Tuple[List[int], List[str], float]:
        """Standard Dijkstra on precomputed weights. Fast — no relation tracking in state."""
        if self.weighted_graph is None:
            self.compute_edge_weights()
        return find_path_engine(self.graph, self.weighted_graph, source, target, allowed_transition)


print("✓ SemanticBridgingV2 loaded")

✓ SemanticBridgingV2 loaded


## 8. Load Data & Build Graph

In [8]:
# Load PrimeKG
print("Loading PrimeKG...")
nodes_df = pd.read_csv(PATHS['nodes'])
edges_df = pd.read_csv(PATHS['edges'])
print(f"  Nodes: {len(nodes_df):,}, Edges: {len(edges_df):,}")

# Load ground truth
print("\nLoading ground truth...")
gt_nodes_df = pd.read_csv(PATHS['ground_truth_nodes'], dtype={'node_index': int})
gt_edges_df = pd.read_csv(PATHS['ground_truth_edges'])

# Load fold assignments
print("Loading fold assignments...")
folds_df = pd.read_csv(PATHS['fold_assignments'])

# Filter to >= MIN_PATHWAY_NODES
pathway_sizes = gt_nodes_df.groupby('pathway_id').size()
valid_pathways = pathway_sizes[pathway_sizes >= MIN_PATHWAY_NODES].index.tolist()

gt_nodes_df = gt_nodes_df[gt_nodes_df['pathway_id'].isin(valid_pathways)].reset_index(drop=True)
gt_edges_df = gt_edges_df[gt_edges_df['pathway_id'].isin(valid_pathways)].reset_index(drop=True)
folds_df = folds_df[folds_df['pathway_id'].isin(valid_pathways)].reset_index(drop=True)

print(f"\n  Pathways (>= {MIN_PATHWAY_NODES} nodes): {len(valid_pathways)}")
print(f"  Fold distribution:")
print(f"  {folds_df['fold'].value_counts().sort_index().to_dict()}")

# Build graph
print("\nBuilding graph...")
def build_graph(nodes_df, edges_df, bidirectional=True):
    G = nx.DiGraph()
    for _, row in nodes_df.iterrows():
        G.add_node(int(row['node_index']), node_id=str(row['node_id']),
                    node_name=str(row['node_name']), node_type=str(row['node_type']))
    for _, row in edges_df.iterrows():
        src, dst = int(row['x_index']), int(row['y_index'])
        rel, disp = str(row['relation']), str(row['display_relation'])
        G.add_edge(src, dst, relation=rel, display_relation=disp)
        if bidirectional:
            G.add_edge(dst, src, relation=rel, display_relation=disp)
    return G

G = build_graph(nodes_df, edges_df, bidirectional=True)
print(f"✓ Graph: {G.number_of_nodes():,} nodes, {G.number_of_edges():,} edges")

# Precompute evaluation helpers
degree_count = compute_degree_counts(edges_df)
hub_threshold = compute_hub_threshold(degree_count, percentile=95)
print(f"  Hub threshold: {hub_threshold:.0f}")


Loading PrimeKG...
  Nodes: 129,375, Edges: 8,100,498

Loading ground truth...
Loading fold assignments...

  Pathways (>= 4 nodes): 150
  Fold distribution:
  {0: 32, 1: 31, 2: 30, 3: 29, 4: 28}

Building graph...
✓ Graph: 129,375 nodes, 8,099,284 edges
  Hub threshold: 412


## 9. Mine Meta-Path Patterns

In [9]:
print("Mining meta-path patterns from ground truth...")
patterns, pattern_counts, prefix_set = mine_metapath_patterns(gt_edges_df, gt_nodes_df, min_count=1)

# Compute edge-type weights from GT frequency
all_edge_types = set(edges_df['relation'].unique())
edge_type_weights = compute_edge_type_weights(gt_edges_df, all_edge_types, bonus=0.3, penalty=2.0)


Mining meta-path patterns from ground truth...
  Using edge type column: 'relation'
  Mined 35 unique meta-path patterns
  35 patterns with count >= 1
  119 valid prefixes for search constraint

  Top 15 patterns:
     28x  drug_protein → drug_protein → indication
     20x  drug_protein → bioprocess_protein → bioprocess_protein → disease_protein
     16x  drug_protein → protein_protein → disease_protein
     10x  drug_protein → protein_protein → protein_protein → bioprocess_protein → bioprocess_protein → bioprocess_protein → bioprocess_protein → disease_protein
      8x  drug_protein → pathway_protein → pathway_protein → disease_protein
      8x  drug_effect → drug_effect → indication
      7x  drug_protein → molfunc_protein → molfunc_protein → protein_protein → disease_protein
      6x  drug_protein → anatomy_protein_present → anatomy_protein_present → bioprocess_protein → bioprocess_protein → protein_protein → bioprocess_protein → bioprocess_protein → disease_protein
      5x  drug_d

## 10. Algorithm Runner & Evaluation

In [10]:
def run_algorithm(algo_find_path_fn, graph, gt_nodes_df, algo_name, verbose=True):
    """Run algorithm on all pathways, return predictions DataFrame."""
    results = []
    pathways = gt_nodes_df['pathway_id'].unique()
    n_total = len(pathways)
    for idx, pathway_id in enumerate(pathways):
        pw = gt_nodes_df[gt_nodes_df['pathway_id'] == pathway_id].sort_values('step_order')
        source_idx = int(pw.iloc[0]['node_index'])
        target_idx = int(pw.iloc[-1]['node_index'])
        if verbose and idx % 25 == 0:
            print(f"  [{idx+1}/{n_total}] ...")
        start_t = time.perf_counter()
        try:
            path, relations, cost = algo_find_path_fn(source_idx, target_idx)
        except Exception as e:
            path, relations, cost = [], [], float('inf')
        elapsed_ms = (time.perf_counter() - start_t) * 1000
        if path:
            node_ids = [graph.nodes[n].get('node_id', str(n)) for n in path]
            node_names = [graph.nodes[n].get('node_name', str(n)) for n in path]
            results.append({
                'pathway_id': pathway_id, 'algorithm': algo_name,
                'predicted_node_indices': ','.join(map(str, path)),
                'predicted_node_ids': ','.join(node_ids),
                'predicted_node_names': ','.join(node_names),
                'predicted_relations': ','.join(relations),
                'predicted_length': len(path), 'ground_truth_length': len(pw),
                'time_ms': elapsed_ms,
            })
        else:
            results.append({
                'pathway_id': pathway_id, 'algorithm': algo_name,
                'predicted_node_indices': 'NONE', 'predicted_node_ids': 'NONE',
                'predicted_node_names': 'NONE', 'predicted_relations': 'NONE',
                'predicted_length': 0, 'ground_truth_length': len(pw),
                'time_ms': elapsed_ms,
            })
    df = pd.DataFrame(results)
    found = (df['predicted_length'] > 0).sum()
    avg_ms = df['time_ms'].mean()
    print(f"  ✓ {algo_name}: {found}/{n_total} paths found, avg {avg_ms:.1f}ms/pathway")
    return df


def evaluate_predictions(pred_df, gt_nodes_df, gt_edges_df, degree_count, hub_threshold):
    """Evaluate predictions against ground truth."""
    edge_type_col = None
    for c in ['relation_type', 'relation', 'edge_type']:
        if c in gt_edges_df.columns:
            edge_type_col = c
            break
    if edge_type_col is None:
        edge_type_col = gt_edges_df.columns[0]

    results = []
    for _, pred_row in pred_df.iterrows():
        pathway_id = pred_row['pathway_id']
        gt_pw = gt_nodes_df[gt_nodes_df['pathway_id'] == pathway_id].sort_values('step_order')
        gt_ed = gt_edges_df[gt_edges_df['pathway_id'] == pathway_id]
        gt_node_ids = [str(x) for x in gt_pw['node_id'].tolist()]
        gt_target_id = str(gt_pw.iloc[-1]['node_id'])
        gt_edge_types = gt_ed[edge_type_col].tolist() if not gt_ed.empty else []

        if pred_row['predicted_node_ids'] == 'NONE':
            pred_node_ids, pred_indices, pred_relations = [], [], []
        else:
            pred_node_ids = pred_row['predicted_node_ids'].split(',')
            pred_indices = [int(x) for x in pred_row['predicted_node_indices'].split(',')]
            pred_relations = pred_row['predicted_relations'].split(',') if pred_row['predicted_relations'] != 'NONE' else []

        results.append({
            'pathway_id': pathway_id,
            'algorithm': pred_row['algorithm'],
            'precision': metric_precision(pred_node_ids, gt_node_ids),
            'recall': metric_recall(pred_node_ids, gt_node_ids),
            'f1_score': metric_f1(pred_node_ids, gt_node_ids),
            'hits_at_1': calculate_hits_at_k(pred_node_ids, gt_target_id)['hits_at_1'],
            'relation_type_accuracy': calculate_relation_accuracy(pred_relations, gt_edge_types),
            'path_edit_distance': calculate_edit_distance(pred_node_ids, gt_node_ids),
            'hub_node_ratio': calculate_hub_node_ratio(pred_indices, degree_count, hub_threshold),
            'path_length_mae': calculate_path_length_mae(pred_row['predicted_length'], pred_row['ground_truth_length']),
            'path_length_accuracy': metric_path_length_accuracy(pred_row['predicted_length'], pred_row['ground_truth_length']),
            'mrr': metric_mrr(pred_node_ids, gt_node_ids),
            'time_ms': pred_row['time_ms'],
        })
    return pd.DataFrame(results)


print("✓ Runner & evaluation loaded")


✓ Runner & evaluation loaded


## 11. Run Semantic Bridging v2

Semantic Bridging has no training phase, so no CV needed — just run with the new weights and meta-path penalties.


In [11]:
print("=" * 60)
print("SEMANTIC BRIDGING V2")
print("=" * 60)

sem_v2 = SemanticBridgingV2(G, beta=0.3, prefix_set=prefix_set,
                             edge_type_weights=edge_type_weights, metapath_penalty=10.0)
sem_v2.compute_embeddings()
sem_v2.compute_edge_weights()

sem_v2_preds = run_algorithm(sem_v2.find_path, G, gt_nodes_df, 'Semantic Bridging v2')
sem_v2_eval = evaluate_predictions(sem_v2_preds, gt_nodes_df, gt_edges_df, degree_count, hub_threshold)

print("\nSemantic Bridging v2 — Mean Metrics:")
metric_cols = ['precision', 'recall', 'f1_score', 'relation_type_accuracy',
               'path_edit_distance', 'hub_node_ratio', 'path_length_mae', 'time_ms']
print(sem_v2_eval[metric_cols].mean().round(4).to_string())


SEMANTIC BRIDGING V2
  Computing TF-IDF embeddings...
  Building neighbor cache for co-occurrence features...
  Embeddings: 129,375 nodes, dim=64
  Computing v2 edge weights...
  Edge weights computed for 8,099,284 edges
  [1/150] ...
  [26/150] ...
  [51/150] ...
  [76/150] ...
  [101/150] ...
  [126/150] ...
  ✓ Semantic Bridging v2: 150/150 paths found, avg 5695.9ms/pathway

Semantic Bridging v2 — Mean Metrics:
precision                    0.7732
recall                       0.4927
f1_score                     0.5850
relation_type_accuracy       0.7223
path_edit_distance           0.5168
hub_node_ratio               0.7408
path_length_mae              2.5800
time_ms                   5695.8652


## 12. Run Learned A* v2 with 5-Fold Cross-Validation

For each fold: train on 4 folds, evaluate on the held-out fold. Concatenate results for honest evaluation.


In [15]:
print("=" * 60)
print("LEARNED A* V2 — 5-FOLD CROSS-VALIDATION")
print("=" * 60)

all_astar_preds = []

# Pre-train embeddings once (not pathway-dependent)
print("\nPre-training embeddings (shared across folds)...")
astar_base = LearnedAStarV2(G, embedding_dim=64, prefix_set=prefix_set,
                              edge_type_weights=edge_type_weights)
astar_base.train_embeddings()
shared_embeddings = astar_base.embeddings

for fold_id in range(N_FOLDS):
    print(f"\n{'='*50}")
    print(f"FOLD {fold_id}: training on folds != {fold_id}, testing on fold {fold_id}")
    print(f"{'='*50}")

    # Split pathways
    test_pws = folds_df[folds_df['fold'] == fold_id]['pathway_id'].tolist()
    train_pws = folds_df[folds_df['fold'] != fold_id]['pathway_id'].tolist()

    print(f"  Train: {len(train_pws)} pathways, Test: {len(test_pws)} pathways")

    # Prepare training data from train folds only
    training_pathways = []
    for pid in train_pws:
        pw = gt_nodes_df[gt_nodes_df['pathway_id'] == pid].sort_values('step_order')
        if len(pw) >= MIN_PATHWAY_NODES:
            training_pathways.append({'path_nodes': pw['node_index'].tolist()})
    print(f"  Training pathways (>= {MIN_PATHWAY_NODES} nodes): {len(training_pathways)}")

    # Initialize fresh algorithm with shared embeddings
    algo = LearnedAStarV2(G, embedding_dim=64, prefix_set=prefix_set,
                           edge_type_weights=edge_type_weights)
    algo.embeddings = shared_embeddings

    # Train on train folds only
    algo.train_edge_weights(training_pathways)

    # Evaluate on test fold
    test_gt = gt_nodes_df[gt_nodes_df['pathway_id'].isin(test_pws)]
    fold_preds = run_algorithm(algo.find_path, G, test_gt, f'Learned A* v2')
    all_astar_preds.append(fold_preds)

# Concatenate all fold results
astar_v2_preds = pd.concat(all_astar_preds, ignore_index=True)
astar_v2_eval = evaluate_predictions(astar_v2_preds, gt_nodes_df, gt_edges_df, degree_count, hub_threshold)

print(f"\n{'='*60}")
print(f"LEARNED A* V2 — AGGREGATE RESULTS ({len(astar_v2_preds)} pathways)")
print(f"{'='*60}")
print(astar_v2_eval[metric_cols].mean().round(4).to_string())


LEARNED A* V2 — 5-FOLD CROSS-VALIDATION

Pre-training embeddings (shared across folds)...
  Computing spectral embeddings...
  Embeddings: 129,375 nodes, dim=64

FOLD 0: training on folds != 0, testing on fold 0
  Train: 118 pathways, Test: 32 pathways
  Training pathways (>= 4 nodes): 118
  Training edge weight MLP (v2 features)...
  MLP trained on 1382 samples (R²=0.906)
  Feature dim: 35
  [1/32] ...


KeyboardInterrupt: 

## 13. Compare v1 vs v2

Load v1 baseline results and compare side-by-side.


In [None]:
# Combine v2 results
all_v2_eval = pd.concat([sem_v2_eval, astar_v2_eval], ignore_index=True)

# Print comparison
print("=" * 80)
print("V2 RESULTS SUMMARY")
print("=" * 80)

v2_summary = all_v2_eval.groupby('algorithm')[metric_cols].mean().round(4)
print(v2_summary.T.to_string())

# If v1 results exist, compare
v1_path = 'evaluation_results_all_algorithms.csv'
if os.path.exists(v1_path):
    print(f"\n\n{'='*80}")
    print("V1 vs V2 COMPARISON")
    print(f"{'='*80}")
    v1_eval = pd.read_csv(v1_path)

    # Extract v1 results for the same algorithms
    v1_sem = v1_eval[v1_eval['algorithm'] == 'Semantic Bridging'][metric_cols].mean()
    v1_astar = v1_eval[v1_eval['algorithm'] == 'Learned A*'][metric_cols].mean()

    v2_sem = sem_v2_eval[metric_cols].mean()
    v2_astar = astar_v2_eval[metric_cols].mean()

    print(f"\n{'Metric':<28s} {'Sem v1':>10s} {'Sem v2':>10s} {'Δ':>8s}    {'A* v1':>10s} {'A* v2':>10s} {'Δ':>8s}")
    print(f"{'-'*88}")
    for m in metric_cols:
        s1, s2 = v1_sem[m], v2_sem[m]
        a1, a2 = v1_astar[m], v2_astar[m]
        sd = s2 - s1
        ad = a2 - a1
        print(f"{m:<28s} {s1:>10.4f} {s2:>10.4f} {sd:>+8.4f}    {a1:>10.4f} {a2:>10.4f} {ad:>+8.4f}")
else:
    print(f"\n  (v1 results not found at '{v1_path}' — run baseline notebook first for comparison)")


## 14. Visualization

In [None]:
# F1 by pathway length
gt_len_map = gt_nodes_df.groupby('pathway_id').size().to_dict()
all_v2_eval['gt_length'] = all_v2_eval['pathway_id'].map(gt_len_map)

print("F1 by pathway length:")
pivot = all_v2_eval.pivot_table(index='gt_length', columns='algorithm', values='f1_score', aggfunc='mean').round(4)
print(pivot.to_string())

# Plot
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# F1 comparison
ax = axes[0]
for alg in all_v2_eval['algorithm'].unique():
    sub = all_v2_eval[all_v2_eval['algorithm'] == alg]
    by_len = sub.groupby('gt_length')['f1_score'].mean()
    ax.plot(by_len.index, by_len.values, 'o-', label=alg)
ax.set_xlabel('Ground Truth Path Length')
ax.set_ylabel('Mean F1 Score')
ax.set_title('F1 by Pathway Length (v2)')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

# Relation accuracy comparison
ax = axes[1]
for alg in all_v2_eval['algorithm'].unique():
    sub = all_v2_eval[all_v2_eval['algorithm'] == alg]
    by_len = sub.groupby('gt_length')['relation_type_accuracy'].mean()
    ax.plot(by_len.index, by_len.values, 'o-', label=alg)
ax.set_xlabel('Ground Truth Path Length')
ax.set_ylabel('Mean Relation Accuracy')
ax.set_title('Relation Accuracy by Length (v2)')
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

# Overall bar comparison
ax = axes[2]
compare_metrics = ['f1_score', 'relation_type_accuracy', 'path_edit_distance', 'recall']
x = np.arange(len(compare_metrics))
width = 0.35
algos = all_v2_eval['algorithm'].unique()
colors = ['#9b59b6', '#f39c12']
for i, alg in enumerate(algos):
    vals = [all_v2_eval[all_v2_eval['algorithm'] == alg][m].mean() for m in compare_metrics]
    ax.bar(x + i * width, vals, width, label=alg, color=colors[i % len(colors)])
ax.set_xticks(x + width / 2)
ax.set_xticklabels(['F1 ↑', 'Rel Acc ↑', 'Edit Dist ↓', 'Recall ↑'], fontsize=9)
ax.set_title('V2 Algorithm Comparison')
ax.legend(fontsize=8)
ax.set_ylim(0, 1.0)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('v2_comparison.png', dpi=300, bbox_inches='tight')
print("✓ Saved: v2_comparison.png")
plt.show()


## 15. Export

In [None]:
# Save v2 results
all_v2_eval.to_csv('v2_evaluation_results.csv', index=False)
print(f"✓ Saved: v2_evaluation_results.csv ({len(all_v2_eval)} rows)")

sem_v2_preds.to_csv('v2_predictions_semantic_bridging.csv', index=False)
print(f"✓ Saved: v2_predictions_semantic_bridging.csv")

astar_v2_preds.to_csv('v2_predictions_learned_astar.csv', index=False)
print(f"✓ Saved: v2_predictions_learned_astar.csv")

v2_summary.to_csv('v2_algorithm_summary.csv')
print(f"✓ Saved: v2_algorithm_summary.csv")

print("\n✓ All v2 outputs saved!")
