### Import statements


In [None]:
import os
from typing import List, Literal, TypedDict

from dotenv import load_dotenv
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langgraph.graph import END, StateGraph

load_dotenv()

### Configuration


In [None]:
CHROMA_DB_DIR = "./chroma_db"
COLLECTION_NAME = "customer_support_knowledge"
GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")

### Initializing LLM and embedding model


In [None]:
llm = ChatGoogleGenerativeAI(
    model="gemini-2.0-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY
)
embeddings = GoogleGenerativeAIEmbeddings(
    model="models/embedding-001", google_api_key=GOOGLE_API_KEY
)

### Defining agent state


In [None]:
class AgentState(TypedDict):
    """
    Represents the state of our RAG agent's overall workflow.
    """

    original_query: str
    sub_queries_list: List[str]
    current_sub_query_index: int
    current_sub_query: str  # The specific sub-query being worked on now
    retrieved_chunks: List[Document]
    evaluated_sufficiency: bool
    evaluator_feedback: str
    retrieval_attempts: int
    final_answer_draft: str
    report_formatted: str
    accumulated_relevant_chunks: List[
        Document
    ]  # Accumulates chunks from all answered sub-queries
    unanswerable_sub_queries: List[
        str
    ]  # To track sub-queries that couldn't be answered
    next_agent_to_call: Literal[
        "research_agent",
        "retriever_agent",
        "evaluator_agent",
        "formatter_agent",
        "synthesizer_agent",  # Will be used in Phase 4
        "END",  # Final completion
        "FATAL_ERROR",  # For unrecoverable errors
    ]

### Load Vector DB

In [None]:
def get_vector_db():
    """Helper function to load the ChromaDB instance."""
    return Chroma(
        persist_directory=CHROMA_DB_DIR,
        embedding_function=embeddings,
        collection_name=COLLECTION_NAME,
    )

### Research agent

In [None]:
def research_agent_node(state: AgentState) -> AgentState:
    """
    Node responsible for breaking down the original query into sub-queries
    and managing the flow of processing them.
    """
    print("---RESEARCH AGENT: Managing research plan---")

    original_query = state["original_query"]
    sub_queries_list = state.get("sub_queries_list", [])
    current_sub_query_index = state.get("current_sub_query_index", 0)

    # Check if all sub-queries are processed
    if current_sub_query_index >= len(sub_queries_list) and len(sub_queries_list) > 0:
        print("---RESEARCH AGENT: All sub-queries processed. Moving to synthesis.---")
        return {
            **state,
            "next_agent_to_call": "synthesizer_agent",
        }  # Transition to synthesizer in Phase 4

    # Initial breakdown of query if not already done
    if not sub_queries_list:
        print(f"---RESEARCH AGENT: Breaking down original query: '{original_query}'---")
        breakdown_prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    """You are a research planner. Your task is to break down a complex user query into 3-5 concise, distinct, and answerable sub-questions.
             Each sub-question should be focused enough to be answered by retrieving information.
             Respond with a comma-separated list of sub-questions ONLY. Do not add any other text or numbering.
             Example: "What are the common symptoms of flu?, How is flu transmitted?, What are flu prevention methods?"
             """,
                ),
                ("human", f"Break down the query: '{original_query}'"),
            ]
        )
        breakdown_chain = breakdown_prompt | llm | StrOutputParser()

        try:
            raw_sub_queries = breakdown_chain.invoke({"original_query": original_query})
            sub_queries_list = [
                q.strip() for q in raw_sub_queries.split(",") if q.strip()
            ]

            if not sub_queries_list:
                raise ValueError("Gemini did not return any sub-queries.")

            print(f"---RESEARCH AGENT: Generated sub-queries: {sub_queries_list}---")

        except Exception as e:
            print(f"---RESEARCH AGENT ERROR: Failed to break down query: {e}---")
            return {
                **state,
                "next_agent_to_call": "FATAL_ERROR",
                "evaluator_feedback": f"Failed to break down query: {e}",
            }

    # Set the current sub-query to process
    current_sub_query = sub_queries_list[current_sub_query_index]
    print(
        f"---RESEARCH AGENT: Processing sub-query {current_sub_query_index + 1}/{len(sub_queries_list)}: '{current_sub_query}'---"
    )

    return {
        **state,
        "sub_queries_list": sub_queries_list,
        "current_sub_query_index": current_sub_query_index,
        "current_sub_query": current_sub_query,
        "retrieval_attempts": 0,  # Reset attempts for new sub-query
        "retrieved_chunks": [],  # Clear previous chunks
        "evaluated_sufficiency": False,  # Reset evaluation
        "evaluator_feedback": "",  # Clear previous feedback
        "next_agent_to_call": "retriever_agent",  # Go to retrieval for this sub-query
    }

### Retriever Agent


In [None]:
def retriever_agent_node(state: AgentState) -> AgentState:
    """
    Node that retrieves relevant chunks from the vector database based on the current sub-query.
    """
    print("---RETRIEVER AGENT: Initiating retrieval---")
    current_sub_query = state["current_sub_query"]
    retrieval_attempts = state.get("retrieval_attempts", 0) + 1

    try:
        vector_db = get_vector_db()  # Use helper function
        print(f"ChromaDB loaded successfully from {CHROMA_DB_DIR}.")

        k_value = 1  # You can experiment with this number
        retriever = vector_db.as_retriever(search_kwargs={"k": k_value})

        print(
            f"Retrieving for query: '{current_sub_query}' (Attempt: {retrieval_attempts})"
        )
        retrieved_chunks = retriever.invoke(current_sub_query)

        print(f"---RETRIEVER AGENT: Retrieved {len(retrieved_chunks)} chunks.---")

        return {
            **state,  # Preserve existing state
            "retrieved_chunks": retrieved_chunks,
            "retrieval_attempts": retrieval_attempts,
            "next_agent_to_call": "evaluator_agent",
            "evaluated_sufficiency": False,
            "evaluator_feedback": "",
        }

    except Exception as e:
        print(f"---RETRIEVER AGENT ERROR: {e}---")
        return {
            **state,
            "retrieved_chunks": [],
            "retrieval_attempts": retrieval_attempts,
            "next_agent_to_call": "FATAL_ERROR",  # Fatal error if retrieval setup fails
            "evaluator_feedback": f"Retrieval failed: {e}",
        }

### Evaluator agent


In [None]:
def evaluator_agent_node(state: AgentState) -> AgentState:
    """
    Node that uses Gemini to evaluate if the retrieved chunks are sufficient to answer the sub-query.
    Provides feedback if not.
    """
    print("---EVALUATOR AGENT: Evaluating retrieved chunks---")
    current_sub_query = state["current_sub_query"]
    retrieved_chunks = state["retrieved_chunks"]
    retrieval_attempts = state["retrieval_attempts"]
    accumulated_relevant_chunks = state.get("accumulated_relevant_chunks", [])
    unanswerable_sub_queries = state.get("unanswerable_sub_queries", [])
    current_sub_query_index = state.get("current_sub_query_index", 0)

    MAX_RETRIEVAL_ATTEMPTS = 1  # Define maximum attempts for a single sub-query

    if not retrieved_chunks:
        print("---EVALUATOR AGENT: No chunks retrieved.---")
        if retrieval_attempts >= MAX_RETRIEVAL_ATTEMPTS:
            print(
                f"---EVALUATOR AGENT: Max attempts reached for '{current_sub_query}'. Marking as unanswerable.---"
            )
            unanswerable_sub_queries.append(current_sub_query)
            return {
                **state,
                "evaluated_sufficiency": False,
                "evaluator_feedback": "Max retrieval attempts reached; no relevant documents found.",
                "unanswerable_sub_queries": unanswerable_sub_queries,
                "current_sub_query_index": current_sub_query_index
                + 1,  # Move to next sub-query
                "next_agent_to_call": "research_agent",  # Let research agent determine next sub-query or conclude
            }
        else:
            print("---EVALUATOR AGENT: No chunks retrieved, retrying retrieval.---")
            return {
                **state,
                "evaluated_sufficiency": False,
                "evaluator_feedback": "No relevant documents were retrieved. Reattempting retrieval.",
                "next_agent_to_call": "retriever_agent",  # Loop back to retriever
            }

    # Format chunks for the LLM prompt
    formatted_chunks = "\n---\n".join([doc.page_content for doc in retrieved_chunks])

    # Define the evaluation prompt for Gemini
    eval_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are an expert evaluator for a RAG system. Your task is to determine if the provided 'CONTEXT' is sufficient and relevant to fully and comprehensively answer the 'SUB-QUERY'.

         If the CONTEXT is sufficient and relevant, respond ONLY with the word 'YES'.
         If the CONTEXT is NOT sufficient or relevant, respond ONLY with the word 'NO', followed by a concise, specific suggestion on how to improve the retrieval for this sub-query. For example:
         - 'NO: The results are too general; focus on error codes.'
         - 'NO: No specific instructions found for this type of issue.'
         - 'NO: The context mentions the topic but lacks actionable steps.'

         Aim for direct, actionable feedback. Do not elaborate beyond 'NO: [feedback]'.
         """,
            ),
            (
                "human",
                f"SUB-QUERY: {current_sub_query}\n\nCONTEXT:\n{formatted_chunks}",
            ),
        ]
    )

    eval_chain = eval_prompt | llm | StrOutputParser()

    print(
        f"---EVALUATOR AGENT: Sending evaluation request to Gemini for query '{current_sub_query}'---"
    )
    gemini_response = eval_chain.invoke(
        {"current_sub_query": current_sub_query, "formatted_chunks": formatted_chunks}
    )

    gemini_response = gemini_response.strip()
    print(f"---EVALUATOR AGENT: Gemini's raw response: {gemini_response}---")

    is_sufficient = gemini_response.upper().startswith("YES")
    feedback = ""
    if (
        not is_sufficient and len(gemini_response) > 3
    ):  # "NO" is 2 chars, so look for more
        feedback = gemini_response[3:].strip()  # Remove "NO:" prefix

    print(
        f"---EVALUATOR AGENT: Sufficiency: {is_sufficient}, Feedback: '{feedback}'---"
    )

    next_step: Literal["retriever_agent", "research_agent", "FATAL_ERROR"]

    if is_sufficient:
        # Add chunks to accumulated results
        accumulated_relevant_chunks.extend(retrieved_chunks)
        print(
            f"---EVALUATOR AGENT: Chunks deemed sufficient. Accumulated {len(retrieved_chunks)} new chunks. Total accumulated: {len(accumulated_relevant_chunks)}---"
        )
        next_step = (
            "research_agent"  # Move to next sub-query (handled by research_agent)
        )
        current_sub_query_index += 1  # Increment for next sub-query
    elif retrieval_attempts >= MAX_RETRIEVAL_ATTEMPTS:
        print(
            f"---EVALUATOR AGENT: Max retrieval attempts ({MAX_RETRIEVAL_ATTEMPTS}) reached for '{current_sub_query}'. Marking as unanswerable.---"
        )
        unanswerable_sub_queries.append(current_sub_query)
        next_step = (
            "research_agent"  # Move to next sub-query (handled by research_agent)
        )
        current_sub_query_index += 1  # Increment for next sub-query
    else:
        next_step = "retriever_agent"  # Not sufficient, try retrieval again

    return {
        **state,
        "evaluated_sufficiency": is_sufficient,
        "evaluator_feedback": feedback,
        "accumulated_relevant_chunks": accumulated_relevant_chunks,
        "unanswerable_sub_queries": unanswerable_sub_queries,
        "current_sub_query_index": current_sub_query_index,  # Update index if moving on
        "next_agent_to_call": next_step,
    }

### Synthesizer agent

In [None]:
def synthesizer_agent_node(state: AgentState) -> AgentState:
    """
    Node that synthesizes all accumulated relevant chunks into a comprehensive answer
    for the original query.
    """
    print("---SYNTHESIZER AGENT: Generating final answer draft---")
    original_query = state["original_query"]
    accumulated_chunks = state["accumulated_relevant_chunks"]
    unanswerable_sub_queries = state["unanswerable_sub_queries"]

    if not accumulated_chunks:
        final_answer_draft = "I could not find sufficient information in my knowledge base to answer your query. "
        if unanswerable_sub_queries:
            final_answer_draft += f"Specifically, I couldn't find answers for: {', '.join(unanswerable_sub_queries)}. "
        final_answer_draft += (
            "Please ensure the information is available in the provided documents."
        )
        print("---SYNTHESIZER AGENT: No chunks accumulated, generating apology.---")
        return {
            **state,
            "final_answer_draft": final_answer_draft,
            "next_agent_to_call": "formatter_agent",
        }

    # Combine all relevant content
    combined_content = "\n\n".join([doc.page_content for doc in accumulated_chunks])

    # Craft the synthesis prompt for Gemini
    synthesis_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are a helpful customer support AI. Your task is to synthesize the provided 'CONTEXT' to answer the 'ORIGINAL_QUERY' comprehensively and clearly.
         
         Only use information directly present in the CONTEXT. Do not make up information.
         Structure your answer professionally, starting with a direct response, potentially using bullet points for steps or lists, and concluding politely.
         
         If any parts of the ORIGINAL_QUERY could not be addressed by the CONTEXT (and are listed as 'UNANSWERABLE_SUB_QUERIES'), acknowledge this gracefully.

         Example of good structure:
         "Thank you for your question about [Topic]. Based on the information I have, here is the answer:
         [Direct answer and details, using bullet points for steps]
         
         Regarding [unanswerable part], I was unable to find specific details in my knowledge base. Please check the latest manual or contact support for further assistance.
         
         I hope this helps!"
         """,
            ),
            (
                "human",
                f"ORIGINAL_QUERY: {original_query}\n\nCONTEXT:\n{combined_content}\n\nUNANSWERABLE_SUB_QUERIES: {', '.join(unanswerable_sub_queries) if unanswerable_sub_queries else 'None'}",
            ),
        ]
    )

    synthesis_chain = synthesis_prompt | llm | StrOutputParser()

    print(
        f"---SYNTHESIZER AGENT: Sending synthesis request to Gemini for query '{original_query}'---"
    )
    final_answer_draft = synthesis_chain.invoke(
        {
            "original_query": original_query,
            "combined_content": combined_content,
            "unanswerable_sub_queries": unanswerable_sub_queries,
        }
    )

    print("---SYNTHESIZER AGENT: Draft generated.---")

    return {
        **state,
        "final_answer_draft": final_answer_draft,
        "next_agent_to_call": "formatter_agent",  # Move to formatter for polish
    }

### Formatter agent

In [None]:
def formatter_agent_node(state: AgentState) -> AgentState:
    """
    Node that refines and formats the final answer draft for clarity and presentation.
    """
    print("---FORMATTER AGENT: Polishing final answer---")
    original_query = state["original_query"]
    final_answer_draft = state["final_answer_draft"]

    if not final_answer_draft:
        print("---FORMATTER AGENT: No draft to format.---")
        return {
            **state,
            "report_formatted": "An error occurred during answer generation.",
            "next_agent_to_call": "END",
        }

    format_prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are a polite and professional customer support assistant. Your task is to take a raw answer draft and refine it for clarity, grammar, and tone.
         Ensure it is easy to read, uses appropriate formatting (like bolding keywords, clear paragraphs, bullet points for steps), and maintains a helpful, empathetic tone.
         Do not add new information or remove critical details. Simply rephrase and format.
         Ensure a polite opening and closing.
         """,
            ),
            (
                "human",
                f"Refine the following answer draft for the query '{original_query}':\n\n{final_answer_draft}",
            ),
        ]
    )

    format_chain = format_prompt | llm | StrOutputParser()

    print("---FORMATTER AGENT: Sending formatting request to Gemini.---")
    report_formatted = format_chain.invoke(
        {"original_query": original_query, "final_answer_draft": final_answer_draft}
    )

    print("---FORMATTER AGENT: Answer formatted.---")

    return {
        **state,
        "report_formatted": report_formatted,
        "next_agent_to_call": "END",  # All done!
    }

### Supervisor agent

In [None]:
def supervisor_node(state: AgentState) -> str:
    """
    The supervisor node orchestrates the flow based on the 'next_agent_to_call' in the state.
    It also checks for overall completion of research.
    """
    print(f"\n---SUPERVISOR: Directing flow to: {state['next_agent_to_call']}---")

    # If an agent signals END or FATAL_ERROR, the supervisor transitions to the graph END
    if state["next_agent_to_call"] in ["END", "FATAL_ERROR"]:
        print(
            "---SUPERVISOR: Workflow complete or fatal error detected. Ending workflow.---"
        )
        return {"next_agent_to_call": "end_workflow"}  # Transition to END

    # Otherwise, direct to the agent specified in next_agent_to_call
    return {"next_agent_to_call": state["next_agent_to_call"]}

### Defining workflow

In [None]:
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("supervisor", supervisor_node)
workflow.add_node("research_agent", research_agent_node)
workflow.add_node("retriever_agent", retriever_agent_node)
workflow.add_node("evaluator_agent", evaluator_agent_node)
workflow.add_node("synthesizer_agent", synthesizer_agent_node)  # NEW
workflow.add_node("formatter_agent", formatter_agent_node)  # NEW


# Set entry point
workflow.set_entry_point("supervisor")

# Add edges
# All agent nodes transition back to the supervisor
workflow.add_edge("research_agent", "supervisor")
workflow.add_edge("retriever_agent", "supervisor")
workflow.add_edge("evaluator_agent", "supervisor")
workflow.add_edge("synthesizer_agent", "supervisor")  # NEW
workflow.add_edge("formatter_agent", "supervisor")  # NEW


# Define conditional transitions from the supervisor
workflow.add_conditional_edges(
    "supervisor",
    lambda state: state[
        "next_agent_to_call"
    ],  # The supervisor's decision is in this state variable
    {
        "research_agent": "research_agent",
        "retriever_agent": "retriever_agent",
        "evaluator_agent": "evaluator_agent",
        "synthesizer_agent": "synthesizer_agent",  # NEW routing for synthesizer
        "formatter_agent": "formatter_agent",  # NEW routing for formatter
        "END": END,
        "FATAL_ERROR": END,  # End workflow on fatal error
        "end_workflow": END,  # Explicit end if supervisor decided to end
    },
)

# Compile the graph
app = workflow.compile()

app

In [None]:
print("\n--- Running Phase 4: Synthesis & Refinement ---")

# Test with a complex original query
test_original_query = "My QuantumFlow purifier is showing a red light on its filter status indicator. What does this mean, and what should I do"
# test_original_query = "What is the history of the internet?" # Example out-of-scope/no data

initial_state: AgentState = {
    "original_query": test_original_query,
    "sub_queries_list": [],
    "current_sub_query_index": 0,
    "current_sub_query": "",
    "retrieved_chunks": [],
    "evaluated_sufficiency": False,
    "evaluator_feedback": "",
    "retrieval_attempts": 0,
    "accumulated_relevant_chunks": [],
    "unanswerable_sub_queries": [],
    "final_answer_draft": "",
    "report_formatted": "",
    "next_agent_to_call": "research_agent",  # Start by asking research agent to break down query
}

# Run the graph
final_state_summary = initial_state.copy()  # Start with a copy of initial state

for s in app.stream(initial_state, config={"recursion_limit": 200}):
    # Update the state summary with the latest changes from each node
    for _key, value in s.items():
        if isinstance(
            value, dict
        ):  # LangGraph updates often come as dicts for each node
            final_state_summary.update(value)  # Merge updates

    last_node_executed = list(s.keys())[-1]
    print(
        f"\nState after '{last_node_executed}': next_agent_to_call={final_state_summary.get('next_agent_to_call', 'N/A')}"
    )


print("\n--- Phase 4 Execution Complete ---")
if final_state_summary:
    print(
        f"\nFinal Answer for Original Query: '{final_state_summary['original_query']}'"
    )
    print(
        f"  Total Sub-Queries Generated: {len(final_state_summary['sub_queries_list'])}"
    )
    print(
        f"  Accumulated Relevant Chunks: {len(final_state_summary['accumulated_relevant_chunks'])} chunks"
    )
    if final_state_summary["unanswerable_sub_queries"]:
        print(
            f"  Unanswerable Sub-Queries: {final_state_summary['unanswerable_sub_queries']}"
        )

    print("\n--- GENERATED ANSWER ---")
    print(final_state_summary["report_formatted"])
    print("\n------------------------")

    print(f"\nFinal Next Agent: {final_state_summary['next_agent_to_call']}")
else:
    print("No final state accumulated.")