### Explore graph algorithms in Network X

In [1]:
from dotenv import load_dotenv
from arango import ArangoClient
import os

load_dotenv()

database = ArangoClient(hosts=os.getenv("ARANGO_HOST")).db(
    username=os.getenv("ARANGO_USERNAME"),
    password=os.getenv("ARANGO_PASSWORD"),
    verify=True,
)
for graph in database.graphs():
    print(graph["id"])

_graphs/Christmas_Carol
_graphs/CVE
_graphs/OPEN_INTELLIGENCE
_graphs/FLIGHTS
_graphs/SYNTHEA_P100


In [2]:
import nx_arangodb as nxadb
import networkx as nx

graph = nxadb.Graph(db=database, name="SYNTHEA_P100")

[08:58:50 +0700] [INFO]: NetworkX-cuGraph is unavailable: No module named 'cupy'.
[08:58:51 +0700] [INFO]: Graph 'SYNTHEA_P100' exists.
[08:58:52 +0700] [INFO]: Default node type set to 'allergies'


In [3]:
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from typing import Type
import json
import sys
from io import StringIO


class NetworkxAnalysisSchema(BaseModel):
    code: str = Field(
        ...,
        description="The Python code to execute NetworkX algorithms on the graph variable",
    )


class NetworkxAnalysisTool(BaseTool):
    name: str = "networkx_analysis"
    description: str = (
        "Execute Python code to run NetworkX algorithms on the graph"
    )
    args_schema: Type[BaseModel] = NetworkxAnalysisSchema

    def _run(self, code: str) -> str:
        """Execute Python code with NetworkX algorithms in a controlled environment"""
        try:
            env = {'G': graph, 'nx': nx}

            # Redirect stdout to capture print statements
            old_stdout = sys.stdout
            sys.stdout = StringIO()

            # Execute the provided Python code
            exec(code, env)

            # Capture output
            output = sys.stdout.getvalue()
            sys.stdout = old_stdout

            # Check if 'result' is defined in the environment (expected output variable)
            if "result" in env:
                return json.dumps({"results": env["result"]}, indent=2)
            elif output:
                return json.dumps({"results": output.strip()}, indent=2)
            else:
                return json.dumps(
                    {"results": "No result variable or output produced"}, indent=2
                )

        except Exception as e:
            return json.dumps(
                {"error": f"NetworkX analysis failed: {str(e)}"}, indent=2
            )


class AQLQuerySchema(BaseModel):
    code: str = Field(
        ..., description="The AQL query to execute on the graph database"
    )


class AQLQueryTool(BaseTool):
    name: str = "aql_query"
    description: str = (
        "Execute AQL queries to traverse the graph and extract structured relationships"
    )
    args_schema: Type[BaseModel] = AQLQuerySchema

    def _run(self, code: str) -> str:
        """Execute an AQL query on ArangoDB"""
        try:
            results = graph.query(code)
            return json.dumps({"results": list(results)}, indent=2)
        except Exception as e:
            return f"AQL Code Execution failed: {str(e)}"


TOOLKIT = [NetworkxAnalysisTool(), AQLQueryTool()]

In [5]:
from langchain_core.prompts import (
    HumanMessagePromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    ChatPromptTemplate,
)
from langchain_core.messages import SystemMessage

with open(r'prompts/graph_system_prompt.txt', 'r', encoding='utf-8') as f:
    SYSTEM_PROMPT = f.read()

with open("prompts/graph_schema.txt", "r") as f:
    GRAPH_SCHEMA = f.read()

PROMPTS = [
    SystemMessage(content=SYSTEM_PROMPT),
    MessagesPlaceholder(variable_name="chat_history", optional=True),
    HumanMessagePromptTemplate(
        prompt=PromptTemplate(
            input_variables=["query", "graph_schema"],
            template="KNOWLEDGE GRAPH SCHEMA: {graph_schema} \n USER QUERY: {query} ",
        )
    ),
    MessagesPlaceholder(variable_name="agent_scratchpad"),
]

CHAT_PROMPT = ChatPromptTemplate.from_messages(PROMPTS)

In [6]:
from langchain.agents import create_tool_calling_agent
from langchain.agents import AgentExecutor
from langchain_openai import ChatOpenAI

model = ChatOpenAI(temperature=0, model="gpt-4o")
agent = create_tool_calling_agent(model, TOOLKIT, CHAT_PROMPT)
agent_executor = AgentExecutor(agent=agent, tools=TOOLKIT, verbose=False)

In [8]:
response = agent_executor.invoke(
    {
        "graph_schema": GRAPH_SCHEMA,
        "query": "Find the shortest path from patient '7c2e78bd-52cf-1fce-acc3-0ddd93104abe' to the condition 'Medication review due (situation)' (CODE: 314529007) through their encounters.",
    }
)
print(response["output"])

The shortest path from the patient '7c2e78bd-52cf-1fce-acc3-0ddd93104abe' to the condition 'Medication review due (situation)' (CODE: 314529007) through their encounters is as follows:

1. **Patient Node**: `patients/7c2e78bd-52cf-1fce-acc3-0ddd93104abe`
2. **Encounter Node**: `encounters/0c5d5b18-165f-1ca2-ba19-d5aad055b6a8`
3. **Payer Node**: `payers/734afbd6-4794-363b-9bc0-6a3981533ed5`
4. **Encounter Node**: `encounters/ddd07218-f991-4103-7382-3237a1c7b220`
5. **Organization Node**: `organizations/901c2d40-1ca3-3879-9a20-c663b8adc0a9`
6. **Encounter Node**: `encounters/e18ae848-0962-bbb8-0942-3e52f3144f7d`
7. **Condition Node**: `conditions/3282`

This path shows the sequence of nodes traversed from the patient to the condition through various encounters and related entities.


### Test AQL & Python Code

In [8]:
import networkx as nx
from collections import defaultdict

# Step 3: Compute Degree Centrality (number of encounters per organization)
degree_centrality = nx.degree_centrality(graph)
org_degree_centrality = {org_id: score for org_id, score in degree_centrality.items() if org_id.startswith("organizations/")}

# Step 4: Compute PageRank (importance of organizations in the network)
pagerank = nx.pagerank(graph)
org_pagerank = {org_id: score for org_id, score in pagerank.items() if org_id.startswith("organizations/")}

[15:08:17 +0000] [INFO]: Graph 'SYNTHEA_P100' load took 14.992738723754883s
INFO:nx_arangodb:Graph 'SYNTHEA_P100' load took 14.992738723754883s


In [9]:
# Step 5: Combine metrics and rank organizations
combined_scores = defaultdict(dict)
for org_id in org_degree_centrality:
    combined_scores[org_id]["name"] = graph.nodes[org_id]["NAME"]
    combined_scores[org_id]["degree_centrality"] = org_degree_centrality[org_id]
    combined_scores[org_id]["pagerank"] = org_pagerank[org_id]
    # Optional: Add a combined score (e.g., weighted average)
    combined_scores[org_id]["combined_score"] = (org_degree_centrality[org_id] * 0.5) + (org_pagerank[org_id] * 0.5)

# Sort by combined score
sorted_orgs = sorted(combined_scores.items(), key=lambda x: x[1]["combined_score"], reverse=True)

# Step 6: Output results
print("Organizations Ranked by Overburden (Degree Centrality + PageRank):")
print("Rank | Organization Name | Degree Centrality | PageRank | Combined Score")
print("-" * 80)
for rank, (org_id, metrics) in enumerate(sorted_orgs[:30], 1):
    print(f"{rank:<4} | {metrics['name']:<17} | {metrics['degree_centrality']:.4f}")

Organizations Ranked by Overburden (Degree Centrality + PageRank):
Rank | Organization Name | Degree Centrality | PageRank | Combined Score
--------------------------------------------------------------------------------
1    | Lynn Community Based Outpatient Clinic (CBOC) | 0.0048
2    | NEW ENGLAND SINAI HOSPITAL | 0.0039
3    | CARNEY HOSPITAL   | 0.0038
4    | STEWARD GOOD SAMARITAN MEDICAL CENTER  INC. | 0.0036
5    | GOOD SAMARITAN MEDICAL CENTER | 0.0035
6    | WHITTIER REHABILTATION HOSPITAL | 0.0022
7    | AP MEDICAL LLC    | 0.0014
8    | FALMOUTH HOSPITAL ASSOCIATION INC | 0.0012
9    | Fitchburg Outpatient Clinic | 0.0011
10   | MOUNT AUBURN HOSPITAL | 0.0010
11   | CARECENTRAL URGENT CARE MEDICAL GROUP PC | 0.0009
12   | BAYSTATE NOBLE HOSPITAL CORPORATION | 0.0008
13   | LYNN URGENT CARE LLC | 0.0008
14   | TEWKSBURY HOSPITAL | 0.0008
15   | TUFTS MEDICAL CENTER | 0.0006
16   | LAWRENCE GENERAL HOSPITAL | 0.0006
17   | BROCKTON HOSPITAL  INC. | 0.0006
18   | BETH ISRAEL D

In [11]:
result = graph.query('''
WITH encounters, patients, providers, patients_to_encounters, providers_to_encounters
FOR enc IN encounters
  FILTER enc.DESCRIPTION == "Well child visit (procedure)"
  LET patients = (
    FOR pat IN 1..1 INBOUND enc patients_to_encounters
      LIMIT 10
      RETURN { patient_id: pat._key, name: CONCAT(pat.FIRST, " ", pat.LAST) }
  )
  LET providers = (
    FOR prov IN 1..1 INBOUND enc providers_to_encounters
      LIMIT 10
      RETURN { provider_id: prov._key, name: prov.NAME }
  )
  LIMIT 10
  RETURN { encounter_id: enc._key, patients: patients, providers: providers }
  ''')
list(result)[:5]

[{'encounter_id': '84d6f5d3-569c-be5d-56a4-55f6e91c8a34',
  'patients': [{'patient_id': '7c2e78bd-52cf-1fce-acc3-0ddd93104abe',
    'name': 'Shila857 Kshlerin58'}],
  'providers': [{'provider_id': '788c9178-9c9f-322d-bb05-c24373197a6f',
    'name': 'Erica194 Goyette777'}]},
 {'encounter_id': '23abd072-3cd2-0ce9-3699-d88f347b7cf4',
  'patients': [{'patient_id': 'ee070281-5df4-601c-8660-d40e7ea76def',
    'name': 'Dallas143 Mueller846'}],
  'providers': [{'provider_id': '84425bd3-a175-3886-89eb-8176ebc6a8e0',
    'name': 'Mayola305 Hauck852'}]},
 {'encounter_id': '35c39791-b641-0c48-8ef4-5743fb8004b9',
  'patients': [{'patient_id': '7c2e78bd-52cf-1fce-acc3-0ddd93104abe',
    'name': 'Shila857 Kshlerin58'}],
  'providers': [{'provider_id': '788c9178-9c9f-322d-bb05-c24373197a6f',
    'name': 'Erica194 Goyette777'}]},
 {'encounter_id': 'e2f42507-c0f7-5a84-fbca-6a24314211d1',
  'patients': [{'patient_id': '7c2e78bd-52cf-1fce-acc3-0ddd93104abe',
    'name': 'Shila857 Kshlerin58'}],
  'provide

### Test subgraph function and algorithm scalability

In [None]:
AQL_QUERY = '''
WITH patients, patients_to_encounters, encounters,
     encounters_to_conditions, encounters_to_medications,
     encounters_to_procedures, conditions, medications, procedures

// Start with a sample of 10 patients
LET patients = (
    FOR p IN patients
    LIMIT 10
    RETURN p._id
)

// Get their encounters
LET encounters = (
    FOR p IN patients
        FOR e IN patients_to_encounters
            FILTER e._from == p
            RETURN e._to
)

// Get related conditions, medications, and procedures
LET conditions = (
    FOR e IN encounters
        FOR c IN encounters_to_conditions
            FILTER c._from == e
            LIMIT 50
            RETURN c._to
)

LET medications = (
    FOR e IN encounters
        FOR m IN encounters_to_medications
            FILTER m._from == e
            LIMIT 50
            RETURN m._to
)

LET procedures = (
    FOR e IN encounters
        FOR pr IN encounters_to_procedures
            FILTER pr._from == e
            LIMIT 50
            RETURN pr._to
)

// Collect all vertex IDs
LET vertexIds = UNION(
    patients,
    encounters,
    conditions,
    medications,
    procedures
)

// Explicitly collect edges from each edge collection
LET edgeIds = UNION(
    // patients to encounters edges
    (FOR e IN patients_to_encounters
        FILTER e._from IN patients AND e._to IN encounters
        RETURN e._id),
    // encounters to conditions edges
    (FOR e IN encounters_to_conditions
        FILTER e._from IN encounters AND e._to IN conditions
        RETURN e._id),
    // encounters to medications edges
    (FOR e IN encounters_to_medications
        FILTER e._from IN encounters AND e._to IN medications
        RETURN e._id),
    // encounters to procedures edges
    (FOR e IN encounters_to_procedures
        FILTER e._from IN encounters AND e._to IN procedures
        RETURN e._id)
)

// Return the subgraph
RETURN {
    vertices: vertexIds,
    edges: edgeIds
}
'''

cursor = graph.query(AQL_QUERY)
result = list(cursor)[0]
print('no. of vertices: ', len(result['vertices']))
print('no. of edges: ', len(result['edges']))

no. of vertices:  907
no. of edges:  897


In [None]:
import networkx as nx

# Create an empty directed graph
G = nx.DiGraph()

# Function to get document attributes
def get_document_attributes(doc_id):
    collection_name = doc_id.split('/')[0]
    cursor = database.collection(collection_name).get({'_id': doc_id})
    return cursor

# Add nodes with attributes
for vertex_id in result['vertices']:
    attrs = get_document_attributes(vertex_id)
    if attrs:
        # Remove system attributes
        attrs.pop('_key', None)
        attrs.pop('_id', None)
        attrs.pop('_rev', None)
        G.add_node(vertex_id, **attrs)

# Add edges with attributes
for edge_id in result['edges']:
    edge_doc = get_document_attributes(edge_id)
    if edge_doc:
        from_vertex = edge_doc['_from']
        to_vertex = edge_doc['_to']
        # Remove system attributes
        edge_attrs = {k: v for k, v in edge_doc.items() if not k.startswith('_')}
        G.add_edge(from_vertex, to_vertex, edge_id=edge_id, **edge_attrs)

# Basic graph info
print(f"Number of nodes: {G.number_of_nodes()}")
print(f"Number of edges: {G.number_of_edges()}")

# Example of accessing attributes
for node in G.nodes(data=True):
    print(f"Node {node[0]} attributes: {node[1]}")
    break
for edge in G.edges(data=True):
    print(f"Edge {edge[0]} -> {edge[1]} attributes: {edge[2]}")
    break

Number of nodes: 907
Number of edges: 897
Node patients/01fd0320-1260-3613-95fb-7703f53e6a08 attributes: {'BIRTHDATE': '1951-08-10', 'SSN': '999-99-9040', 'DRIVERS': 'S99991675', 'PASSPORT': 'X19696800X', 'PREFIX': 'Mr.', 'FIRST': 'Frankie174', 'LAST': 'Schinner682', 'MARITAL': 'M', 'RACE': 'white', 'ETHNICITY': 'nonhispanic', 'GENDER': 'M', 'BIRTHPLACE': 'North Adams  Massachusetts  US', 'ADDRESS': '989 Kunze Orchard Unit 36', 'CITY': 'Lynn', 'STATE': 'Massachusetts', 'COUNTY': 'Essex County', 'FIPS': 25009, 'ZIP': 1907, 'LAT': 42.4402758250332, 'LON': -70.95449283370179, 'HEALTHCARE_EXPENSES': 276780.33, 'HEALTHCARE_COVERAGE': 568451.33, 'INCOME': 124300}
Edge patients/01fd0320-1260-3613-95fb-7703f53e6a08 -> encounters/dd606bd7-d284-099e-3161-558e11cd5058 attributes: {'edge_id': 'patients_to_encounters/147068', 'CODE': 185345009}


In [None]:
import time
from community import community_louvain

# Function to time and run an algorithm
def time_algorithm(name, func, *args, **kwargs):
    start_time = time.time()
    result = func(*args, **kwargs)
    end_time = time.time()
    runtime = end_time - start_time
    print(f"{name}: {runtime:.3f} seconds")
    return result

# Test 1: Degree Centrality
degree_centrality = time_algorithm(
    "Degree Centrality",
    nx.degree_centrality,
    G
)
print(f"Sample degree centrality: {list(degree_centrality.items())[:3]}")

# Test 2: Betweenness Centrality
betweenness_centrality = time_algorithm(
    "Betweenness Centrality",
    nx.betweenness_centrality,
    G
)
print(f"Top 3 betweenness: {sorted(betweenness_centrality.items(), key=lambda x: x[1], reverse=True)[:3]}")

# Test 3: Closeness Centrality
closeness_centrality = time_algorithm(
    "Closeness Centrality",
    nx.closeness_centrality,
    G
)
print(f"Sample closeness centrality: {list(closeness_centrality.items())[:3]}")

# Test 4: Louvain Community Detection (requires python-louvain)
louvain_partition = time_algorithm(
    "Louvain Community Detection",
    community_louvain.best_partition,
    G.to_undirected()  # Louvain needs undirected graph
)
communities = {}
for node, comm_id in louvain_partition.items():
    communities.setdefault(comm_id, []).append(node)
print(f"Number of communities: {len(communities)}, Sample: {list(communities.items())[:2]}")

# Test 5: Shortest Path (single source)
source_node = next(iter(G.nodes))  # First node as source
shortest_paths = time_algorithm(
    "Shortest Path (Single Source)",
    nx.shortest_path_length,
    G,
    source=source_node
)
print(f"Shortest path lengths from {source_node}: {list(shortest_paths.items())[:3]}")

Degree Centrality: 0.001 seconds
Sample degree centrality: [('patients/01fd0320-1260-3613-95fb-7703f53e6a08', 0.44150110375275936), ('patients/0213869d-8ae4-2a7c-40bd-9a80e9560189', 0.039735099337748346), ('patients/05e5e51a-5dd4-84eb-7811-fd2c93bdaa7e', 0.046357615894039736)]
Betweenness Centrality: 0.148 seconds
Top 3 betweenness: [('encounters/94676bc6-f507-a1e5-e25d-06fbfdfea94f', 1.585501201322064e-05), ('encounters/43d4594c-f2f2-83dd-7b15-79046fbedc7f', 1.463539570451136e-05), ('encounters/2acd0294-4344-b09b-bb8c-15f9d70acfb3', 1.341577939580208e-05)]
Closeness Centrality: 0.020 seconds
Sample closeness centrality: [('patients/01fd0320-1260-3613-95fb-7703f53e6a08', 0.0), ('patients/0213869d-8ae4-2a7c-40bd-9a80e9560189', 0.0), ('patients/05e5e51a-5dd4-84eb-7811-fd2c93bdaa7e', 0.0)]
Louvain Community Detection: 0.017 seconds
Number of communities: 54, Sample: [(0, ['patients/01fd0320-1260-3613-95fb-7703f53e6a08', 'encounters/4a95dc57-711a-5319-02d6-5759315d768f', 'encounters/35aba7