# Assignment 4: Multi-Agent System with Supervisor and Validation

## Project Overview
This project implements a multi-agent system using LangGraph that includes a supervisor node, multiple specialized nodes, and a validation mechanism with feedback loops.

## Architecture

### 1. Supervisor Node
- **Purpose**: Central coordinator that makes routing decisions
- **Responsibility**: Determines which node to call next based on the current state and requirements
- **Decision Logic**: Routes requests to appropriate specialized nodes

### 2. Router Function
- **Purpose**: Implements the routing logic for the supervisor
- **Functionality**: Analyzes input and context to decide the next step
- **Integration**: Works closely with the supervisor node

### 3. Specialized Nodes

#### 3.1 LLM Node
- **Purpose**: Handles language model calls for text generation
- **Functionality**: Processes natural language queries and generates responses
- **Use Cases**: Text completion, question answering, content generation

#### 3.2 RAG Node (Retrieval-Augmented Generation)
- **Purpose**: Combines retrieval and generation for enhanced responses
- **Functionality**: Retrieves relevant information from knowledge base and generates contextual responses
- **Components**: Document retriever + LLM generator

#### 3.3 Web Crawler Node
- **Purpose**: Fetches real-time information from the internet
- **Functionality**: Performs web searches, scrapes content, and extracts relevant data
- **Use Cases**: Current events, real-time data, up-to-date information

### 4. Validation Node
- **Purpose**: Validates the quality and accuracy of generated outputs
- **Validation Methods**:
  - Content relevance checking
  - Factual accuracy verification
  - Format and structure validation
  - Completeness assessment

### 5. Feedback Loop Mechanism
- **Flow**: If validation fails → Return to supervisor node
- **Decision**: Supervisor determines corrective action (retry with same node or switch to different node)
- **Iteration**: Process continues until validation passes

### 6. Final Output Generation
- **Condition**: Only generates final output after successful validation
- **Quality Assurance**: Ensures all outputs meet quality standards

## Workflow
```
Input → Supervisor Node → Router Function → Specialized Node (LLM/RAG/Web Crawler)
                ↑                                           ↓
                ←←←←←←← Validation Fails ←←←←← Validation Node
                                                          ↓
                                              Validation Passes
                                                          ↓
                                                   Final Output
```

## Implementation Goals
1. Build a robust supervisor-agent architecture
2. Implement intelligent routing between different processing nodes
3. Create specialized nodes for different types of information processing
4. Develop comprehensive validation mechanisms
5. Establish effective feedback loops for quality control
6. Ensure reliable final output generation

In [1]:
# importing all the packages needed for the code to run
import os
import json
import operator
from typing import List
from pydantic import BaseModel , Field
from langchain.prompts import PromptTemplate
from typing import TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph import StateGraph,END
from langchain.output_parsers import PydanticOutputParser
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter

In [2]:
def setup_rag_system(data_directory: str, chunk_size: int = 200, chunk_overlap: int = 50, k: int = 3):
    """
    Set up a complete RAG (Retrieval-Augmented Generation) system.
    
    Args:
        data_directory (str): Path to the directory containing text files
        chunk_size (int): Size of text chunks for splitting documents
        chunk_overlap (int): Overlap between chunks
        k (int): Number of documents to retrieve
    
    Returns:
        tuple: (model, retriever) - Configured model and retriever
    """
    # Config the model
    model = ChatGoogleGenerativeAI(model='gemini-1.5-flash')
    
    # Config the embedding model
    embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en")
    
    # Load documents
    loader = DirectoryLoader(
        data_directory, 
        glob="./*.txt", 
        loader_cls=TextLoader
    )
    docs = loader.load()
    
    # Split documents
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )
    new_docs = text_splitter.split_documents(documents=docs)
    
    # Create vector database
    db = Chroma.from_documents(new_docs, embeddings)
    retriever = db.as_retriever(search_kwargs={"k": k})
    
    return model, retriever


In [3]:
class TopicSelectionParser(BaseModel):
    """Parser for topic selection with reasoning."""
    Topic: str = Field(description="selected topic")
    Reasoning: str = Field(description='Reasoning behind topic selection')

class AgentState(TypedDict):
    """Agent state for StateGraph management."""
    messages: Annotated[Sequence[BaseMessage], operator.add]
    validation_status: str  # "pending", "passed", "failed"
    next_action: str       # "supervisor", "final_output"

def create_topic_selection_system():
    """
    Create a topic selection parser and agent state for the LangGraph system.
    
    Returns:
        tuple: (parser, AgentState, format_instructions)
    """
    # Create parser
    parser = PydanticOutputParser(pydantic_object=TopicSelectionParser)
    
    # Get format instructions
    format_instructions = parser.get_format_instructions()
    
    return parser, AgentState, format_instructions

def format_docs(docs):
    """
    Format documents by joining their page content.
    
    Args:
        docs: List of document objects
        
    Returns:
        str: Formatted document content
    """
    return "\n\n".join(doc.page_content for doc in docs)

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)



In [4]:
def function_1(state:AgentState):
    
    question=state["messages"][-1]
    
    print("Question",question)
    
    template="""
    Your task is to classify the given user query into one of the following categories: [USA,Not Related, Websearch]. 
    Only respond with the category name and nothing else.

    User query: {question}
    {format_instructions}
    """
    
    prompt= PromptTemplate(
        template=template,
        input_variable=["question"],
        partial_variables={"format_instructions": parser.get_format_instructions()}
    )
    
    
    chain= prompt | model | parser
    
    response = chain.invoke({"question":question})
    
    print("Parsed response:", response)
    
    return {"messages": [response.Topic]}

In [5]:
def enhanced_router(state: AgentState):
    """Enhanced router that handles validation flow"""
    
    # Check if we have validation status
    if "validation_status" in state:
        if state["validation_status"] == "failed":
            print("🔄 Validation failed - routing back to supervisor")
            return "supervisor"
        elif state["validation_status"] == "passed":
            print("✅ Validation passed - generating final output")
            return "final_output"
    
    # Original routing logic
    last_message = state["messages"][-1]
    if isinstance(last_message, dict):
        last_message = last_message.get("content", "")
    
    if "supervisor" in str(last_message).lower():
        return "supervisor"
    elif any(keyword in str(last_message).lower() for keyword in ["rag", "document", "file"]):
        return "rag"
    elif any(keyword in str(last_message).lower() for keyword in ["web", "search", "latest", "news"]):
        return "webcrawl"
    elif any(keyword in str(last_message).lower() for keyword in ["validation", "validate"]):
        return "validation"
    else:
        return "llm"

In [6]:
def router(state:AgentState):
    print("-> ROUTER ->")
    
    last_message=state["messages"][-1]
    print("last_message:", last_message)
    
    if "usa" in last_message.lower():
        return "RAG Call"
    elif "not related" in last_message.lower():
        return "LLM Call"
    else:
        return "WEBCRAWL Call"

In [7]:
# RAG Function
def function_2(state:AgentState):
    print("-> RAG Call ->")
    
    question = state["messages"][0]
    
    prompt=PromptTemplate(
        template="""You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\nQuestion: {question} \nContext: {context} \nAnswer:""",
        
        input_variables=['context', 'question']
    )
    
    rag_chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | model
        | StrOutputParser()
    )
    result = rag_chain.invoke(question)
    return  {"messages": [result]}

In [8]:
# LLM Function
def function_3(state:AgentState):
    print("-> LLM Call ->")
    question = state["messages"][0]
    
    # Normal LLM call
    complete_query = "Anwer the follow question with you knowledge of the real world. Following is the user question: " + question
    response = model.invoke(complete_query)
    return {"messages": [response.content]}

In [9]:
# WebCrawling Function
from dotenv import load_dotenv
load_dotenv()
API_KEY=os.getenv("TAVILY_API_KEY")
from langchain_community.tools.tavily_search import TavilySearchResults

def function_4(state:AgentState):
    print("-> WEBCRAWL Call ->")
    question = state["messages"][0]

    tool=TavilySearchResults(tavily_api_key=API_KEY)
    
    # Normal LLM call
    complete_query = "Anwer the follow question with you knowledge of the real world. Following is the user question: " + question
    response = tool.invoke({"query": complete_query})
    return {"messages": response}

In [10]:
# Enhanced function_4 with clickable links
def function_4(state: AgentState):
    print("-> Web Search & Summarization Started")
    
    # Get the question from state
    if isinstance(state["messages"], list) and len(state["messages"]) > 0:
        question = state["messages"][-1]
        print("Last message in state:", question)
        if isinstance(question, dict):
            question = question.get("content", "")
    else:
        question = str(state["messages"])
    
    try:
        # Search the web
        tool = TavilySearchResults(tavily_api_key=API_KEY, max_results=3)
        results = tool.invoke({"query": question})
        
        if results and len(results) > 0:
            # Collect content from multiple sources
            all_content = []
            sources = []
            
            for i, result in enumerate(results[:3]):  # Top 3 results
                content = result.get('content', '')
                title = result.get('title', f'Source {i+1}')
                url = result.get('url', '')
                
                if content:
                    all_content.append(content)
                    # Include clickable links in sources
                    if url:
                        sources.append(f"• [{title}]({url})")
                    else:
                        sources.append(f"• {title}")
            
            # Combine and summarize content
            combined_content = " ".join(all_content)
            summary = summarize_content(combined_content, question)
            
            # Format final answer with clickable links
            answer = f"""📰 **Web Search Summary**

**Question:** {question}

**Summary:**
{summary}

**Sources with Links:**
{chr(10).join(sources[:3])}

---
*Search completed at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*
*Click on source titles to visit the original articles*"""
            
        else:
            answer = "❌ No search results found for your query."
            
    except Exception as e:
        answer = f"🚨 Search error: {str(e)}"
    
    print("-> Web Search & Summarization Complete")
    # Ensure the result is formatted as markdown
    #print(answer)
    # Return in the expected format for your agent state
    return {"messages": state["messages"] + [answer]}

In [11]:
# # Test the web search function
# state = {"messages": ["find the latest news about AI technology"]}
# print("Testing web search function...")
# print("Input state:", state)
# print("-" * 50)

# result = function_4(state)
# print("Result:", result)

In [12]:
def create_enhanced_langgraph_workflow(
    agent_state_class,
    supervisor_func,
    rag_func,
    llm_func,
    webcrawl_func,
    validation_func,
    final_output_func,
    router_func
):
    workflow = StateGraph(agent_state_class)
    
    # Add all nodes
    workflow.add_node("supervisor", supervisor_func)
    workflow.add_node("rag", rag_func)
    workflow.add_node("llm", llm_func)
    workflow.add_node("webcrawl", webcrawl_func)
    workflow.add_node("validation", validation_func)
    workflow.add_node("final_output", final_output_func)
    
    # Define the flow
    workflow.set_entry_point("supervisor")
    
    # From supervisor to task nodes
    workflow.add_conditional_edges(
        "supervisor",
        router_func,
        {
            "rag": "rag",
            "llm": "llm", 
            "webcrawl": "webcrawl"
        }
    )
    
    # All task nodes go to validation
    workflow.add_edge("rag", "validation")
    workflow.add_edge("llm", "validation")
    workflow.add_edge("webcrawl", "validation")
    
    # From validation - conditional routing
    workflow.add_conditional_edges(
        "validation",
        lambda state: state.get("next_action", "supervisor"),
        {
            "supervisor": "supervisor",  # If validation failed
            "final_output": "final_output"  # If validation passed
        }
    )
    
    # Final output ends the workflow
    workflow.add_edge("final_output", END)
    
    return workflow.compile()

In [13]:
def create_langgraph_workflow(agent_state_class, supervisor_func, rag_func, llm_func, router_func, webcrawl_func):
    """
    Create a modular LangGraph workflow with supervisor, RAG, and LLM nodes.
    
    Args:
        agent_state_class: The AgentState class for state management
        supervisor_func: Function for the supervisor node
        rag_func: Function for the RAG node
        llm_func: Function for the LLM node
        router_func: Router function for conditional edges
        webcrawl_func: Function for the web crawling node
    
    Returns:
        compiled workflow app
    """
    # Initialize workflow
    workflow = StateGraph(agent_state_class)
    
    # Add nodes
    workflow.add_node("Supervisor", supervisor_func)
    workflow.add_node("RAG", rag_func)
    workflow.add_node("LLM", llm_func)
    workflow.add_node("WEBCRAWL", webcrawl_func)
    
    # Set entry point
    workflow.set_entry_point("Supervisor")
    
    # Add conditional edges from supervisor
    workflow.add_conditional_edges(
        "Supervisor",
        router_func,
        {
            "RAG Call": "RAG",
            "LLM Call": "LLM",
            "WEBCRAWL Call": "WEBCRAWL"
        }
    )
    
    # Add terminal edges
    workflow.add_edge("RAG", END)
    workflow.add_edge("LLM", END)
    workflow.add_edge("WEBCRAWL", END)
    
    # Compile and return the workflow
    app = workflow.compile()
    return app



In [14]:
def function_5_validation(state: AgentState):
    """Validation function to check the quality of generated output"""
    print("-> VALIDATION Started")
    
    # Get the last message (the output to validate)
    if isinstance(state["messages"], list) and len(state["messages"]) > 0:
        last_message = state["messages"][-1]
        if isinstance(last_message, dict):
            content = last_message.get("content", "")
        else:
            content = str(last_message)
    else:
        content = str(state["messages"])
    
    # Get the original question (first message)
    original_question = ""
    if isinstance(state["messages"], list) and len(state["messages"]) > 0:
        first_message = state["messages"][0]
        if isinstance(first_message, dict):
            original_question = first_message.get("content", "")
        else:
            original_question = str(first_message)
    
    print(f"Validating response for: {original_question[:50]}...")
    
    # Validation criteria
    validation_results = validate_response(content, original_question)
    
    if validation_results["is_valid"]:
        print("✅ Validation PASSED")
        # Add validation stamp to the response
        validated_response = f"""
{content}

---
✅ **Validation Status:** PASSED
📊 **Quality Score:** {validation_results['score']}/10
⏰ **Validated at:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
"""
        return {
            "messages": state["messages"][:-1] + [validated_response],
            "validation_status": "passed",
            "next_action": "final_output"
        }
    else:
        print("❌ Validation FAILED")
        # Add validation failure message
        failure_message = f"""
❌ **Validation Failed**

**Issues Found:**
{chr(10).join(f'• {issue}' for issue in validation_results['issues'])}

**Score:** {validation_results['score']}/10
**Recommendation:** {validation_results['recommendation']}
"""
        return {
            "messages": state["messages"] + [failure_message],
            "validation_status": "failed",
            "next_action": "supervisor"
        }

def validate_response(content: str, question: str) -> dict:
    """Validate the response quality"""
    issues = []
    score = 10
    
    # Check 1: Content length
    if len(content) < 50:
        issues.append("Response too short")
        score -= 3
    
    # Check 2: Contains error messages
    error_keywords = ["error", "failed", "could not", "unable to", "no results"]
    if any(keyword in content.lower() for keyword in error_keywords):
        issues.append("Contains error messages")
        score -= 2
    
    # Check 3: Relevance to question
    question_words = set(question.lower().split())
    content_words = set(content.lower().split())
    relevance = len(question_words.intersection(content_words)) / len(question_words)
    
    if relevance < 0.2:
        issues.append("Low relevance to original question")
        score -= 3
    
    # Check 4: Has proper formatting
    if "**" not in content and "#" not in content:
        issues.append("Poor formatting")
        score -= 1
    
    # Check 5: Contains sources (for web search)
    if "web search" in question.lower() or "latest" in question.lower():
        if "source" not in content.lower() and "http" not in content.lower():
            issues.append("Missing sources for web search")
            score -= 2
    
    # Determine recommendation
    if score >= 7:
        recommendation = "Response quality is good"
    elif score >= 5:
        recommendation = "Response needs minor improvements"
    else:
        recommendation = "Response needs significant improvement - retry with different approach"
    
    return {
        "is_valid": score >= 6,  # Pass threshold
        "score": max(0, score),
        "issues": issues,
        "recommendation": recommendation
    }

In [15]:
def function_6_final_output(state: AgentState):
    """Generate the final polished output"""
    print("-> FINAL OUTPUT Generation")
    
    # Get the validated response
    validated_response = state["messages"][-1]
    
    final_output = f"""
# 🎯 Final AI Agent Response

{validated_response}

---
🤖 **Generated by:** Multi-Agent LangGraph System
🔍 **Processing Steps:** Supervisor → Task Execution → Validation → Final Output
✅ **Quality Assured:** Response validated and approved
"""
    
    print("✅ Final output generated successfully!")
    
    return {
        "messages": state["messages"] + [final_output],
        "validation_status": "completed",
        "next_action": "end"
    }

In [16]:
model, retriever = setup_rag_system("/Users/kausik/Desktop/Agentic-AI/3-langgraph/data2")
parser, AgentState, format_instructions = create_topic_selection_system()
app = create_langgraph_workflow(
    agent_state_class=AgentState,
    supervisor_func=function_1,
    rag_func=function_2,
    llm_func=function_3,
    webcrawl_func=function_4,
    router_func=router
)



  from .autonotebook import tqdm as notebook_tqdm


In [17]:
# Create the enhanced app with validation
import datetime
enhanced_app = create_enhanced_langgraph_workflow(
    agent_state_class=AgentState,
    supervisor_func=function_1,
    rag_func=function_2,
    llm_func=function_3,
    webcrawl_func=function_4,
    validation_func=function_5_validation,
    final_output_func=function_6_final_output,
    router_func=enhanced_router
)

In [18]:
# Test with validation
state = {
    "messages": ["web search can you tell me the industrial growth of world's most powerful economy?"],
    "validation_status": "pending",
    "next_action": "supervisor"
}

result = enhanced_app.invoke(state)
print("Final result:", result)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Question web search can you tell me the industrial growth of world's most powerful economy?
Parsed response: Topic='USA' Reasoning="The query asks about the industrial growth of the world's most powerful economy, which is generally considered to be the USA."
-> LLM Call ->
-> VALIDATION Started
Validating response for: web search can you tell me the industrial growth o...
✅ Validation PASSED


AttributeError: module 'datetime' has no attribute 'now'

In [None]:
# Install required packages (run once)
# !pip install graphviz
# !pip install pygraphviz

# Then use this for best results:
from IPython.display import Image, display
display(Image(app.get_graph().draw_mermaid_png()))

In [None]:
# Test different queries
state1 = {"messages": ["latest AI technology news"]}
result1 = app.invoke(state1)

# state2 = {"messages": ["current stock market updates"]}
# result2 = app.invoke(state2)

# state3 = {"messages": ["recent climate change developments"]}
# result3 = app.invoke(state3)

In [None]:
state={"messages":["what is a gdp of usa in web?"]}

app.invoke(state)

In [None]:
state={"messages":["web search can you tell me the industrial growth of world's most powerful economy search in web?"]}
app.invoke(state)

In [None]:
state={"messages":["can you tell me the industrial growth of world's poor economy?"]}
result=app.invoke(state)

In [None]:
state={"messages":["What's the capital of India from web?"]}
result=app.invoke(state)