### Import statements

In [None]:
import os
from typing import List, Dict, TypedDict, Literal
from dotenv import load_dotenv

from langchain_core.documents import Document
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

from langgraph.graph import StateGraph, END

load_dotenv()

### Configuration

In [None]:
CHROMA_DB_DIR = "./chroma_db"
COLLECTION_NAME = "customer_support_knowledge"
GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")

### Initializing LLM and embedding model

In [None]:
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.2, google_api_key=GOOGLE_API_KEY)
embeddings = GoogleGenerativeAIEmbeddings(
    model="models/embedding-001",
    google_api_key=GOOGLE_API_KEY
)

### Defining agent state

In [None]:
class AgentState(TypedDict):
    """
    Represents the state of our RAG agent's workflow for a single sub-query.
    """
    current_sub_query: str
    retrieved_chunks: List[Document]
    evaluated_sufficiency: bool
    evaluator_feedback: str
    retrieval_attempts: int
    next_agent_to_call: Literal['retriever_agent', 'evaluator_agent', 'END_PHASE2', 'END_PHASE2_FAILURE']

### Retriever Agent

In [None]:
def retriever_agent_node(state: AgentState) -> AgentState:
    """
    Node that retrieves relevant chunks from the vector database based on the current sub-query.
    """
    print("---RETRIEVER AGENT: Initiating retrieval---")
    current_sub_query = state["current_sub_query"]
    retrieval_attempts = state.get("retrieval_attempts", 0) + 1

    try:
        # Load the existing ChromaDB instance
        vector_db = Chroma(
            persist_directory=CHROMA_DB_DIR,
            embedding_function=embeddings, # MUST use the same embedding function
            collection_name=COLLECTION_NAME
        )
        print(f"ChromaDB loaded successfully from {CHROMA_DB_DIR}.")

        # Retrieve top K results (e.g., top 3 or 5)
        # We can dynamically adjust 'k' or other search_kwargs here if needed
        # For now, let's keep it fixed, but evaluator_feedback could refine this.
        k_value = 1 # You can experiment with this number
        retriever = vector_db.as_retriever(search_kwargs={"k": k_value})

        print(f"Retrieving for query: '{current_sub_query}' (Attempt: {retrieval_attempts})")
        retrieved_chunks = retriever.invoke(current_sub_query)

        print(f"---RETRIEVER AGENT: Retrieved {len(retrieved_chunks)} chunks.---")

        return {
            "current_sub_query": current_sub_query,
            "retrieved_chunks": retrieved_chunks,
            "retrieval_attempts": retrieval_attempts,
            "next_agent_to_call": 'evaluator_agent', # Always go to evaluator after retrieval
            "evaluated_sufficiency": False, # Reset for next evaluation
            "evaluator_feedback": "" # Reset feedback
        }

    except Exception as e:
        print(f"---RETRIEVER AGENT ERROR: {e}---")
        return {
            "current_sub_query": current_sub_query,
            "retrieved_chunks": [],
            "retrieval_attempts": retrieval_attempts,
            "next_agent_to_call": 'END_PHASE2_FAILURE', # End if retrieval fails completely
            "evaluated_sufficiency": False,
            "evaluator_feedback": f"Retrieval failed: {e}"
        }

### Evaluator agent

In [None]:
def evaluator_agent_node(state: AgentState) -> AgentState:
    """
    Node that uses Gemini to evaluate if the retrieved chunks are sufficient to answer the sub-query.
    Provides feedback if not.
    """
    print("---EVALUATOR AGENT: Evaluating retrieved chunks---")
    current_sub_query = state["current_sub_query"]
    retrieved_chunks = state["retrieved_chunks"]
    retrieval_attempts = state["retrieval_attempts"]

    if retrieval_attempts > 2:
        print("---EVALUATOR AGENT: Max retrieval attempts reached. Marking as insufficient.---")
        return {
            "current_sub_query": current_sub_query,
            "retrieved_chunks": retrieved_chunks,
            "evaluated_sufficiency": False,
            "evaluator_feedback": "Max retrieval attempts reached. Marking as insufficient.",
            "retrieval_attempts": retrieval_attempts,
            "next_agent_to_call": 'END_PHASE2_FAILURE'
        }

    if not retrieved_chunks:
        print("---EVALUATOR AGENT: No chunks retrieved, marking as insufficient.---")
        return {
            "current_sub_query": current_sub_query,
            "retrieved_chunks": retrieved_chunks,
            "evaluated_sufficiency": False,
            "evaluator_feedback": "No relevant documents were retrieved for this sub-query. Try rephrasing the sub-query or check if the information exists in the knowledge base.",
            "retrieval_attempts": retrieval_attempts,
            "next_agent_to_call": 'retriever_agent' # Try retrieval again, will be capped by max_retries
        }

    # Format chunks for the LLM prompt
    formatted_chunks = "\n---\n".join([doc.page_content for doc in retrieved_chunks])
    
    # Define the evaluation prompt for Gemini
    eval_prompt = ChatPromptTemplate.from_messages([
        ("system", 
         """You are an expert evaluator for a RAG system. Your task is to determine if the provided 'CONTEXT' is sufficient and relevant to fully and comprehensively answer the 'SUB-QUERY'.

         If the CONTEXT is sufficient and relevant, respond ONLY with the word 'YES'.
         If the CONTEXT is NOT sufficient or relevant, respond ONLY with the word 'NO', followed by a concise, specific suggestion on how to improve the retrieval for this sub-query. For example:
         - 'NO: The results are too general; focus on error codes.'
         - 'NO: No specific instructions found for this type of issue.'
         - 'NO: The context mentions the topic but lacks actionable steps.'

         Aim for direct, actionable feedback. Do not elaborate beyond 'NO: [feedback]'.
         """),
        ("human", 
         f"SUB-QUERY: {current_sub_query}\n\nCONTEXT:\n{formatted_chunks}")
    ])

    eval_chain = eval_prompt | llm | StrOutputParser()

    print(f"---EVALUATOR AGENT: Sending evaluation request to Gemini for query '{current_sub_query}'---")
    gemini_response = eval_chain.invoke({"current_sub_query": current_sub_query, "formatted_chunks": formatted_chunks})
    
    gemini_response = gemini_response.strip()
    print(f"---EVALUATOR AGENT: Gemini's raw response: {gemini_response}---")

    is_sufficient = gemini_response.upper().startswith("YES")
    feedback = ""
    if not is_sufficient and len(gemini_response) > 3: # "NO" is 2 chars, so look for more
        feedback = gemini_response[3:].strip() # Remove "NO:" prefix

    print(f"---EVALUATOR AGENT: Sufficiency: {is_sufficient}, Feedback: '{feedback}'---")

    next_step: Literal['retriever_agent', 'evaluator_agent', 'END_PHASE2', 'END_PHASE2_FAILURE']
    MAX_RETRIEVAL_ATTEMPTS = 2 # Define maximum attempts

    if is_sufficient:
        next_step = 'END_PHASE2' # Good enough, proceed to next phase (or final answer generation)
    elif retrieval_attempts >= MAX_RETRIEVAL_ATTEMPTS:
        print(f"---EVALUATOR AGENT: Max retrieval attempts ({MAX_RETRIEVAL_ATTEMPTS}) reached. Ending phase for this query.---")
        next_step = 'END_PHASE2_FAILURE' # Couldn't find sufficient info after retries
    else:
        next_step = 'retriever_agent' # Not sufficient, try retrieval again (perhaps with refined query in future phases)

    return {
        "current_sub_query": current_sub_query,
        "retrieved_chunks": retrieved_chunks,
        "evaluated_sufficiency": is_sufficient,
        "evaluator_feedback": feedback,
        "retrieval_attempts": retrieval_attempts,
        "next_agent_to_call": next_step
    }

In [None]:
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("retriever_agent", retriever_agent_node)
workflow.add_node("evaluator_agent", evaluator_agent_node)

# Set entry point
workflow.set_entry_point("retriever_agent")

# Define edges (transitions)
workflow.add_edge("retriever_agent", "evaluator_agent") # Retriever always goes to Evaluator

# Conditional edge from evaluator_agent
def decide_next_step(state: AgentState) -> Literal['retriever_agent', 'END_PHASE2', 'END_PHASE2_FAILURE']:
    """
    Determines the next step based on the evaluator's decision.
    """
    if state["evaluated_sufficiency"]:
        return 'END_PHASE2'
    elif state["next_agent_to_call"] == 'END_PHASE2_FAILURE':
        return 'END_PHASE2_FAILURE' # Propagate failure state
    else:
        # Here, in a more advanced phase, you might have a 'query_refiner_agent'
        # For now, it simply loops back to retriever, which will be capped by max attempts.
        return 'retriever_agent' 

workflow.add_conditional_edges(
    "evaluator_agent",
    decide_next_step,
    {
        "retriever_agent": "retriever_agent",  # Loop back for another retrieval attempt
        "END_PHASE2": END,                     # Exit the graph if sufficient
        "END_PHASE2_FAILURE": END              # Exit if attempts exhausted or critical failure
    }
)

# Compile the graph
app = workflow.compile()

app

In [None]:
print("\n--- Running Phase 2: Core Retrieval & Evaluation Loop ---")

# Test with a sub-query
test_sub_query = "My QuantumFlow purifier is showing a red light on its filter status indicator. What does this mean, and what should I do"
# test_sub_query = "What is the capital of France?" # Example of an out-of-scope query

initial_state: AgentState = {
    "current_sub_query": test_sub_query,
    "retrieved_chunks": [],
    "evaluated_sufficiency": False,
    "evaluator_feedback": "",
    "retrieval_attempts": 0,
    "next_agent_to_call": 'retriever_agent' # Initial call
}

# Run the graph
# We iterate through the states to see the progression
final_state = None
for s in app.stream(initial_state):
    print(f"\nCurrent State after node execution: {s}")
    final_state = s

print("\n--- Phase 2 Execution Complete ---")
if final_state:
    # Get the actual last state dict (s is a dict of {node_name: state_update})
    # We need to find the final state by combining all updates
    # LangGraph typically returns the accumulated state at each step.
    # So, the last 's' in the loop will contain the full final state of the graph's memory.
    
    # Let's verify the final_state which would be the last item from the stream
    last_node_name = list(final_state.keys())[-1]
    final_state_values = final_state[last_node_name]


    print(f"\nFinal State Summary for Query: '{final_state_values['current_sub_query']}'")
    print(f"  Evaluated Sufficient: {final_state_values['evaluated_sufficiency']}")
    print(f"  Retrieval Attempts: {final_state_values['retrieval_attempts']}")
    print(f"  Evaluator Feedback: '{final_state_values['evaluator_feedback']}'")
    
    if final_state_values['evaluated_sufficiency']:
        print("  Retrieved Chunks (Content Preview):")
        for i, chunk in enumerate(final_state_values['retrieved_chunks']):
            print(f"    Chunk {i+1} (Source: {chunk.metadata.get('source', 'N/A')}): {chunk.page_content}...")
    else:
        print("  Could not find sufficient information for this query.")