In [None]:
class GraphState(TypedDict):
    query: str
    results: Dict[str, str]
    next_steps: List[str]

# Node functions
def node_a(state: GraphState) -> Dict[str, Any]:
    print("Executing node_a")
    # Decide which nodes to execute next
    return {"next_steps": ["b", "c", "d"]}

async def search_and_generate(node: str, state: GraphState, prompt: PromptTemplate) -> Dict[str, str]:
    print(f"Executing {node}")
    # Perform search
    search_results = await run_vespa_search(state["query"], top_k=5, tags=node)
    
    # Generate response with LLM
    chain = prompt | llm
    response = chain.invoke({"query": state["query"], "search_results": search_results})
    return {node: response}

def create_node_function(node: str, prompt: PromptTemplate):
    def node_function(state: GraphState) -> Dict[str, Any]:
        result = asyncio.run(search_and_generate(node, state, prompt))
        return {"results": {node: result[node]}}
    return node_function

# Define prompts for each node
prompt_b = PromptTemplate.from_template(
    "Analyze the following search results for the query: {query}\n\nSearch Results: {search_results}\n\nProvide a concise summary focusing on the main points."
)

prompt_c = PromptTemplate.from_template(
    "Given the query: {query}\n\nAnd the search results: {search_results}\n\nIdentify any conflicting information or controversies in the results."
)

prompt_d = PromptTemplate.from_template(
    "For the query: {query}\n\nBased on these search results: {search_results}\n\nProvide potential future developments or implications."
)

# Create node functions
node_b = create_node_function("b", prompt_b)
node_c = create_node_function("c", prompt_c)
node_d = create_node_function("d", prompt_d)

def aggregate_results(state: GraphState) -> Dict[str, Any]:
    print("Aggregating results")
    combined_response = "\n".join([f"Node {k}: {v}" for k, v in state["results"].items()])
    
    aggregate_prompt = PromptTemplate.from_template(
        "Synthesize a comprehensive answer based on these results:\n{combined_response}\n\nProvide a well-structured and coherent response that addresses the original query: {query}"
    )
    
    chain = aggregate_prompt | llm
    final_result = chain.invoke({"combined_response": combined_response, "query": state["query"]})
    return {"final_result": final_result}

# Router function
def router(state: GraphState) -> List[str]:
    return state["next_steps"]

# Build the graph
workflow = StateGraph(GraphState)

# Add nodes
workflow.add_node("a", node_a)
workflow.add_node("b", node_b)
workflow.add_node("c", node_c)
workflow.add_node("d", node_d)
workflow.add_node("aggregate", aggregate_results)

# Add edges with conditional routing
workflow.add_conditional_edges(
    "a",
    router,
    {
        "b": "b",
        "c": "c",
        "d": "d"
    }
)
workflow.add_edge("b", "aggregate")
workflow.add_edge("c", "aggregate")
workflow.add_edge("d", "aggregate")
workflow.add_edge("aggregate", END)

# Set the entry point
workflow.set_entry_point("a")

# Compile the graph
graph = workflow.compile()

# Run the graph
query = "What are the latest advancements in quantum computing?"
initial_state = {"query": query, "results": {}, "next_steps": []}
result = graph.invoke(initial_state)
print(f"Final result: {result['final_result']}")