# Agent-Based RAG with LangGraph

This notebook demonstrates an agent-based approach to RAG (Retrieval Augmented Generation) using LangGraph.
The agent can dynamically decide when to search for information and can make multiple searches if needed.

In [56]:
# Cell 0: Tool Definition
from confluence_rag_integration import query_customer
from langchain.tools import tool
from typing import List, Optional
import json

@tool
def retrieve_knowledge(query: str, customer_id: str, top_k) -> str:
    """
    Search and retrieve relevant documentation from Confluence knowledge base.
    Use this when you need to find information about products, procedures, 
    troubleshooting steps, or any documented knowledge to help answer questions.
    
    Args:
        query: The search query or question to find relevant documents
        customer_id: The customer ID to search within (default: acme_corp)
        top_k: Number of top results to return (default: 3)
    Returns:
        A formatted string containing the retrieved documents
    """
    result = query_customer(customer_id, query, top_k)
    
    # Format results for the agent
    formatted_results = []
    for i, doc in enumerate(result.documents, 1):
        source = doc.get("source", "Unknown")
        content = doc.get("content", "")
        formatted_results.append(f"Document {i}:\nSource: {source}\n{content}")
    
    if not formatted_results:
        return "No relevant documents found for the query."
    
    return "\n\n---\n\n".join(formatted_results)

In [57]:
# Cell 1: Agent State Definition
from typing import TypedDict, Annotated, List, Optional
import operator
from langchain_core.messages import BaseMessage

class AgentState(TypedDict):
    """State for the RAG agent."""
    messages: Annotated[List[BaseMessage], operator.add]
    customer_id: str
    thread_id: Optional[str]  # For conversation persistence

In [58]:
# Cell 2: Initialize LLM with Tool Binding
from langchain.chat_models import init_chat_model

# Initialize the LLM
llm = init_chat_model("gemini-2.5-flash", model_provider="google_genai")

# Bind the retrieval tool to the LLM
tools = [retrieve_knowledge]
llm_with_tools = llm.bind_tools(tools)

In [59]:
# Cell 3: Define Graph Nodes
from langchain_core.messages import ToolMessage

def agent_node(state: AgentState) -> dict:
    """Agent node that decides whether to search or respond."""
    messages = state["messages"]
    
    # Add system message for better agent behavior
    system_prompt = """You are a helpful support agent with access to a knowledge retrieval tool.
    Use the retrieve_knowledge tool to search for relevant documentation when answering questions.
    You can call the tool multiple times if needed to gather comprehensive information.
    Always search for information before providing an answer unless the question is clearly conversational."""
    
    # Invoke LLM with tools
    response = llm_with_tools.invoke([
        {"role": "system", "content": system_prompt},
        *messages
    ])
    
    return {"messages": [response]}

def tool_node(state: AgentState) -> dict:
    """Execute tool calls from the last message."""
    last_message = state["messages"][-1]
    tool_messages = []
    
    # Execute each tool call
    for tool_call in last_message.tool_calls:
        tool_name = tool_call["name"]
        tool_args = tool_call["args"]
        
        # Add customer_id to args if not present
        if "customer_id" not in tool_args:
            tool_args["customer_id"] = state.get("customer_id", "acme_corp")
        
        # Find and execute the tool
        if tool_name == "retrieve_knowledge":
            result = retrieve_knowledge.invoke(tool_args)
        else:
            result = f"Error: Unknown tool {tool_name}"
        
        # Create tool message
        tool_message = ToolMessage(
            content=result,
            tool_call_id=tool_call["id"],
            name=tool_name
        )
        tool_messages.append(tool_message)
    
    return {"messages": tool_messages}

In [60]:
# Cell 4: Define Conditional Routing
def should_continue(state: AgentState) -> str:
    """Decide whether to call tools or end."""
    last_message = state["messages"][-1]
    
    # Check if the last message has tool calls
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        return "tools"
    return "end"

In [None]:
# Cell 5: Graph Construction
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode

memory = MemorySaver()

graph_builder = StateGraph(agent_node)

graph_builder.add_node("preprocess_query", agent_node)
graph_builder.add_node("retrieve_knowledge", retrieve_knowledge)
tools = ToolNode([retrieve_knowledge])
graph_builder.add_node("synthesize_response", synthesize_response)

graph_builder.set_entry_point("preprocess_query")
graph_builder.add_edge("preprocess_query", "retrieve_knowledge")
graph_builder.add_edge("retrieve_knowledge", "synthesize_response")
graph_builder.add_edge("synthesize_response", END)


In [61]:
# Cell 6: Usage Examples
from langchain_core.messages import HumanMessage

# Example 1: Single query
print("=== Example 1: Password Reset Query ===")
result = app.invoke({
    "messages": [HumanMessage(content="How do I reset my password?")],
    "customer_id": "acme_corp"
}, {"configurable": {"thread_id": "example-1"}})

# Print the final response
print("\nAgent Response:")
result["messages"][-1].pretty_print()

# Show tool calls made
print("\nTool calls made:")
for msg in result["messages"]:
    if hasattr(msg, "tool_calls") and msg.tool_calls:
        for call in msg.tool_calls:
            print(f"- {call['name']}: {call['args']['query']}")

=== Example 1: Password Reset Query ===

Agent Response:

To reset your password, please follow the steps for the relevant account type:

**For Business Systems (SSO) Password Reset:**

1.  **Start a Password Reset:**
    *   Go to [iam.ucsd.edu/ResetPassword](https://iam.ucsd.edu/ResetPassword/)
    *   Enter your User ID (full UCSD Email address or your mainframe ID).
    *   Click **Continue**.
2.  **Validate your Identity:**
    *   Enter your Employee ID number (without leading zeros, found on your campus paycheck stub or UCPath).
    *   Enter the last 4 digits of your Social Security number.
    *   Enter your Date of Birth in *mm/dd/yyyy* format.
    *   Click **Continue**.
3.  **Confirm Password Reset:**
    *   Check your UCSD email tied to your Business Systems account for an email containing a temporary password. Your Department Security Administrator (DSA) will also be notified.
4.  **Set a New Password:**
    *   Go to [iam.ucsd.edu/ResetPassword/login](https://iam.ucsd.e

In [62]:
# Cell 6: Usage Examples
from langchain_core.messages import HumanMessage

# Example 1: Single query
print("=== Example 1: Password Reset Query ===")
result = app.invoke({
    "messages": [HumanMessage(content="How do I reset my password?")],
    "customer_id": "acme_corp"
}, {"configurable": {"thread_id": "support-456"}})

# Print the final response
print("\nAgent Response:")
result["messages"][-1].pretty_print()

# Show tool calls made
print("\nTool calls made:")
for msg in result["messages"]:
    if hasattr(msg, "tool_calls") and msg.tool_calls:
        for call in msg.tool_calls:
            print(f"- {call['name']}: {call['args']['query']}")

=== Example 1: Password Reset Query ===

Agent Response:

I can help you with that! Please specify which type of account you need to reset the password for:

*   **Business Systems (Single Sign-On) Account:**
    1.  Go to [iam.ucsd.edu/ResetPassword](https://iam.ucsd.edu/ResetPassword/)
    2.  Enter your User ID (full UCSD Email address or mainframe ID).
    3.  Click **Continue**.
    4.  Validate your identity by entering your Employee ID number, last 4 digits of your Social Security number, and Date of Birth (mm/dd/yyyy).
    5.  Click **Continue**.
    6.  Check your UCSD email for a temporary password.
    7.  Go to [iam.ucsd.edu/ResetPassword/login](https://iam.ucsd.edu/ResetPassword/login) and log in with the temporary password.
    8.  You will be prompted to create a new, permanent password. Your new password must be 6-8 characters, include a mix of upper/lower case, numbers, and symbols (@, #, $), and not contain dictionary words, your name, or username, and be different fr

In [63]:
# Cell 8: Conversation with History
print("=== Example 3: Multi-turn Conversation ===")

# Use config with thread_id - MemorySaver handles conversation history
config = {"configurable": {"thread_id": "support-456"}}

# Simulate a conversation
queries = [
    "What authentication methods do you support?",
    "Can you tell me more about SSO setup?",
    "What are the common SAML errors?"
]

for query in queries:
    print(f"\n👤 User: {query}")
    
    # Invoke with just the new message - MemorySaver handles history
    result = app.invoke({
        "messages": [HumanMessage(content=query)],
        "customer_id": "acme_corp"
    }, config)
    
    # Print agent response
    print("\n🤖 Agent:")
    # Find the last AI message (not a tool call)
    for msg in reversed(result["messages"]):
        if msg.type == "ai" and not hasattr(msg, "tool_calls"):
            print(msg.content[:500] + "..." if len(msg.content) > 500 else msg.content)
            break

=== Example 3: Multi-turn Conversation ===

👤 User: What authentication methods do you support?

🤖 Agent:

👤 User: Can you tell me more about SSO setup?

🤖 Agent:

👤 User: What are the common SAML errors?

🤖 Agent:


In [64]:
# Cell 9: Analyze Agent Behavior
print("=== Agent Behavior Analysis ===")

# Test different query types
test_queries = [
    "Hello!",  # Greeting - should not trigger search
    "How do I configure LDAP?",  # Technical - should search
    "What's the weather today?",  # Off-topic - might not search
    "I'm having login issues with error code 403"  # Specific error - should search
]

for i, query in enumerate(test_queries):
    print(f"\nQuery: '{query}'")
    
    # Each test gets its own thread to avoid state pollution
    result = app.invoke({
        "messages": [HumanMessage(content=query)],
        "customer_id": "acme_corp"
    }, {"configurable": {"thread_id": f"test-{i}"}})
    
    # Count tool calls
    tool_calls = 0
    for msg in result["messages"]:
        if hasattr(msg, "tool_calls") and msg.tool_calls:
            tool_calls += len(msg.tool_calls)
    
    print(f"Tool calls made: {tool_calls}")
    
    # Show brief response
    final_msg = result["messages"][-1]
    brief_response = final_msg.content[:100] + "..." if len(final_msg.content) > 100 else final_msg.content
    print(f"Response preview: {brief_response}")

=== Agent Behavior Analysis ===

Query: 'Hello!'
Tool calls made: 0
Response preview: Hello! How can I assist you today?

Query: 'How do I configure LDAP?'
Tool calls made: 1
Response preview: I am sorry, but I could not find any information on how to configure LDAP in the knowledge base. The...

Query: 'What's the weather today?'
Tool calls made: 0
Response preview: I apologize, but I cannot provide real-time weather information as my capabilities are limited to re...

Query: 'I'm having login issues with error code 403'
Tool calls made: 1
Response preview: Error code 403 often indicates an access forbidden issue, which could mean your account is locked or...


In [65]:
# Cell 10: Memory Inspection
print("=== Memory Inspection ===")

# Show conversation history for a thread
thread_id = "support-456"
state = app.get_state({"configurable": {"thread_id": thread_id}})

if state.values:
    print(f"\nConversation history for thread '{thread_id}':")
    messages = state.values.get("messages", [])
    
    for i, msg in enumerate(messages):
        if msg.type == "human":
            print(f"\n{i+1}. 👤 User: {msg.content}")
        elif msg.type == "ai" and not hasattr(msg, "tool_calls"):
            preview = msg.content[:150] + "..." if len(msg.content) > 150 else msg.content
            print(f"   🤖 Agent: {preview}")
        elif msg.type == "ai" and hasattr(msg, "tool_calls"):
            print(f"   🔍 Agent searched for: {', '.join([call['args']['query'] for call in msg.tool_calls])}")
        elif msg.type == "tool":
            print(f"   📚 Retrieved {len(msg.content)} characters of documentation")
    
    print(f"\nTotal messages in thread: {len(messages)}")
else:
    print(f"\nNo conversation history found for thread '{thread_id}'")

=== Memory Inspection ===

Conversation history for thread 'support-456':

1. 👤 User: How do I reset my password?
   🔍 Agent searched for: reset password
   📚 Retrieved 5452 characters of documentation
   🔍 Agent searched for: 

5. 👤 User: What authentication methods do you support?
   🔍 Agent searched for: authentication methods
   📚 Retrieved 3055 characters of documentation
   🔍 Agent searched for: 

9. 👤 User: Can you tell me more about SSO setup?
   🔍 Agent searched for: SSO setup
   📚 Retrieved 2322 characters of documentation
   🔍 Agent searched for: 

13. 👤 User: What are the common SAML errors?
   🔍 Agent searched for: SAML errors
   📚 Retrieved 2554 characters of documentation
   🔍 Agent searched for: 

Total messages in thread: 16


In [None]:
# Cell 9: Analyze Agent Behavior
print("=== Agent Behavior Analysis ===")

# Test different query types
test_queries = [
    "Hello!",  # Greeting - should not trigger search
    "How do I configure LDAP?",  # Technical - should search
    "What's the weather today?",  # Off-topic - might not search
    "I'm having login issues with error code 403"  # Specific error - should search
]

for query in test_queries:
    print(f"\nQuery: '{query}'")
    
    result = app.invoke({
        "messages": [HumanMessage(content=query)],
        "customer_id": "acme_corp"
    }, {"configurable": {"thread_id": "example-1"}})
    
    # Count tool calls
    tool_calls = 0
    for msg in result["messages"]:
        if hasattr(msg, "tool_calls") and msg.tool_calls:
            tool_calls += len(msg.tool_calls)
    
    print(f"Tool calls made: {tool_calls}")
    
    # Show brief response
    final_msg = result["messages"][-1]
    brief_response = final_msg.content[:100] + "..." if len(final_msg.content) > 100 else final_msg.content
    print(f"Response preview: {brief_response}")

=== Agent Behavior Analysis ===

Query: 'Hello!'


ValueError: Checkpointer requires one or more of the following 'configurable' keys: thread_id, checkpoint_ns, checkpoint_id