In [3]:
from langgraph.graph import Graph, END
from langchain_core.runnables import RunnableLambda
from langchain_core.prompts import PromptTemplate
from typing import Dict, Any
import asyncio
from neo4j import AsyncGraphDatabase
import logging

# Assume llm is imported and configured
from your_llm_config import llm

# Set up logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

class AgentState(Dict):
    query: str
    use_graph_db: bool
    graph_db_result: Dict[str, Any]

async def graph_db_router(state: AgentState) -> AgentState:
    logger.debug(f"Entering graph_db_router with query: {state['query']}")
    
    router_prompt = PromptTemplate.from_template(
        """Analyze the following query and determine if it can utilize graph relationships:
        
        Query: {query}
        
        Respond with either 'Yes' or 'No' followed by a brief explanation.
        """
    )
    
    result = await router_prompt.aformat(query=state["query"]) | llm | str
    state["use_graph_db"] = result.lower().startswith("yes")
    
    logger.debug(f"graph_db_router decision: use_graph_db = {state['use_graph_db']}")
    logger.debug(f"LLM explanation: {result}")
    
    return state

async def query_graph_db(state: AgentState) -> AgentState:
    logger.debug(f"Entering query_graph_db with use_graph_db: {state['use_graph_db']}")
    
    if not state["use_graph_db"]:
        logger.debug("Skipping graph DB query as use_graph_db is False")
        return state

    # Replace with your actual Neo4j connection details
    uri = "neo4j://localhost:7687"
    user = "neo4j"
    password = "password"

    try:
        async with AsyncGraphDatabase.driver(uri, auth=(user, password)) as driver:
            async with driver.session() as session:
                # This is a placeholder Cypher query. Replace with actual query generation logic.
                cypher_query = f"MATCH (n) WHERE n.name CONTAINS '{state['query']}' RETURN n LIMIT 5"
                logger.debug(f"Executing Cypher query: {cypher_query}")
                
                result = await session.run(cypher_query)
                state["graph_db_result"] = [record.data() for record in await result.data()]
                
                logger.debug(f"Graph DB query result: {state['graph_db_result']}")
    except Exception as e:
        logger.error(f"Error querying graph database: {str(e)}")
        state["graph_db_result"] = {"error": str(e)}

    return state

def create_debug_graph():
    workflow = Graph()

    workflow.add_node("graph_db_router", RunnableLambda(graph_db_router))
    workflow.add_node("query_graph_db", RunnableLambda(query_graph_db))

    workflow.add_edge("graph_db_router", "query_graph_db")
    workflow.add_edge("query_graph_db", END)

    workflow.set_entry_point("graph_db_router")

    return workflow.compile()

async def run_debug_graph(query: str):
    graph = create_debug_graph()
    initial_state = AgentState(
        query=query,
        use_graph_db=False,
        graph_db_result={}
    )
    result = await graph.ainvoke(initial_state)
    return result

# Example usage
if __name__ == "__main__":
    async def main():
        queries = [
            "What were Apple's financial highlights in the last quarter?",
            "List all companies in the tech sector",
            "What is the current stock price of Google?"
        ]
        
        for query in queries:
            logger.info(f"\nProcessing query: {query}")
            result = await run_debug_graph(query)
            logger.info(f"Final state for query '{query}':")
            logger.info(f"use_graph_db: {result['use_graph_db']}")
            logger.info(f"graph_db_result: {result['graph_db_result']}")
            logger.info("-" * 50)

    asyncio.run(main())

INFO:numexpr.utils:NumExpr defaulting to 8 threads.


ModuleNotFoundError: No module named 'langchain.schema.runnable'; 'langchain.schema' is not a package