In [1]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import networkx as nx
from collections import defaultdict
from hierarchy_functions import get_unique_coded_terms, create_SNOMED_CT_graph_based_on_terms, find_lca_and_distance, compute_lcas_and_distances

# Building DAG Hierarchy

**Goal: Group terms and find common name to create a hierarchy of DAG nodes from broad down to specific.**

Import Data
1. `dag_df` - Standardized DAG dataframe from workshop
2. `concept_df` - Athena concept dataframe containing all the athena ids, concept codes, and concept names
3. `concept_relationship_df` - Athena dataframe containing the relationships between all the concepts
4. `concept_ancestor_df` - Athena dataframe containing information on the ancestors of terms (unfortunately seems incomplete)

In [2]:
dag_df = pd.read_csv('../data/DAGs_standardized.csv', dtype={'Exposure': str, 'Outcome':str})
concept_df = pd.read_csv('../Standardization/athena_vocabulary/CONCEPT.csv', sep='\t', dtype={'concept_code': str, 'concept_id': str}, low_memory=False)
concept_relationship_df = pd.read_csv('../Standardization/athena_vocabulary/CONCEPT_RELATIONSHIP.csv', sep='\t', dtype={'concept_id_1':str, 'concept_id_2': str}, low_memory=False)
concept_ancestor_df = pd.read_csv('../Standardization/athena_vocabulary/CONCEPT_ANCESTOR.csv', sep='\t',dtype={'ancestor_concept_id': str, 'descendant_concept_id': str}, low_memory=False)

Get set of terms

In [3]:
my_terms = get_unique_coded_terms(dag_df)

Dictionary to convert codes to names

In [4]:
code_to_name = dict(zip(concept_df["concept_id"], concept_df["concept_name"]))
name_to_code = dict(zip(concept_df["concept_name"], concept_df["concept_id"]))
my_terms_written = [code_to_name[code] for code in my_terms]

### Hierarchy

In [5]:
selected_relationship_ids = ["Subsumes"]

In [6]:
def build_hierarchy(my_terms_filtered, G_clinical, compute_lcas_and_distances, find_lca_and_distance):
    """
    Build a hierarchy from terms using a graph, precomputing LCAs and distances.

    Parameters:
    - my_terms_filtered (set): A set of terms to process.
    - G_clinical (networkx.Graph): The graph representing the relationships.
    - compute_lcas_and_distances (function): A function that computes LCAs and distances for initial terms.
    - find_lca_and_distance (function): A function to compute the LCA and distance for a pair of terms.

    Returns:
    - hierarchy (defaultdict): The hierarchy built from the terms.
    """
    # Precompute LCAs and distances
    lca_distances = compute_lcas_and_distances(my_terms_filtered, G_clinical)

    remaining_terms = my_terms_filtered.copy()
    hierarchy = defaultdict(list)
    processed_pairs = set()  # Track processed pairs to avoid infinite loops

    def has_descendants(term, remaining_terms, graph):
        """Check if a term has descendants in the remaining terms."""
        descendants = nx.descendants(graph, term)
        return bool(descendants & set(remaining_terms))
    
    # Special case: If there is only one node, add it to the hierarchy and return
    if len(remaining_terms) == 1:
        single_node = next(iter(remaining_terms))  # Get the single node
        hierarchy[single_node] = []  # No children
        return hierarchy

    while len(remaining_terms) > 1:
        shortest_distance = float("inf")
        best_pair = None
        best_lca = None

        # Find the best pair of terms based on distance
        for node1 in remaining_terms:
            for node2 in remaining_terms:
                if node1 == node2:
                    continue
                key = frozenset({node1, node2})
                if key in processed_pairs:
                    continue
                if key not in lca_distances[1]:
                    lca, distance = find_lca_and_distance(G_clinical, node1, node2)
                    lca_distances[0][key] = lca
                    lca_distances[1][key] = distance
                dist = lca_distances[1][key]

                # Update the best pair if this distance is shorter
                if dist < shortest_distance:
                    shortest_distance = dist
                    best_pair = (node1, node2)
                    best_lca = lca_distances[0][key]

        if best_pair is None:
            break
        
        hierarchy[best_lca].append(best_pair)
        processed_pairs.add(frozenset(best_pair))

        # Replace grouped nodes with their LCA if they no longer have descendants
        for node in best_pair:
            if not has_descendants(node, remaining_terms, G_clinical):
                remaining_terms.remove(node)
        remaining_terms.add(best_lca)

    return hierarchy

In [7]:
def hierarchy_to_graph(hierarchy, original_terms):
    """
    Convert a hierarchy dictionary to a directed graph, removing duplicates and self-references.
    Adds a 'type' attribute to distinguish between original and LCA nodes.
    """
    graph = nx.DiGraph()

    # Iterate over the hierarchy
    for lca, pairs in hierarchy.items():
        # Add the LCA node and mark it as an 'lca'
        if lca not in graph:
            node_type = "original" if lca in original_terms else "lca"
            graph.add_node(lca, type=node_type)

        for pair in pairs:
            for node in pair:
                # Add the original terms or nodes if not already added
                if node not in graph:
                    node_type = "original" if node in original_terms else "lca"
                    graph.add_node(node, type=node_type)

                # Add edges, avoiding self-references
                if lca != node:
                    if not graph.has_edge(lca, node):
                        graph.add_edge(lca, node)

    return graph

In [8]:
def prune_redundant_edges(graph):
    """Remove redundant edges that create shortcuts in the hierarchy."""
    edges_to_remove = set()

    for node in graph.nodes:
        descendants = nx.descendants(graph, node)
        for descendant in descendants:
            for intermediate in graph.successors(node):
                if intermediate in descendants and graph.has_edge(intermediate, descendant):
                    edges_to_remove.add((node, descendant))

    graph.remove_edges_from(edges_to_remove)
    return graph

In [9]:
def replace_nodes_with_names(graph, mapping):
    """Replace node codes with names based on a mapping dictionary while preserving attributes."""

    new_graph = nx.DiGraph()

    # Add nodes with attributes, replacing codes with names
    for node, attrs in graph.nodes(data=True):
        new_node = mapping.get(node, node)  # Replace code with name if mapping exists
        new_graph.add_node(new_node, **attrs)  # Preserve attributes

    # Add edges, replacing codes with names
    for u, v, attrs in graph.edges(data=True):
        new_u = mapping.get(u, u)
        new_v = mapping.get(v, v)
        new_graph.add_edge(new_u, new_v, **attrs)  # Preserve edge attributes

    return new_graph

In [10]:
concept_classes = list(concept_df[concept_df.concept_id.isin(my_terms)].concept_class_id.unique())
concept_classes.remove('Disorder')
concept_classes

['Substance',
 'Clinical Finding',
 'Procedure',
 'Observable Entity',
 'Context-dependent',
 'Morph Abnormality',
 'Social Context',
 'Event']

In [11]:
full_graph = nx.DiGraph()

for concept_class in concept_classes:

    # Group clinical finding and disorder together
    if concept_class=='Clinical Finding':
        concept_class = ['Clinical Finding', 'Disorder']
        filtered_terms = set(concept_df[concept_df.concept_id.isin(my_terms)&(concept_df.concept_class_id.isin(concept_class))].concept_id)

    else:
        filtered_terms = set(concept_df[concept_df.concept_id.isin(my_terms)&(concept_df.concept_class_id==concept_class)].concept_id)

    G_x = create_SNOMED_CT_graph_based_on_terms(filtered_terms, concept_df, concept_relationship_df, selected_relationship_ids)
    
    hierarchy_x = build_hierarchy(filtered_terms, G_x, compute_lcas_and_distances, find_lca_and_distance)
    graph_x = hierarchy_to_graph(hierarchy_x, filtered_terms)
    pruned_graph_x = prune_redundant_edges(graph_x)
    graph_with_names_x = replace_nodes_with_names(pruned_graph_x, code_to_name)
    
    full_graph = nx.compose(full_graph, graph_with_names_x)
    print(f'Adding {graph_with_names_x.number_of_nodes()} terms to the graph for concept class {concept_class}.\n')
print(f'Total number of terms in the Snomed DAG are {full_graph.number_of_nodes()} with {full_graph.number_of_edges()} edges.')

Adding 11 terms to the graph for concept class Substance.

Adding 145 terms to the graph for concept class ['Clinical Finding', 'Disorder'].

Adding 45 terms to the graph for concept class Procedure.

Adding 9 terms to the graph for concept class Observable Entity.

Adding 2 terms to the graph for concept class Context-dependent.

Adding 8 terms to the graph for concept class Morph Abnormality.

Adding 3 terms to the graph for concept class Social Context.

Adding 1 terms to the graph for concept class Event.

Total number of terms in the Snomed DAG are 224 with 224 edges.


In [12]:
nx.write_gexf(full_graph, 'snomed_alternative_grouping_2.gexf')

### Slider

In [13]:
import networkx as nx

def synchronized_expansion_with_reappearance(full_graph):
    """
    Perform synchronized hierarchical expansion on a DAG,
    keeping nodes active and visible in every iteration until all their children are expanded.
    
    Parameters:
    - full_graph: A NetworkX DiGraph object representing the DAG.
    
    Returns:
    - iterations: A list of lists, where each inner list contains the nodes expanded at that iteration.
    """
    # Step 1: Track parent dependencies (number of parents yet to be processed)
    parent_dependencies = {node: full_graph.in_degree(node) for node in full_graph.nodes}

    # Step 2: Identify root nodes (those with no incoming edges)
    roots = [node for node, indegree in parent_dependencies.items() if indegree == 0]

    # Step 3: Iterative expansion with multiple parent handling
    processed = set()  # Set to track processed nodes
    iterations = []  # List to store nodes expanded at each iteration
    current_level = set(roots)  # Start with root nodes

    while current_level:
        iterations.append(list(current_level))  # Add the current level to the result
        next_level = set(current_level)  # Include current nodes by default

        for node in current_level:
            # Add children to next level if all their parents are processed
            for child in full_graph.successors(node):
                if child not in processed:
                    # Decrement the dependency counter for each child
                    parent_dependencies[child] -= 1
                    # If all parents are processed, the child becomes eligible
                    if parent_dependencies[child] == 0:
                        next_level.add(child)

            # Only mark the node as processed if all its children are processed
            if all(
                child in processed or child in next_level
                for child in full_graph.successors(node)
            ):
                processed.add(node)  # Mark the node as fully processed

        # Filter out fully processed nodes from the current level
        current_level = {node for node in next_level if node not in processed}

    return iterations

# Example usage
iterations = synchronized_expansion_with_reappearance(full_graph)
for i, level in enumerate(iterations):
    print(f"Iteration {i}: {level}")

Iteration 0: ['Substance', 'Observable entity', 'Death', 'Procedure', 'FH: Cardiovascular disease', 'Clinical finding', 'Morphologically abnormal structure', 'Social context']
Iteration 1: ['Hypercoagulability state', 'Bleeding', 'Patient position finding', 'Pregnancy', 'Immunosuppression', 'Pain', 'Mechanical lesion', 'Increased body mass index', 'Surgical procedure', 'Body temperature above reference range', 'Risk of cardiovascular disease', 'Measurement finding', 'Disability', 'Dysfunction of urinary bladder', 'Leuko-araiosis', 'Aphasia', 'Procedure with a clinical finding focus', 'Able to cope', 'Vascular sclerosis', 'Cardiac syndrome X', 'Glomerular filtration rate', 'Dietary finding', 'Social worker involved', 'Activities of daily living assessment', 'Blood-brain barrier', 'Functional finding', 'Age factor', 'Procedure by method', 'Economic status', 'Catheter procedure', 'General clinical state', 'Weight loss', 'Racial group', 'Physical fitness state', 'Lesion size', 'Drug or med

In [28]:
def precompute_depths(graph):
    """
    Precompute the depths (max depth to leaf) for all nodes in the graph.
    """
    depths = {}

    def compute_depth(node):
        if node in depths:
            return depths[node]
        if graph.out_degree(node) == 0:  # Leaf node
            depths[node] = 0
        else:
            depths[node] = 1 + max(compute_depth(child) for child in graph.successors(node))
        return depths[node]
    
    for node in nx.topological_sort(graph):
        compute_depth(node)
    
    return depths

In [29]:
def visualize_reachability_expansion_with_yifan_hu(graph):
    """
    Create an interactive visualization for reachability-based expansion of a graph,
    using the Yifan Hu layout algorithm for node positioning.
    
    Parameters:
    - graph: A NetworkX DiGraph to visualize.
    """
    # Precompute depths for efficiency
    depth_cache = precompute_depths(graph)
    
    def calculate_reachability(node):
        """Calculate the reachability for a node, which is the depth underneath it and the number of out-degrees (children) it has."""
        return graph.out_degree(node) + depth_cache[node]

    # Identify root nodes (nodes with in-degree 0)
    root_nodes = set(node for node in graph.nodes if graph.in_degree(node) == 0)

    # Prepare iterations
    active_nodes = root_nodes.copy()  # Start with root nodes
    processed_nodes = set()
    blue_nodes = root_nodes.copy()  # Initially, all root nodes are blue
    iterations = []
    
    while active_nodes:
        # Exclude nodes with no children from being highlighted (leaf nodes)
        expandable_nodes = {node for node in active_nodes if graph.out_degree(node) > 0}
        if not expandable_nodes:
            break  # No more nodes to expand
        
        # Calculate reachability for all expandable nodes
        node_reachability = {node: calculate_reachability(node) for node in expandable_nodes}
        most_specific_node = max(node_reachability, key=node_reachability.get)
        
        # Record the current iteration
        iterations.append((list(active_nodes), most_specific_node))
        
        next_level = set(active_nodes)
        for node in active_nodes:
            if node == most_specific_node:
                next_level.remove(node)
                for child in graph.successors(node):
                    next_level.add(child)
                    blue_nodes.add(child)  # Mark children as blue
        
        # Process the expanded node
        processed_nodes.add(most_specific_node)
        blue_nodes.discard(most_specific_node)  # Remove the node from blue once it's processed
        active_nodes = next_level - processed_nodes

    # Generate positions using Yifan Hu layout (requires pygraphviz)
    pos = nx.nx_agraph.graphviz_layout(graph, prog='sfdp', args='-Goverlap=false -Gsplines=true')

    # Slider callback function
    def update(iteration):
        plt.figure(figsize=(20, 15))
        nodes_to_draw, highlighted_node = iterations[iteration]
        
        # Collect processed nodes up to the current iteration
        processed_up_to_now = set(
            iterations[i][1] for i in range(iteration)
        )
        
        # Collect blue nodes up to the current iteration
        blue_nodes_up_to_now = set(root_nodes)  # Start with root nodes
        for i in range(iteration):
            _, parent_node = iterations[i]
            for child in graph.successors(parent_node):
                if child not in processed_up_to_now:
                    blue_nodes_up_to_now.add(child)
        
        # Determine node colors
        node_colors = []
        for node in graph.nodes:
            if node in processed_up_to_now:  # Already expanded nodes
                color = 'black'
            elif node == highlighted_node:  # Node being split
                color = 'red'
            elif node in blue_nodes_up_to_now:  # Nodes below the split node and not yet expanded
                color = 'blue'
            else:  # Unprocessed nodes
                color = 'gray'
            node_colors.append(color)

        # Draw the graph with highlighting
        nx.draw(
            graph,
            pos,
            with_labels=True,
            node_color=node_colors,
            node_size=500,
            font_size=10,
        )
        
        plt.title(f"Iteration {iteration}")
        plt.show()
    
    # Create the slider
    interact(update, iteration=IntSlider(min=0, max=len(iterations) - 1, step=1, value=0))

# Usage
visualize_reachability_expansion_with_yifan_hu(full_graph)





interactive(children=(IntSlider(value=0, description='iteration', max=66), Output()), _dom_classes=('widget-in…