# Biolink Predicate Granularity Explorer

This notebook explores the biolink model predicate hierarchy to design a granularity-based filtering system.

**Goal**: Allow users to exclude vague predicates (like `related_to`) while keeping more specific ones.

**Workflow:**
1. Fetch/cache biolink-model.yaml (version-aware)
2. Parse predicate hierarchy from `is_a` relationships
3. Build tree structure with treelib
4. Analyze depth distribution
5. Visualize "cuts" at different granularity levels
6. Test filtering against real TCT MetaKG predicates

## 1. Setup & Dependencies

In [None]:
import json
import requests
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import yaml
from treelib import Tree
import pandas as pd

# Configuration
DATA_DIR = Path("../data")
DATA_DIR.mkdir(exist_ok=True)

BIOLINK_CACHE_PATH = DATA_DIR / "biolink-model.yaml"
BIOLINK_VERSION_PATH = DATA_DIR / "biolink-model-version.txt"
GITHUB_REPO = "biolink/biolink-model"

# Toggle to exclude literature co-occurrence predicates
EXCLUDE_LITERATURE_COOCCURRENCE = False

print(f"Data directory: {DATA_DIR.resolve()}")
print(f"Cache path: {BIOLINK_CACHE_PATH}")
print(f"Exclude literature co-occurrence: {EXCLUDE_LITERATURE_COOCCURRENCE}")

## 2. Version-Aware Biolink Model Fetching

- Checks GitHub for latest release tag
- Compares to locally cached version
- Updates cache only if newer version available
- Graceful fallback on network issues

In [None]:
def get_latest_biolink_version() -> Optional[str]:
    """Fetch latest release tag from GitHub API."""
    try:
        resp = requests.get(
            f"https://api.github.com/repos/{GITHUB_REPO}/releases/latest",
            timeout=10
        )
        if resp.status_code == 200:
            return resp.json().get("tag_name")
        else:
            print(f"Warning: GitHub API returned status {resp.status_code}")
    except requests.RequestException as e:
        print(f"Warning: Could not check GitHub for updates: {e}")
    return None


def get_local_version() -> Optional[str]:
    """Read cached version if exists."""
    if BIOLINK_VERSION_PATH.exists():
        return BIOLINK_VERSION_PATH.read_text().strip()
    return None


def fetch_biolink_model(version: str) -> str:
    """Fetch biolink-model.yaml for specific version/tag."""
    url = f"https://raw.githubusercontent.com/{GITHUB_REPO}/{version}/biolink-model.yaml"
    print(f"Fetching from: {url}")
    resp = requests.get(url, timeout=60)
    resp.raise_for_status()
    return resp.text


def load_biolink_model() -> dict:
    """Load biolink model, updating cache if newer version available."""
    local_version = get_local_version()
    latest_version = get_latest_biolink_version()

    need_update = False
    
    if latest_version:
        if local_version is None:
            print(f"No local cache found. Downloading {latest_version}...")
            need_update = True
        elif latest_version != local_version:
            print(f"Update available: {local_version} -> {latest_version}")
            need_update = True
        else:
            print(f"Local cache is current: {local_version}")
    else:
        print("Could not check for updates. Using local cache if available.")

    if need_update and latest_version:
        try:
            yaml_content = fetch_biolink_model(latest_version)
            BIOLINK_CACHE_PATH.write_text(yaml_content)
            BIOLINK_VERSION_PATH.write_text(latest_version)
            print(f"Successfully cached version {latest_version}")
        except Exception as e:
            print(f"Warning: Failed to download update: {e}")
            if not BIOLINK_CACHE_PATH.exists():
                raise RuntimeError("No local cache and cannot download biolink model")
            print("Falling back to existing cache.")

    if not BIOLINK_CACHE_PATH.exists():
        raise RuntimeError(
            f"No biolink-model.yaml found at {BIOLINK_CACHE_PATH}. "
            "Check network connection."
        )

    print(f"Loading from: {BIOLINK_CACHE_PATH}")
    return yaml.safe_load(BIOLINK_CACHE_PATH.read_text())


# Load the model
biolink_model = load_biolink_model()
print(f"\nLoaded biolink model with {len(biolink_model.get('slots', {}))} slots")

## 3. Parse Predicates from Biolink Model

Extract predicates (slots) that have `is_a` relationships forming the hierarchy.

In [None]:
def normalize_predicate_name(name: str) -> str:
    """Normalize predicate name: spaces to underscores, lowercase."""
    return name.replace(" ", "_").lower().strip()


def extract_predicates(model: dict) -> Dict[str, Dict]:
    """Extract predicates (slots) and their hierarchy relationships.

    Filters to only include slots that are part of the predicate hierarchy
    (those that eventually trace back to 'related_to' via is_a).
    
    Note: Biolink model uses spaces in slot names (e.g., "related to"),
    but the API uses underscores (e.g., "related_to"). We normalize to underscores.
    """
    slots = model.get("slots", {})
    predicates = {}

    # First pass: collect all slots - normalize names to use underscores
    for name, definition in slots.items():
        if definition is None:
            continue

        # Normalize the slot name (spaces -> underscores)
        normalized_name = normalize_predicate_name(name)
        
        is_a_raw = definition.get("is_a")
        # Also normalize the parent name if present
        is_a = normalize_predicate_name(is_a_raw) if is_a_raw else None
        
        slot_uri = definition.get("slot_uri", "")

        # Always include 'related_to' as root, plus any slot with is_a
        if normalized_name == "related_to" or is_a:
            predicates[normalized_name] = {
                "is_a": is_a,
                "description": definition.get("description", ""),
                "inverse": normalize_predicate_name(definition.get("inverse", "")) if definition.get("inverse") else None,
                "symmetric": definition.get("symmetric", False),
                "slot_uri": slot_uri,
                "abstract": definition.get("abstract", False),
                "original_name": name,  # Keep original for reference
            }

    print(f"First pass: found {len(predicates)} slots with is_a relationships")
    
    # Check if related_to was found
    if "related_to" in predicates:
        print("  ✓ Found 'related_to' root predicate")
    else:
        print("  ✗ WARNING: 'related_to' not found!")

    # Second pass: filter to only predicates in the related_to hierarchy
    def traces_to_related_to(name: str, visited: set = None) -> bool:
        """Check if predicate eventually inherits from related_to."""
        if visited is None:
            visited = set()
        if name in visited:
            return False  # Cycle detection
        visited.add(name)

        if name == "related_to":
            return True
        if name not in predicates:
            return False
        parent = predicates[name].get("is_a")
        if parent:
            return traces_to_related_to(parent, visited)
        return False

    # Filter to related_to hierarchy only
    related_to_predicates = {}
    for name, info in predicates.items():
        if name == "related_to" or traces_to_related_to(name):
            related_to_predicates[name] = info

    return related_to_predicates


predicates = extract_predicates(biolink_model)
print(f"\nExtracted {len(predicates)} predicates in related_to hierarchy")

# Show sample
print("\nSample predicates:")
for name, info in list(predicates.items())[:10]:
    print(f"  {name}: is_a={info['is_a']}")

## 4. Build Predicate Hierarchy Tree

Use treelib to create a navigable tree structure from the `is_a` relationships.

In [None]:
def build_predicate_tree(predicates: Dict[str, Dict]) -> Tree:
    """Build treelib Tree from predicate is_a relationships."""
    tree = Tree()

    # Track which nodes have been added
    added = set()

    def add_node(name: str, parent: Optional[str] = None):
        """Recursively add node, ensuring parent exists first."""
        if name in added:
            return
        
        # If parent specified and not yet added, add parent first
        if parent and parent not in added:
            parent_info = predicates.get(parent, {})
            grandparent = parent_info.get("is_a")
            # Only add parent if it's in our predicate set
            if parent in predicates:
                add_node(parent, grandparent)

        # Determine actual parent for tree (must be in added set)
        tree_parent = parent if parent in added else None

        tree.create_node(
            tag=name,
            identifier=name,
            parent=tree_parent,
            data=predicates.get(name, {})
        )
        added.add(name)

    # Add all predicates
    for name, info in predicates.items():
        add_node(name, info.get("is_a"))

    return tree


predicate_tree = build_predicate_tree(predicates)
print(f"Built tree with {len(predicate_tree)} nodes")

# Find roots (nodes with no parent)
roots = [node.identifier for node in predicate_tree.all_nodes() 
         if predicate_tree.parent(node.identifier) is None]
print(f"Root nodes: {roots}")

## 5. Visualize Full Predicate Tree

Display the complete hierarchy using treelib's built-in display.

In [None]:
print("=" * 70)
print("BIOLINK PREDICATE HIERARCHY")
print("=" * 70)
print()

# Show tree from related_to root
if "related_to" in [n.identifier for n in predicate_tree.all_nodes()]:
    predicate_tree.show(idhidden=False)
else:
    # Show all roots if related_to not found
    for root in roots:
        print(f"\n--- Tree rooted at: {root} ---")
        predicate_tree.show(nid=root)

## 6. Calculate Predicate Depths

Compute the depth of each predicate from the root (`related_to`).

In [None]:
def get_all_depths(tree: Tree) -> Dict[str, int]:
    """Calculate depth of each predicate from root."""
    depths = {}

    def calculate_depth(node_id: str, current_depth: int = 0):
        depths[node_id] = current_depth
        for child in tree.children(node_id):
            calculate_depth(child.identifier, current_depth + 1)

    # Start from all roots
    for node in tree.all_nodes():
        if tree.parent(node.identifier) is None:
            calculate_depth(node.identifier, 0)

    return depths


predicate_depths = get_all_depths(predicate_tree)

# Analyze depth distribution
depth_counts = {}
for pred, depth in predicate_depths.items():
    depth_counts[depth] = depth_counts.get(depth, 0) + 1

print("Predicate Depth Distribution:")
print("-" * 50)
for depth in sorted(depth_counts.keys()):
    count = depth_counts[depth]
    bar = "#" * min(count, 50)
    print(f"  Depth {depth}: {count:3d} predicates  {bar}")

max_depth = max(predicate_depths.values()) if predicate_depths else 0
print(f"\nMax depth: {max_depth}")

# Show examples at each depth
print("\nExamples at each depth:")
print("-" * 50)
for depth in sorted(depth_counts.keys()):
    examples = [p for p, d in predicate_depths.items() if d == depth][:5]
    print(f"  Depth {depth}: {', '.join(examples)}{'...' if depth_counts[depth] > 5 else ''}")

## 7. Define Granularity Filtering Functions

Functions to filter predicates based on depth (granularity level).

In [None]:
def get_predicates_at_min_depth(depths: Dict[str, int], min_depth: int) -> Set[str]:
    """Return predicates at or deeper than min_depth."""
    return {pred for pred, depth in depths.items() if depth >= min_depth}


def get_predicates_excluded_at_depth(depths: Dict[str, int], min_depth: int) -> Set[str]:
    """Return predicates excluded (shallower than min_depth)."""
    return {pred for pred, depth in depths.items() if depth < min_depth}


def filter_predicates_by_granularity(
    predicates_list: List[str],
    min_depth: int,
    depths: Dict[str, int],
    exclude_literature: bool = False
) -> List[str]:
    """Filter predicate list by granularity level.
    
    Args:
        predicates_list: List of predicates (may include biolink: prefix)
        min_depth: Minimum depth required (0 = all, 1 = exclude root, etc.)
        depths: Dict mapping predicate names to their depths
        exclude_literature: If True, exclude predicates with 'literature' in name
    
    Returns:
        Filtered list of predicates
    """
    result = []
    for pred in predicates_list:
        # Normalize: remove biolink: prefix if present
        pred_name = pred.replace("biolink:", "").strip()

        # Check depth (predicates not in our tree are excluded)
        if pred_name not in depths:
            continue
        if depths[pred_name] < min_depth:
            continue

        # Check literature exclusion
        if exclude_literature and "literature" in pred_name.lower():
            continue

        result.append(pred)

    return result


# Test at different granularity levels
print("Predicates remaining at each granularity level:")
print("-" * 60)
for level in range(min(max_depth + 1, 8)):
    remaining = get_predicates_at_min_depth(predicate_depths, level)
    excluded = get_predicates_excluded_at_depth(predicate_depths, level)
    print(f"Level {level}: {len(remaining):3d} allowed, {len(excluded):3d} excluded")
    if excluded:
        excluded_list = sorted(excluded)[:5]
        suffix = "..." if len(excluded) > 5 else ""
        print(f"         Excluded: {', '.join(excluded_list)}{suffix}")

## 8. Visualize Tree "Cuts" at Different Levels

Show what the predicate tree looks like when cut at different granularity levels.

In [None]:
def show_tree_at_depth(
    tree: Tree, 
    depths: Dict[str, int], 
    min_depth: int, 
    max_display: int = 100
):
    """Show subtree of predicates allowed at given granularity."""
    allowed = get_predicates_at_min_depth(depths, min_depth)
    excluded = get_predicates_excluded_at_depth(depths, min_depth)

    print(f"\n{'=' * 70}")
    print(f"GRANULARITY LEVEL {min_depth}: {len(allowed)} predicates allowed")
    
    # Show excluded predicates (truncate if too many)
    if len(excluded) <= 10:
        print(f"Excluded ({len(excluded)}): {', '.join(sorted(excluded))}")
    else:
        excluded_sample = ', '.join(sorted(excluded)[:10])
        print(f"Excluded ({len(excluded)}): {excluded_sample}...")
    print("=" * 70)

    if len(allowed) == 0:
        print("(No predicates at this level)")
        return

    # Build subtree of allowed predicates
    if len(allowed) <= max_display:
        # Use a forest approach - create artificial root to hold multiple trees
        subtree = Tree()
        subtree.create_node("ALLOWED_PREDICATES", "root")
        added_to_subtree = {"root"}
        
        # Sort by depth to ensure parents processed before children
        sorted_allowed = sorted(allowed, key=lambda x: depths.get(x, 0))
        
        for pred in sorted_allowed:
            if pred in added_to_subtree:
                continue
            
            # Find closest allowed ancestor in original tree
            parent_id = None
            current = tree.parent(pred)
            while current:
                if current.identifier in allowed and current.identifier in added_to_subtree:
                    parent_id = current.identifier
                    break
                current = tree.parent(current.identifier)
            
            # If no allowed ancestor found, attach to artificial root
            if parent_id is None:
                parent_id = "root"

            subtree.create_node(pred, pred, parent=parent_id)
            added_to_subtree.add(pred)

        # Show the tree (skip the artificial root label)
        subtree.show()
    else:
        print(f"(Too many to display full tree - showing depth {min_depth} predicates only)")
        # Show just the predicates at exactly this depth level
        at_this_depth = [p for p, d in depths.items() if d == min_depth]
        print(f"\nPredicates at exactly depth {min_depth} ({len(at_this_depth)}):")
        for pred in sorted(at_this_depth)[:30]:
            print(f"  - {pred}")
        if len(at_this_depth) > 30:
            print(f"  ... and {len(at_this_depth) - 30} more")


# Show cuts at key levels
for level in range(min(max_depth + 1, 5)):
    show_tree_at_depth(predicate_tree, predicate_depths, level)

## 9. Test with Real TCT MetaKG Data

Load the Translator MetaKG and test how granularity filtering affects real predicates.

In [None]:
# Load TCT if available
try:
    from TCT import TCT
    from TCT import translator_metakg

    print("Loading Translator MetaKG...")
    print("(This may take 1-2 minutes)\n")
    
    APInames, metaKG, _ = translator_metakg.load_translator_resources()

    # Get all predicates in MetaKG
    metakg_predicates = list(set(metaKG["Predicate"]))
    print(f"MetaKG contains {len(metakg_predicates)} unique predicates")
    
    # Show sample
    print(f"\nSample MetaKG predicates:")
    for pred in sorted(metakg_predicates)[:10]:
        pred_name = pred.replace("biolink:", "")
        depth = predicate_depths.get(pred_name, "N/A")
        print(f"  {pred} (depth: {depth})")

    # Test filtering at each level
    print("\n" + "=" * 60)
    print("Filtering MetaKG predicates by granularity:")
    print("=" * 60)
    
    for level in range(min(max_depth + 1, 6)):
        # Without literature exclusion
        filtered = filter_predicates_by_granularity(
            metakg_predicates,
            level,
            predicate_depths,
            exclude_literature=False
        )
        
        # With literature exclusion
        filtered_no_lit = filter_predicates_by_granularity(
            metakg_predicates,
            level,
            predicate_depths,
            exclude_literature=True
        )
        
        pct = len(filtered) / len(metakg_predicates) * 100 if metakg_predicates else 0
        pct_no_lit = len(filtered_no_lit) / len(metakg_predicates) * 100 if metakg_predicates else 0
        
        print(f"Level {level}: {len(filtered):3d} ({pct:5.1f}%)  |  excl. lit: {len(filtered_no_lit):3d} ({pct_no_lit:5.1f}%)")

    TCT_AVAILABLE = True

except ImportError as e:
    print(f"TCT not available - skipping MetaKG analysis")
    print(f"Error: {e}")
    TCT_AVAILABLE = False
    metakg_predicates = []

In [None]:
# Predicates of interest
predicates_of_interest = [
    "related_to",
    "associated_with",
    "correlated_with",
    "occurs_together_in_literature_with",
    "interacts_with",
    "physically_interacts_with",
    "directly_physically_interacts_with",
    "affects",
    "regulates",
    "causes",
    "contributes_to",
    "treats",
    "gene_associated_with_condition",
    "genetically_associated_with",
    "in_clinical_trials_for",
    "has_part"
]

print("Predicates of Interest - Depth Analysis:")
print("-" * 60)
print(f"{'Predicate':<45} {'Depth':>6} {'In MetaKG':>10}")
print("-" * 60)

for pred in predicates_of_interest:
    depth = predicate_depths.get(pred, "N/A")
    in_metakg = "Yes" if f"biolink:{pred}" in metakg_predicates or pred in metakg_predicates else "No"
    print(f"{pred:<45} {str(depth):>6} {in_metakg:>10}")

## 10. Analyze Specific Predicates of Interest

Look at specific predicates that might be problematic or useful.

## 11. Summary & Recommendations

In [34]:
print("=" * 70)
print("SUMMARY: PREDICATE GRANULARITY LEVELS")
print("=" * 70)

# Create summary table
summary_data = []
for level in range(min(max_depth + 1, 7)):
    allowed = get_predicates_at_min_depth(predicate_depths, level)
    excluded = get_predicates_excluded_at_depth(predicate_depths, level)
    
    # Get excluded examples
    excluded_examples = ", ".join(sorted(excluded)[:3])
    if len(excluded) > 3:
        excluded_examples += "..."
    
    summary_data.append({
        "Level": level,
        "Allowed": len(allowed),
        "Excluded": len(excluded),
        "Excluded Examples": excluded_examples if excluded else "(none)"
    })

summary_df = pd.DataFrame(summary_data)
print("\n")
print(summary_df.to_string(index=False))

print("\n" + "=" * 70)
print("CONFIGURATION NOTES")
print("=" * 70)
print(f"\n- EXCLUDE_LITERATURE_COOCCURRENCE = {EXCLUDE_LITERATURE_COOCCURRENCE}")
print("- Level 0 = All predicates (most permissive, includes 'related_to')")
print("- Level 1 = Exclude 'related_to' (the most vague predicate)")
print("- Level 2+ = Increasingly specific predicates only")
print(f"- Max available level: {max_depth}")

print("\n" + "=" * 70)
print("NEXT STEPS")
print("=" * 70)
print("""
1. Review the tree cuts above to decide on granularity presets
2. Consider which levels make biological sense for your use case
3. Decide on UI: slider vs named presets ("All", "Moderate", "Specific")
4. Implement predicate filtering in trapi_client.py
5. Add UI selector in app.py (similar to intermediate_types)
""")

SUMMARY: PREDICATE GRANULARITY LEVELS


 Level  Allowed  Excluded                                                   Excluded Examples
     0      245         0                                                              (none)
     1      244         1                                                          related_to
     2      238         7 composed_primarily_of, disease_has_location, location_of_disease...
     3      143       102                         active_in, acts_upstream_of, affected_by...
     4       48       197               active_in, actively_involved_in, actively_involves...
     5        6       239               active_in, actively_involved_in, actively_involves...
     6        1       244               active_in, actively_involved_in, actively_involves...

CONFIGURATION NOTES

- EXCLUDE_LITERATURE_COOCCURRENCE = False
- Level 0 = All predicates (most permissive, includes 'related_to')
- Level 1 = Exclude 'related_to' (the most vague predicate)
- Level 2+ = Inc