### ReAct Agent by providing RAG pipeline as tool to retrieve Products

In [None]:
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PayloadSchemaType, PointStruct, SparseVectorParams, Document, Prefetch, FusionQuery
from qdrant_client.http.models import models
import pandas as pd
import openai
import fastembed
from langsmith import traceable, get_current_run_tree

from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.messages import convert_to_openai_messages, convert_to_messages

from jinja2 import Template
from langchain_core.tools import tool
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import ToolNode
from instructor import from_openai
from openai import OpenAI
from langgraph.graph.message import add_messages

from utils.utils import get_tool_descriptions, format_ai_message
from typing import Annotated, List, Any, Dict
from pydantic import Field, BaseModel
from operator import add
import instructor

from IPython.display import Image, display
from pprint import pprint
import json

In [None]:
from typing import Annotated, List, Any, Dict
from pydantic import Field
from operator import add

class Toolcall(BaseModel):
    name: str
    args: dict
    
class SearchAgentResponse(BaseModel):
    answer: str
    tool_calls: List[Toolcall] = Field(default_factory=list)

class State(BaseModel):
    messages: Annotated[List[Any], add_messages] = []
    user_query: str
    expanded_queries: List[str] = []
    answer: str = ""
    retrieved_contextdata: Annotated[List[Any], add] = []
    query_relevant: bool = False

class QueryRelevanceResponse(BaseModel):
    query_relevant: bool
    reason: str

class QueryRewriteResponse(BaseModel):
    search_queries: List[str]
    
class AggregationResponse(BaseModel):
    answer: str = Field(description="The answer to the question in a list format.")

In [None]:
def create_embeddings(text, model="text-embedding-3-small"):
   
    response = openai.embeddings.create(
        model=model,
        input=text
    )
        
    return response.data[0].embedding

In [None]:
@traceable(name="query_rewriter_node", 
description="This function rewrites the query to be more specific to include multiple statements",
run_type="prompt"
)
def query_rewriter_node(state: State) -> str:
    """
    This function rewrites the query to be more specific to include multiple statements
    """
    
    prompt_template = """
    You are a Shopping Intent Extraction Agent.
    Your task is to analyze a user's complex request and split it into distinct, standalone product search statements.

    ### INSTRUCTIONS
    1. **Identify Entities:** Look for distinct people or needs mentioned (e.g., "for me", "for my kid", "for my wife").
    2. **Segment Requests:** If the user asks for multiple different items, separate them completely. Do not combine them.
    3. **Refine & Standardize:** Convert colloquial phrases into clear, searchable product categories.
        - If the user says "nice toys", convert to "toys for kids" or specific categories if implied.
        - If the user says "I need", infer the target audience based on the context (e.g., "for adults" or "men/women").
    4. **Output List:** Return a raw JSON object containing the list of search statements.
    5. Note: Don't includ below examples in your output.

    ### EXAMPLES

    <Question>
    I need smart watch. My kid needs nice toys. My wife want home appliances.
    </Question>
    <Response>
    {
        "search_queries": [
            "Smart watch for adults",
            "Toys for kids",
            "Home appliances for women"
        ]
    }
    </Response>

    <Question>
    Looking for a gaming laptop for myself and a pink ipad for my daughter.
    </Question>
    <Response>
    {
        "search_queries": [
            "Gaming laptop",
            "Pink iPad for girls"
        ]
    }
    </Response>

    <Question>
    {{ query }}
    </Question>
    <Response>
    """
    
    prompt = Template(prompt_template).render(query=state.user_query)
    
    client = instructor.from_openai(OpenAI())
    
    response, raw_response = client.chat.completions.create_with_completion(
        model="gpt-4o-mini",
        response_model=QueryRewriteResponse,
        messages=[{"role": "system", "content": prompt}],
        temperature=0.4
    )
    return {
        "expanded_queries": response.search_queries
    }

In [None]:
# making retrieve_embedding_data a retriever node for parallel execution
from typing import Dict

@tool("retrieve_embedding", description="Retrieve embedding data from Qdrant for a given query", response_format="content_and_artifact")
def retrieve_embedding(query: str) -> List[str]:
    """
    Retrieves a list of relevant product context strings from a Qdrant database using hybrid search (embedding and BM25 fusion) based on the given user query.

    Args:
        query (str): The user's search query for desired product(s).

    Returns:
        List[str]: Each string contains the product ID, description, and average rating, formatted as:
            'Product ID: <ASIN> - Description: <description> - Rating: <rating>'
    """
    
    qd_client = QdrantClient(url="http://localhost:6333")
    
    collection_name = "amazon_items-collection-hybrid-02"
    k=5
    
    querry_embeddings = create_embeddings(query)
    
    response = qd_client.query_points(
        collection_name=collection_name,
        prefetch=[Prefetch(
            query=querry_embeddings,
            using="text-embedding-3-small",
            limit=20),
            Prefetch(
                query=Document(text=query, model="qdrant/bm25"),
                using="bm25",
                limit=20)
            ],
        query=FusionQuery(fusion="rrf"),
        limit=k,
    )
    retrieved_context_ids = []
    retrieved_context = []
    retrieved_scores = []
    retrieved_context_ratings = []
    
    for point in response.points:
        retrieved_context_ids.append(point.payload["parent_asin"])
        retrieved_context.append(point.payload["description"])
        retrieved_scores.append(point.score)
        retrieved_context_ratings.append(point.payload["average_rating"])

    retrieved_contextdata = []
    for item, context, rating in zip(retrieved_context_ids, retrieved_context, retrieved_context_ratings):
        product_context = f"Product ID: {item} - Description: {context} - Rating: {rating}"
        retrieved_contextdata.append(product_context)
    
    return json.dumps(retrieved_contextdata), retrieved_contextdata    

In [None]:
@traceable(name="aggregation_node", 
description="This function aggregates the retrieved context data and returns the final answer",
run_type="retriever"
)
def aggregation_node(state: State) -> State:
    """
    This function aggregates the retrieved context data and returns the final answer
    """
    
    # Extract the formatted product descriptions from state.messages
    product_descriptions = []
    
    for msg in state.messages:
      if isinstance(msg, ToolMessage):
        if msg.artifact:
            for item in msg.artifact:
              product_descriptions.append(item)
    
    prompt_template = """
    You are a specialized Product Expert Assistant. Your goal is to answer customer questions accurately using ONLY the provided product information.

      ### Instructions:
      1. **Source of Truth:** Answer strictly based on the provided "Available Products" section below. Do not use outside knowledge or make assumptions.
      2. **Handling Missing Info:** If the answer cannot be found in the provided products, politely state that you do not have that information. Do not make up features.
      3. **Tone:** Be helpful, professional, and concise.
      4. **Terminology:** Never refer to the text below as "context" or "data." Refer to it naturally as "our current inventory" or "available products."
      5. **output format** - An output of the following format is expected:
          
          1. **Answer:** The answer to the question.
          2. **Context:** The list of the IDs of the chunks that were used to answer the question. Only return the ones that are used in the answer.
          3. **Description:** Short description (1-2 sentences) of the item based on the description provided in the context.

      ### Available Products:
      <inventory_data>
      {{ preprocessed_context }}
      </inventory_data>

      ### Customer Question:
      {{ question }}
      
      ### Expanded Queries by you on the above question:
      {{ expanded_queries }}

      ### Answer:
      """
      
    prompt = Template(prompt_template).render(
                    question=state.user_query, 
                    expanded_queries=state.expanded_queries, 
                    preprocessed_context=state.retrieved_contextdata)
    
    client = instructor.from_openai(OpenAI())
    
    response, raw_response = client.chat.completions.create_with_completion(
    model="gpt-4.1-mini",
    messages=[{"role": "system", "content": prompt}],
    response_model=AggregationResponse,
    temperature=0.4
    )
    
    return {
      "answer": response.answer,
      "retrieved_contextdata": product_descriptions
    }

In [None]:
# add router node to evaluate the user query and decide the next node to execute
def router_node(state: State) -> State:
    """
    This function evaluates the user query and decides the next node to execute
    """
    prompt_template = """
    You are a Query Relevance Validator for a specific e-commerce product catalog.
    Your job is to classify the user's intent and determine if we need to ask for clarification before proceeding.

    ### Instructions:
    1. **Analyze Intent:** Look at the User Query below.
    2. **Determine Relevance:** - **RELEVANT (true):** The user is asking about buying products, features, comparisons, prices, or inventory.
    - **NOT RELEVANT (false):** The user is asking about topics completely unrelated to shopping (e.g., "How's the weather?").
    3. **Stock & Availability Check:** - If the user asks specifically about **product stock** or **availability** (e.g., "Is this in stock?", "Do you have inventory?"), you MUST request clarification.
    - Set `clarification_needed` to `true`.
    - In the `reason` field, draft a polite clarification question (e.g., "Could you specify which product or store location you are asking about?").
    4. **Output Format:** You must output valid JSON only.

    ### Schema:
    {
        "query_relevant": boolean,
        "clarification_needed": boolean,
        "reason": "string"
    }

    ### User Query:
    {{ question }}

    ### JSON Response:
    """
    
    prompt = Template(prompt_template).render(question=state.user_query)
    
    client = instructor.from_openai(OpenAI())
    
    response, raw_response = client.chat.completions.create_with_completion(
        model="gpt-4o-mini",
        response_model=QueryRelevanceResponse,
        messages=[{"role": "system", "content": prompt}],
        temperature=0.4
    )
    
    return {
        "query_relevant": response.query_relevant,
        "answer": response.reason
    }

In [None]:
# define the conditional edge to decide the query expansion or end state
from typing import Literal
def router_conditional_edge(state: State) -> Literal["query_rewriter", END]:
    """
    This function decides the next node to execute based on the user query
    """
    if state.query_relevant:
        return "query_rewriter"
    else:
        return END

In [None]:
# define agent node which can use RAG pipeline to perform search on the products 
from langgraph.prebuilt import tools_condition
from langchain_openai import ChatOpenAI

@traceable(name="agent_node", 
description="This function uses the RAG pipeline to perform search on the products",
run_type="retriever"
)
def agent_node(state: State) -> State:
    """
    This function uses the RAG pipeline to perform search on the products
    """
    search_agent_prompt = """
    You are a query dispatcher. Your ONLY job is to call the `retrieve_embedding` tool for every input query in the list below.


    ### STRICT RULES:
    1. **DO NOT** output text, markdown, or python code.
    2. **DO NOT** number the list or add bullet points.
    3. **MUST** generate a separate Tool Call for each line item.

    ### Input Queries:
    {{ expanded_queries }}
    """ 
    
    prompt = Template(search_agent_prompt).render(expanded_queries=state.expanded_queries)
    
    #client = instructor.from_openai(OpenAI())
    client = ChatOpenAI(model="gpt-4o-mini", temperature=0.4).bind_tools([retrieve_embedding], tool_choice="required")
    
    response = client.invoke(prompt)
    #response, raw_response = client.chat.completions.create_with_completion(
    #    model="gpt-4o-mini",
    #    messages=[{"role": "system", "content": prompt}],
    #    temperature=0.4,
    #    tools=[retrieve_embedding]
    #)
    
    return {
        "messages": [response]
    }

# define custom route edge to decide the tool call or agent node or aggregation node
def custome_route_edge(state: State) -> Literal["ageaggregation_nodent", "tools", END]:
    """
    This function decides the next node to execute based on the user query
    """
    #print(state.messages)
    
    tool_calls_count = 0
    for m in state.messages:
        if m.tool_calls and isinstance(m, AIMessage):
            tool_calls_count += len(m.tool_calls)
    
    print(f"tool_calls_count: {tool_calls_count}")
    
    if tool_calls_count == 0:
        return "aggregation"
    
    if tools_condition(state.messages) == "tools" and tool_calls_count <= len(state.expanded_queries):
        return "tools"
    
    return "aggregation"
    

In [None]:
graphbuilder2 = StateGraph(State)

tools_node = ToolNode(tools=[retrieve_embedding])
graphbuilder2.add_node("router", router_node)
graphbuilder2.add_node("query_rewriter", query_rewriter_node)
graphbuilder2.add_node("agent_node", agent_node)
graphbuilder2.add_node("aggregation", aggregation_node)
graphbuilder2.add_node("tools", tools_node)

graphbuilder2.add_edge(START, "router")
graphbuilder2.add_conditional_edges("router", router_conditional_edge, {"query_rewriter": "query_rewriter", END: END})
graphbuilder2.add_edge("query_rewriter", "agent_node")
graphbuilder2.add_conditional_edges("agent_node", custome_route_edge, {"tools": "tools", "aggregation": "aggregation", END: "aggregation"})
graphbuilder2.add_edge("tools", "aggregation")
graphbuilder2.add_edge("aggregation", END)

agg_graph_1 = graphbuilder2.compile()

display(Image(agg_graph_1.get_graph().draw_mermaid_png()))

In [None]:
initial_state = State(user_query="top laptop under 2000$. Show me the kid toys as well. Nice bags for my daughter")

result = agg_graph_1.invoke(initial_state)

pprint(result["answer"])