## RAG Project


<img src="image-rag.png">

In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
from langchain.schema import Document
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import Chroma

# Create embeddings
embedding_function = OpenAIEmbeddings()

In [None]:
# Example TechZone documents
docs = [
    Document(
        page_content="TechZone is owned by Priya Sharma, a technology entrepreneur with over 15 years of experience in AI and cloud computing. She founded TechZone to make advanced technology accessible to businesses of all sizes.",
        metadata={"source": "owner.txt"},
    ),
    Document(
        page_content="TechZone offers a variety of subscription plans: Starter at ₹1,000/month, Professional at ₹5,000/month, and Enterprise custom plans starting from ₹20,000/month.",
        metadata={"source": "pricing.txt"},
    ),
    Document(
        page_content="TechZone's support hours are Monday to Friday, 9:00 AM to 8:00 PM IST, and Saturday from 10:00 AM to 4:00 PM IST. No support on Sundays.",
        metadata={"source": "support_hours.txt"},
    ),
    Document(
        page_content="TechZone provides multiple services including AI-powered analytics, cloud hosting, and API integrations for payment and customer management.",
        metadata={"source": "services.txt"},
    ),
]

In [None]:
# Create vector store and retriever
db = Chroma.from_documents(docs, embedding_function)
retriever = db.as_retriever(search_kwargs={"k": 2})

In [None]:
from langchain_core.prompts import ChatPromptTemplate

template = """Answer the question based on the following context and the Chat history.
Especially take the latest question into consideration:

Chat history: {history}

Context: {context}

Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)

# Base LLM
llm = ChatOpenAI(model="gpt-4o-mini")
rag_chain = prompt | llm

In [None]:
from typing import TypedDict, List
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, END

In [None]:
# Agent state definition
class AgentState(TypedDict):
    messages: List[BaseMessage]
    documents: List[Document]
    on_topic: str
    rephrased_question: str
    proceed_to_generate: bool
    rephrase_count: int
    question: HumanMessage

# Classifier output schema
class GradeQuestion(BaseModel):
    score: str = Field(
        description="Is the question about specified TechZone topics? If yes -> 'Yes', otherwise -> 'No'."
    )

In [None]:
# Step 1: Question rewriter
def question_rewriter(state: AgentState):
    print(f"Entering question_rewriter with state: {state}")

    # Reset
    state["documents"] = []
    state["on_topic"] = ""
    state["rephrased_question"] = ""
    state["proceed_to_generate"] = False
    state["rephrase_count"] = 0

    if "messages" not in state or state["messages"] is None:
        state["messages"] = []

    if state["question"] not in state["messages"]:
        state["messages"].append(state["question"])

    if len(state["messages"]) > 1:
        conversation = state["messages"][:-1]
        current_question = state["question"].content
        messages = [
            SystemMessage(content="Rephrase the user's question so it is standalone and optimized for retrieval."),
            *conversation,
            HumanMessage(content=current_question),
        ]
        rephrase_prompt = ChatPromptTemplate.from_messages(messages)
        llm = ChatOpenAI(model="gpt-4o-mini")
        prompt = rephrase_prompt.format()
        response = llm.invoke(prompt)
        state["rephrased_question"] = response.content.strip()
    else:
        state["rephrased_question"] = state["question"].content

    print(f"Rephrased question: {state['rephrased_question']}")
    return state

In [None]:
# Step 2: Question classifier
def question_classifier(state: AgentState):
    print("Entering question_classifier")
    system_message = SystemMessage(content="""You are a classifier.
Check if the question is about one of the following TechZone topics:
1. Information about the owner (Priya Sharma)
2. Subscription pricing
3. Support hours
Answer only 'Yes' or 'No'.""")

    human_message = HumanMessage(content=f"User question: {state['rephrased_question']}")
    grade_prompt = ChatPromptTemplate.from_messages([system_message, human_message])

    llm = ChatOpenAI(model="gpt-4o-mini")
    structured_llm = llm.with_structured_output(GradeQuestion)
    grader_llm = grade_prompt | structured_llm

    result = grader_llm.invoke({})
    state["on_topic"] = result.score.strip()
    print(f"on_topic = {state['on_topic']}")
    return state

In [None]:
# Router based on topic
def on_topic_router(state: AgentState):
    return "retrieve" if state.get("on_topic", "").lower() == "yes" else "off_topic_response"

# Step 3: Retrieve docs
def retrieve(state: AgentState):
    documents = retriever.invoke(state["rephrased_question"])
    state["documents"] = documents
    print(f"Retrieved {len(documents)} documents")
    return state

# Document relevance grader schema
class GradeDocument(BaseModel):
    score: str = Field(description="Relevant? 'Yes' or 'No'")

In [None]:
# Step 4: Document relevance check
def retrieval_grader(state: AgentState):
    system_message = SystemMessage(content="Grade if document is relevant to the question. Only 'Yes' or 'No'.")
    llm = ChatOpenAI(model="gpt-4o-mini")
    structured_llm = llm.with_structured_output(GradeDocument)

    relevant_docs = []
    for doc in state["documents"]:
        human_message = HumanMessage(content=f"User question: {state['rephrased_question']}\nDoc: {doc.page_content}")
        grade_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
        grader_llm = grade_prompt | structured_llm
        result = grader_llm.invoke({})
        if result.score.strip().lower() == "yes":
            relevant_docs.append(doc)

    state["documents"] = relevant_docs
    state["proceed_to_generate"] = len(relevant_docs) > 0
    return state

In [None]:
# Router based on retrieval
def proceed_router(state: AgentState):
    if state.get("proceed_to_generate", False):
        return "generate_answer"
    elif state.get("rephrase_count", 0) >= 2:
        return "cannot_answer"
    else:
        return "refine_question"

In [None]:
# Step 5: Refine question if needed
def refine_question(state: AgentState):
    rephrase_count = state.get("rephrase_count", 0)
    system_message = SystemMessage(content="Refine the question slightly for better search.")
    human_message = HumanMessage(content=state["rephrased_question"])
    refine_prompt = ChatPromptTemplate.from_messages([system_message, human_message])
    llm = ChatOpenAI(model="gpt-4o-mini")
    response = llm.invoke(refine_prompt.format())
    state["rephrased_question"] = response.content.strip()
    state["rephrase_count"] = rephrase_count + 1
    return state

In [None]:
# Step 6: Generate answer
def generate_answer(state: AgentState):
    response = rag_chain.invoke(
        {"history": state["messages"], "context": state["documents"], "question": state["rephrased_question"]}
    )
    state["messages"].append(AIMessage(content=response.content.strip()))
    return state

In [None]:
# Step 7: Fallback responses
def cannot_answer(state: AgentState):
    state["messages"].append(AIMessage(content="Sorry, I couldn't find any relevant information."))
    return state

def off_topic_response(state: AgentState):
    state["messages"].append(AIMessage(content="I can only answer questions about TechZone's owner, pricing, or support hours."))
    return state

In [None]:
# Build workflow
from langgraph.checkpoint.memory import MemorySaver
checkpointer = MemorySaver()

workflow = StateGraph(AgentState)
workflow.add_node("question_rewriter", question_rewriter)
workflow.add_node("question_classifier", question_classifier)
workflow.add_node("off_topic_response", off_topic_response)
workflow.add_node("retrieve", retrieve)
workflow.add_node("retrieval_grader", retrieval_grader)
workflow.add_node("generate_answer", generate_answer)
workflow.add_node("refine_question", refine_question)
workflow.add_node("cannot_answer", cannot_answer)

workflow.add_edge("question_rewriter", "question_classifier")
workflow.add_conditional_edges("question_classifier", on_topic_router, {
    "retrieve": "retrieve",
    "off_topic_response": "off_topic_response",
})

workflow.add_edge("retrieve", "retrieval_grader")
workflow.add_conditional_edges("retrieval_grader", proceed_router, {
    "generate_answer": "generate_answer",
    "refine_question": "refine_question",
    "cannot_answer": "cannot_answer",
})
workflow.add_edge("refine_question", "retrieve")
workflow.add_edge("generate_answer", END)
workflow.add_edge("cannot_answer", END)
workflow.add_edge("off_topic_response", END)
workflow.set_entry_point("question_rewriter")

graph = workflow.compile(checkpointer=checkpointer)

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod
display(
    Image(
        graph.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

In [None]:
# Example runs
print("\n--- Off topic ---")
graph.invoke({"question": HumanMessage(content="What's the weather today?")}, config={"configurable": {"thread_id": 1}})

In [None]:
print("\n--- On topic ---")
graph.invoke({"question": HumanMessage(content="Who is the owner of TechZone?")}, config={"configurable": {"thread_id": 2}})