In [1]:
pip install langgraph langchain langchain-community tavily-python pydantic

Note: you may need to restart the kernel to use updated packages.


In [2]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, Sequence, Annotated
from langchain.schema import BaseMessage, HumanMessage, AIMessage
from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.tools.tavily_search import TavilySearchResults
from pydantic import BaseModel, Field

In [3]:
import operator
import os

In [4]:
from dotenv import load_dotenv
load_dotenv()

True

In [5]:
os.environ.get('TAVILY_API_KEY')

'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'

In [6]:
# State definition
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    current_agent: str
    validation_attempts: int
    final_output: str

#### Pydantic models for structured output

In [7]:
class TopicSelectionParser(BaseModel):
    Topic: str = Field(description="selected topic")
    Reasoning: str = Field(description='Reasoning behind topic selection')

In [8]:
class SupervisorDecision(BaseModel):
    next_agent: str = Field(description="Next agent to call: RAG, LLM, or Web Crawler")
    reasoning: str = Field(description="Reasoning for the decision")

In [9]:
class ValidationResult(BaseModel):
    is_valid: bool = Field(description="Whether the output is valid")
    feedback: str = Field(description="Feedback on the output")
    confidence_score: float = Field(description="Confidence score (0-1)")

#### Model Config

In [10]:
from langchain_google_genai import ChatGoogleGenerativeAI
model=ChatGoogleGenerativeAI(model='gemini-1.5-flash')

In [11]:
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter

In [13]:
loader=DirectoryLoader("../data2",glob="./*.txt",loader_cls=TextLoader)

In [14]:
docs=loader.load()

In [15]:
text_splitter=RecursiveCharacterTextSplitter(
    chunk_size=200,
    chunk_overlap=50
)

In [16]:
new_docs=text_splitter.split_documents(documents=docs)

In [17]:
new_docs

[Document(metadata={'source': '..\\data2\\usa.txt'}, page_content='ðŸ‡ºðŸ‡¸ Overview of the U.S. Economy'),
 Document(metadata={'source': '..\\data2\\usa.txt'}, page_content='The United States of America possesses the largest economy in the world in terms of nominal GDP, making it the most powerful economic force globally. It operates under a capitalist mixed economy,'),
 Document(metadata={'source': '..\\data2\\usa.txt'}, page_content='It operates under a capitalist mixed economy, where the private sector dominates, but the government plays a significant regulatory and fiscal role. With a population of over 335 million people and a'),
 Document(metadata={'source': '..\\data2\\usa.txt'}, page_content='a population of over 335 million people and a high level of technological advancement, the U.S. economy thrives on a foundation of consumer spending, innovation, global trade, and financial services.'),
 Document(metadata={'source': '..\\data2\\usa.txt'}, page_content='innovation, global 

In [18]:
doc_string=[doc.page_content for doc in new_docs]

In [19]:
pip install ChromaDB

Note: you may need to restart the kernel to use updated packages.


In [20]:
db=Chroma.from_documents(new_docs,embeddings)

In [21]:
retriever=db.as_retriever(search_kwargs={"k": 3})

In [61]:
new_docs[3]

Document(metadata={'source': '..\\data2\\usa.txt'}, page_content='a population of over 335 million people and a high level of technological advancement, the U.S. economy thrives on a foundation of consumer spending, innovation, global trade, and financial services.')

In [22]:
def format_docs(docs):
    """Helper function to format documents for RAG"""
    return "\n\n".join(doc.page_content for doc in docs)

In [23]:
# 1. Supervisor Node
def supervisor_node(state: AgentState):
    """
    Supervisor decides which agent to call based on the query and current state
    """
    print("-> RUNNING SUPERVISOR NODE ->")
    
    # Get the original query
    if isinstance(state["messages"][0], str):
        query = state["messages"][0]
    else:
        query = state["messages"][0].content
    
    # Get validation feedback if available
    validation_feedback = ""
    if len(state["messages"]) > 1:
        for msg in state["messages"]:
            if isinstance(msg, str) and "validation failed" in msg.lower():
                validation_feedback = msg
                break
    
    # Supervisor decision logic
    supervisor_prompt = f"""
    You are a supervisor agent that decides which specialized agent should handle a query.
    
    Available agents:
    1. RAG - For questions that require specific document knowledge or factual retrieval
    2. LLM - For general knowledge questions, reasoning, or creative tasks
    3. Web Crawler - For real-time information, current events, or recent data
    
    Query: {query}
    Validation Feedback: {validation_feedback}
    Validation Attempts: {state.get('validation_attempts', 0)}
    
    Decide which agent should handle this query. Consider:
    - If it's about current events or real-time data -> Web Crawler
    - If it requires specific document knowledge -> RAG
    - If it's general knowledge or reasoning -> LLM
    - If validation failed, consider trying a different approach
    
    Respond with just the agent name: RAG, LLM, or Web Crawler
    """
    
    # Simple decision logic (replace with your actual model call)
    decision = router(state)  # Use the existing router function
    
    new_state = {
        "current_agent": decision,
        "messages": state["messages"]
    }
    
    return new_state


In [24]:
# 2. Router Function 
def router(state: AgentState):
    """
    Enhanced router function that decides which agent to call
    """
    print("-> ROUTER ->")
    
    if isinstance(state["messages"][0], str):
        query = state["messages"][0].lower()
    else:
        query = state["messages"][0].content.lower()
    
    # Enhanced routing logic
    if any(keyword in query for keyword in ["current", "latest", "recent", "today", "news", "real-time"]):
        return "Web Crawler"
    elif any(keyword in query for keyword in ["document", "specific", "usa", "policy", "regulation"]):
        return "RAG Call"
    else:
        return "LLM Call"

In [62]:
# 3. RAG Node 
def rag_node(state: AgentState):
    """
    RAG node for document-based question answering
    """
    print("-> RAG NODE ->")
    
    if isinstance(state["messages"][0], str):
        question = state["messages"][0]
    else:
        question = state["messages"][0].content
    
    # RAG processing (replace with your actual RAG implementation)
    try:
        prompt = PromptTemplate(
            template="""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.

Question: {question}
Context: {context}
Answer:""",
            input_variables=['context', 'question']
        )
        
        # Placeholder for actual RAG chain
        rag_chain =  (
        {
            "context": retriever | format_docs,
            "question": RunnablePassthrough()
        }
        | prompt
        | model
        | StrOutputParser()
    )
        result = rag_chain.invoke(question)
        
    #    debug_rag_node(question, result, state)
        docs = retriever.invoke(question)
        print("Retrieved docs for RAG:", [doc.page_content for doc in docs])  # <-- Add this line
        
        
        return {"messages": [AIMessage(content=result)]}
    
    except Exception as e:
        return {"messages": [AIMessage(content=f"RAG error: {str(e)}")]}

In [26]:
# 4. LLM Node 
def llm_node(state: AgentState):
    """
    LLM node for general knowledge and reasoning
    """
    print("-> LLM NODE ->")
    
    if isinstance(state["messages"][0], str):
        question = state["messages"][0]
    else:
        question = state["messages"][0].content
    
    try:
        # Enhanced LLM call
        complete_query = f"""Answer the following question using your knowledge. Provide a comprehensive and accurate response.

Question: {question}

Answer:"""
        
        response = model.invoke(complete_query)
        result = response.content
        
        
        return {"messages": [AIMessage(content=result)]}
    
    except Exception as e:
        return {"messages": [AIMessage(content=f"LLM error: {str(e)}")]}

In [46]:
def debug_rag_node(state: AgentState):
    """
    RAG node with debugging to see what's being retrieved
    """
    print("-> RAG NODE (DEBUG) ->")
    
    if isinstance(state["messages"][0], str):
        question = state["messages"][0]
    else:
        question = state["messages"][0].content
    
    try:
        # DEBUG: See what chunks are being retrieved
        retrieved_docs = retriever.get_relevant_documents(question)
        print(f"\n=== RETRIEVED {len(retrieved_docs)} CHUNKS ===")
        for i, doc in enumerate(retrieved_docs):
            print(f"Chunk {i+1}: {doc.page_content[:200]}...")
            print(f"Metadata: {doc.metadata}")
            print("-" * 50)
        
        # Format the context
        context = format_docs(retrieved_docs)
        print(f"\n=== FORMATTED CONTEXT ===")
        print(f"Context length: {len(context)} characters")
        print(f"Context preview: {context[:300]}...")
        
        # Create prompt
        prompt = PromptTemplate(
            template="""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.

Question: {question}
Context: {context}
Answer:""",
            input_variables=['context', 'question']
        )
        
        # Get the final prompt
        final_prompt = prompt.format(question=question, context=context)
        print(f"\n=== FINAL PROMPT ===")
        print(final_prompt[:500] + "..." if len(final_prompt) > 500 else final_prompt)
        
        # Get response
        response = model.invoke(final_prompt)
        result = response.content
        
        print(f"\n=== MODEL RESPONSE ===")
        print(result)
        
        return {"messages": [AIMessage(content=result)]}
    
    except Exception as e:
        print(f"RAG error: {str(e)}")
        return {"messages": [AIMessage(content=f"RAG error: {str(e)}")]}


In [27]:
# 5. Web Crawler Node 
def web_crawler_node(state: AgentState):
    """
    Web crawler node for real-time information retrieval
    """
    print("-> WEB CRAWLER NODE ->")
    
    if isinstance(state["messages"][0], str):
        question = state["messages"][0]
    else:
        question = state["messages"][0].content
    
    try:
        # Web crawler implementation
        TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
        if not TAVILY_API_KEY:
            return {"messages": [AIMessage(content="Tavily API key not found. Please set TAVILY_API_KEY environment variable.")]}
        
        tool = TavilySearchResults(tavily_api_key=TAVILY_API_KEY)
        search_results = tool.invoke(question)
        
        # Process and format the results
        if search_results:
            formatted_results = []
            for result in search_results[:3]:  # Limit to top 3 results
                if isinstance(result, dict):
                    title = result.get('title', 'No title')
                    content = result.get('content', 'No content')
                    formatted_results.append(f"Title: {title}\nContent: {content}")
            
            final_result = f"Web search results for '{question}':\n\n" + "\n\n".join(formatted_results)
        else:
            final_result = f"No web search results found for: {question}"
        
        return {"messages": [AIMessage(content=final_result)]}
    
    except Exception as e:
        return {"messages": [AIMessage(content=f"Web crawler error: {str(e)}")]}

In [28]:
# 6. Validation Node
def validation_node(state: AgentState):
    """
    Validation node to check the quality and accuracy of generated output
    """
    print("-> VALIDATION NODE ->")
    
    # Get the last generated response
    last_message = state["messages"][-1]
    if isinstance(last_message, str):
        generated_output = last_message
    else:
        generated_output = last_message.content
    
    # Get original query
    if isinstance(state["messages"][0], str):
        original_query = state["messages"][0]
    else:
        original_query = state["messages"][0].content
    
    # Validation logic
    validation_prompt = f"""
    Validate the following generated output for the given query:
    
    Original Query: {original_query}
    Generated Output: {generated_output}
    
    Check for:
    1. Relevance to the query
    2. Accuracy of information
    3. Completeness of response
    4. Clarity and coherence
    
    Provide validation result (True/False) and feedback.
    """
    
    # Simple validation logic (replace with your actual validation)
    is_valid = True
    feedback = "Output validation passed"
    confidence_score = 0.8
    
    # Basic validation checks
    if len(generated_output) < 10:
        is_valid = False
        feedback = "Output too short"
        confidence_score = 0.2
    elif "error" in generated_output.lower():
        is_valid = False
        feedback = "Output contains errors"
        confidence_score = 0.3
    elif generated_output.strip() == "":
        is_valid = False
        feedback = "Empty output"
        confidence_score = 0.0
    
    validation_attempts = state.get("validation_attempts", 0) + 1
    
    if is_valid:
        return {
            "messages": [AIMessage(content=f"VALIDATION PASSED: {feedback}")],
            "final_output": generated_output,
            "validation_attempts": validation_attempts
        }
    else:
        return {
            "messages": [AIMessage(content=f"VALIDATION FAILED: {feedback}. Confidence: {confidence_score}")],
            "validation_attempts": validation_attempts
        }

In [29]:
# 7. Workflow Decision Functions
def should_continue(state: AgentState):
    """
    Decide whether to continue to validation or end
    """
    last_message = state["messages"][-1]
    if isinstance(last_message, str):
        content = last_message
    else:
        content = last_message.content
    
    # Check if this is a validation message
    if "VALIDATION" in content:
        return "end"
    else:
        return "validate"

def validation_decision(state: AgentState):
    """
    Decide what to do after validation
    """
    last_message = state["messages"][-1]
    if isinstance(last_message, str):
        content = last_message
    else:
        content = last_message.content
    
    validation_attempts = state.get("validation_attempts", 0)
    
    if "VALIDATION PASSED" in content:
        return "end"
    elif validation_attempts >= 3:  # Max 3 validation attempts
        return "end"
    else:
        return "supervisor"

In [30]:
# 8. Create and Configure Workflow
def create_workflow():
    """
    Create and configure the complete workflow
    """
    workflow = StateGraph(AgentState)
    
    # Add all nodes
    workflow.add_node("supervisor", supervisor_node)
    workflow.add_node("rag", rag_node)
    workflow.add_node("llm", llm_node)
    workflow.add_node("web_crawler", web_crawler_node)
    workflow.add_node("validation", validation_node)
    
    # Set entry point
    workflow.set_entry_point("supervisor")
    
    # Add conditional edges from supervisor
    workflow.add_conditional_edges(
        "supervisor",
        router,
        {
            "RAG Call": "rag",
            "LLM Call": "llm",
            "Web Crawler": "web_crawler"
        }
    )
    
    # Add edges from agent nodes to validation
    workflow.add_conditional_edges(
        "rag",
        should_continue,
        {
            "validate": "validation",
            "end": END
        }
    )
    
    workflow.add_conditional_edges(
        "llm",
        should_continue,
        {
            "validate": "validation",
            "end": END
        }
    )
    
    workflow.add_conditional_edges(
        "web_crawler",
        should_continue,
        {
            "validate": "validation",
            "end": END
        }
    )
    
    # Add conditional edges from validation
    workflow.add_conditional_edges(
        "validation",
        validation_decision,
        {
            "supervisor": "supervisor",
            "end": END
        }
    )
    
    return workflow.compile()

In [31]:
# 9. Main execution function
def run_workflow(query: str):
    """
    Run the complete workflow with a given query
    """
    app = create_workflow()
    
    # Initialize state
    initial_state = {
        "messages": [HumanMessage(content=query)],
        "current_agent": "",
        "validation_attempts": 0,
        "final_output": ""
    }
    
    # Run the workflow
    print(f"Starting workflow for query: {query}")
    print("-" * 5)
    
    result = app.invoke(initial_state)
    
    print("-" * 5)
    print("Workflow completed!")
    print(f"Final output: {result.get('final_output', 'No final output generated')}")
    
    return result

In [64]:
query = "What is the GDP of the USA?"

In [65]:
result = run_workflow(query)

Starting workflow for query: What is the GDP of the USA?
-----
-> SUPERVISOR NODE ->
-> ROUTER ->
-> ROUTER ->
-> RAG NODE ->
Retrieved docs for RAG: ['U.S. GDP â€“ Size, Composition, and Global Share', 'As of 2024, the United Statesâ€™ nominal GDP is estimated to be around $28 trillion USD, accounting for approximately 25% of the global economy. It ranks #1 in the world by nominal GDP, far ahead of', 'ðŸ‡ºðŸ‡¸ Overview of the U.S. Economy']
-> VALIDATION NODE ->
-----
Workflow completed!
Final output: The nominal GDP of the USA is estimated to be around $28 trillion USD as of 2024.  This represents about 25% of the global economy.  It holds the #1 ranking worldwide by nominal GDP.
