In [3]:
from langgraph.graph import Graph, END
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.output_parsers import CommaSeparatedListOutputParser
from langchain_core.prompts import PromptTemplate
from typing import Dict, List, Any
import asyncio
from neo4j import AsyncGraphDatabase  # Assuming use of neo4j for graph database

# Import subgraph implementations
from subgraphs import run_ect_subgraph, run_ac_subgraph

# Assume llm is imported and configured
from your_llm_config import llm

class AgentState(Dict):
    query: str
    processed_query: str
    use_graph_db: bool
    graph_db_result: Dict
    expanded_queries: List[str]
    subgraphs_to_execute: Dict[str, List[str]]  # Maps queries to subgraphs
    subgraph_results: Dict[str, Any]
    final_response: str

# ... (keep preprocess_query function as before)

async def graph_db_router(state: AgentState) -> AgentState:
    router_prompt = PromptTemplate.from_template(
        """Analyze the following query and determine if it can utilize graph relationships:

        Query: {query}

        Respond with either 'Yes' or 'No' followed by a brief explanation.
        """
    )

    result = await router_prompt.aformat(query=state["processed_query"]) | llm | str
    state["use_graph_db"] = result.lower().startswith("yes")
    return state

async def query_graph_db(state: AgentState) -> AgentState:
    if not state["use_graph_db"]:
        return state

    # Replace with your actual Neo4j connection details
    uri = "neo4j://localhost:7687"
    user = "neo4j"
    password = "password"

    async with AsyncGraphDatabase.driver(uri, auth=(user, password)) as driver:
        async with driver.session() as session:
            # This is a placeholder Cypher query. Replace with actual query generation logic.
            cypher_query = f"MATCH (n) WHERE n.name CONTAINS '{state['processed_query']}' RETURN n LIMIT 5"
            result = await session.run(cypher_query)
            state["graph_db_result"] = [record.data() for record in await result.data()]

    return state

async def expand_queries(state: AgentState) -> AgentState:
    expansion_prompt = PromptTemplate.from_template(
        """Based on the original query and the graph database results, generate a list of expanded queries 
        that cover different aspects of the information needed. Separate queries with commas.

        Original Query: {query}
        Graph Database Results: {graph_results}

        Expanded Queries:"""
    )

    expanded_queries = await expansion_prompt.aformat(
        query=state["processed_query"],
        graph_results=state["graph_db_result"]
    ) | llm | CommaSeparatedListOutputParser()

    state["expanded_queries"] = expanded_queries
    return state

async def subgraph_router(state: AgentState) -> AgentState:
    router_prompt = PromptTemplate.from_template(
        """For each of the following queries, determine which subgraphs should be executed.
        Available subgraphs: ect_subgraph (Earning Call Transcripts), ac_subgraph (Analyst Commentary).
        
        Queries:
        {queries}
        
        Respond with a JSON object where keys are the queries and values are lists of subgraphs to execute.
        """
    )
    
    routing_result = await router_prompt.aformat(queries="\n".join(state["expanded_queries"])) | llm | str
    # Assume the LLM returns a valid JSON string. In practice, you might need more robust parsing.
    state["subgraphs_to_execute"] = eval(routing_result)
    return state

async def execute_subgraphs(state: AgentState) -> AgentState:
    subgraph_map = {
        "ect_subgraph": run_ect_subgraph,
        "ac_subgraph": run_ac_subgraph
    }
    
    results = {}
    for query, subgraphs in state["subgraphs_to_execute"].items():
        query_results = {}
        for subgraph in subgraphs:
            if subgraph in subgraph_map:
                query_results[subgraph] = await subgraph_map[subgraph](query, state["graph_db_result"])
        results[query] = query_results
    
    state["subgraph_results"] = results
    return state

async def aggregate_results(state: AgentState) -> AgentState:
    aggregation_prompt = PromptTemplate.from_template(
        """Synthesize a comprehensive response based on the following information:
        
        Original Query: {original_query}
        Graph Database Results: {graph_db_result}
        Expanded Queries and Subgraph Results: {subgraph_results}
        
        Provide a coherent and informative response that addresses the original query.
        """
    )
    
    state["final_response"] = await aggregation_prompt.aformat(
        original_query=state["query"],
        graph_db_result=state["graph_db_result"],
        subgraph_results=state["subgraph_results"]
    ) | llm | str
    
    return state

def create_finance_agent_graph():
    workflow = Graph()

    workflow.add_node("preprocess", RunnableLambda(preprocess_query))
    workflow.add_node("graph_db_router", RunnableLambda(graph_db_router))
    workflow.add_node("query_graph_db", RunnableLambda(query_graph_db))
    workflow.add_node("expand_queries", RunnableLambda(expand_queries))
    workflow.add_node("subgraph_router", RunnableLambda(subgraph_router))
    workflow.add_node("execute_subgraphs", RunnableLambda(execute_subgraphs))
    workflow.add_node("aggregate_results", RunnableLambda(aggregate_results))

    workflow.add_edge("preprocess", "graph_db_router")
    workflow.add_edge("graph_db_router", "query_graph_db")
    workflow.add_edge("query_graph_db", "expand_queries")
    workflow.add_edge("expand_queries", "subgraph_router")
    workflow.add_edge("subgraph_router", "execute_subgraphs")
    workflow.add_edge("execute_subgraphs", "aggregate_results")
    workflow.add_edge("aggregate_results", END)

    workflow.set_entry_point("preprocess")

    return workflow.compile()

async def run_finance_agent(query: str):
    graph = create_finance_agent_graph()
    initial_state = AgentState(
        query=query,
        processed_query="",
        use_graph_db=False,
        graph_db_result={},
        expanded_queries=[],
        subgraphs_to_execute={},
        subgraph_results={},
        final_response=""
    )
    result = await graph.ainvoke(initial_state)
    return result["final_response"]


INFO:numexpr.utils:NumExpr defaulting to 8 threads.


ModuleNotFoundError: No module named 'langchain.schema.runnable'; 'langchain.schema' is not a package