In [None]:
# imports
from langchain.messages import AnyMessage
from typing_extensions import TypedDict, List, Annotated, Optional, Literal
from langchain.messages import SystemMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import operator
from pydantic import BaseModel, Field

import heapq

from helper import get_models
from hybrid_database import hybrid_search

In [None]:
# get models, embeddings and index
database, embedding_model, rerank_model, llm_model = get_models()

In [None]:
# define tools


# # Augment the LLM with tools
# tools = []
# tools_by_name = {tool.name: tool for tool in tools}
# model_with_tools = model.bind_tools(tools)

In [None]:
question = ""
documents = ""

### LLMS

In [None]:
# router node
class RouteDecision(BaseModel):
    """Routes the user query to the appropriate data source."""
    reasoning: str = Field(
        ..., 
        description="Briefly explain what the user is asking and check if it matches the explicit database topics."
    )
    is_in_domain: bool = Field(
        ..., 
        description="True ONLY if the query is strictly about the specific topics in our database (e.g., vLLM, CRAG, Self-RAG). False for general ML topics like Federated Learning, Vision, etc."
    )
    route: Literal["vectorstore", "websearch", "chitchat"] = Field(
        ..., 
        description="Choose 'vectorstore' if is_in_domain is True. Choose 'websearch' if is_in_domain is False or for current events. Choose 'chitchat' for greetings."
    )

router_node_prompt = ChatPromptTemplate(
    [
        ("system", "You are an expert routing assistant for a highly specialized AI Engineering knowledge base. \nYour job is to analyze the user\'s query and route it to the correct destination.\n\nCRITICAL CONTEXT: Our \'vectorstore\' is NOT a general Machine Learning database. It ONLY contains 15 specific research papers. \nThe explicit topics covered in the vectorstore are:\n- LLM Architecture & Serving (vLLM, Transformers, PagedAttention)\n- Advanced RAG Methodologies (CRAG, Self-RAG, Adaptive-RAG, Vector Databases)\n- Model Building from Scratch (ConvNeXt)\n\nRouting Rules:\n1. Route to \'vectorstore\' ONLY if the query is explicitly related to the specific topics/papers listed above.\n2. Route to \'websearch\' if the user asks about General ML/AI topics NOT listed above (e.g., Federated Learning, Reinforcement Learning, CNNs, Generative Adversarial Networks).\n3. Route to \'websearch\' if the user asks about real-world current events, news, or live data.\n4. Route to \'chitchat\' if the query is a simple conversational greeting or compliment."),
        ("human", "Route this query: {question}"),
    ],
    input_variables = ["question"],
)

router_node_llm = router_node_prompt | llm_model.with_structured_output(RouteDecision, method="json_schema", strict=True)

result = router_node_llm.invoke({"question": question})

print(type(result))
print(result.route)

In [None]:
# rewrite node
class RewrittenQuery(BaseModel):
    """The optimized search query for the vector database."""
    reasoning: str = Field(..., description="Briefly explain what keywords were extracted or expanded.")
    query: str = Field(..., description="The highly optimized, keyword-dense search query.")


rewrite_node_prompt = ChatPromptTemplate(
    [
        ("system", "You are an expert Search Query Optimizer. \nThe user previously searched a vector database, but the retrieval failed to find relevant context.\n\nYour task is to rewrite the query to be highly effective for semantic search.\nStrip away conversational filler (e.g., \"Can you tell me about...\"). Extract core technical keywords and expand known AI acronyms to maximize the chance of a database match."),
        ("human", "Original failing query: {question}"),
    ],
    input_variables = ["question"],
)

rewrite_node_llm = rewrite_node_prompt | llm_model.with_structured_output(RewrittenQuery, method="json_schema", strict=True)


result = rewrite_node_llm.invoke({"question": question})

print(type(result))
print(result.query)

In [None]:
# context checking before generation
class ContextGap(BaseModel):
    """Evaluates if there is missing foundational knowledge in the retrieved documents."""
    reasoning: str = Field(..., description="Explain if any concepts mentioned in the documents are required to answer the prompt, but are not defined.")
    has_gap: bool = Field(..., description="True if a critical concept/acronym is missing its definition. False if the context is complete.")
    missing_concept: str = Field(default="", description="If has_gap is True, provide a 2-4 word search query to find the missing concept. If False, leave empty.")


sys_msg = SystemMessage(content="")
user_msg = HumanMessage(content=f"")

context_check_node_prompt = ChatPromptTemplate(
    [
        ("system", "You are a Context Gap Analyzer. \nYour job is to read the user\'s question and the currently retrieved documents to determine if a \"Multi-Hop\" search is required.\n\nLook for unexplained technical acronyms or foundational concepts that are mentioned in the text as part of the answer, but are not actually explained. \nIf explaining that concept is necessary to fully answer the user\'s question, flag a context gap and provide a short search query to retrieve that specific definition."),
        ("human", "User Question: {question}\n\nCurrently Retrieved Documents:\n{documents}"),
    ],
    input_variables = ["question", "documents"],
)

context_check_node_llm = context_check_node_prompt | llm_model.with_structured_output(ContextGap, method="json_schema", strict=True)

result = context_check_node_llm.invoke({"question": question, "documents": documents})

print(type(result))
print(result.has_gap)
print(result.missing_concept)


In [None]:
# generate node
generate_node_prompt = ChatPromptTemplate(
    [
        ("system", "You are a highly technical AI Research Assistant.\nAnswer the user\'s question strictly using the provided context documents.\n\nRules:\n1. If the answer is not contained within the context documents, state: \"I cannot answer this based on the provided context.\"\n2. Do not use outside knowledge or hallucinate facts.\n3. Keep the answer concise, professional, and directly address the user\'s core question.\n4. When applicable, briefly cite the document or paper title you are referencing."),
        ("human", "Context Documents:\n{documents}\n\n---------------------\nUser Question: {question}"),
    ],
    input_variables = ["documents", "question"],
)

generate_node_llm = generate_node_prompt | llm_model | StrOutputParser()

result = generate_node_llm.invoke({"question": question, "documents": documents})

print(result)
generation = ""

In [None]:
# hallucination checker
class HallucinationScore(BaseModel):
    """Checks if the generated answer contains hallucinations."""
    reasoning: str = Field(..., description="Briefly compare the generated facts against the source documents.")
    is_grounded: bool = Field(..., description="True if all facts in the generation are present in the documents. False if any outside info was added.")

hallucination_check_node_prompt = ChatPromptTemplate(
    [
        ("system", "You are a strict grading assistant executing a Self-RAG critique.\nYour task is to evaluate if an AI-generated answer is completely grounded in the provided source documents.\n\nA generation is ungrounded (hallucinated) if it includes statistics, methodology names, or factual claims that do not appear anywhere in the source text."),
        ("human", "Source Documents:\n{documents}\n\n---------------------\nAI Generation:\n{generation}"),
    ],
    input_variables = ["documents", "generation"],
)

hallucination_check_node_llm = hallucination_check_node_prompt | llm_model.with_structured_output(HallucinationScore, method="json_schema", strict=True)

result = hallucination_check_node_llm.invoke({"documents": documents, "generation": generation})

print(type(result))
print(result.is_grounded)

In [None]:
# relevance checker
class RelevanceScore(BaseModel):
    """Checks if the generated answer actually addresses the user's question."""
    reasoning: str = Field(..., description="Explain how the generated text does or does not answer the core premise of the user's question.")
    is_relevant: bool = Field(..., description="True if the answer directly addresses the question. False if it is evasive or off-topic.")

relevance_check_node_prompt = ChatPromptTemplate(
    [
        ("system", "You are a strict grading assistant. \nYour task is to evaluate if an AI-generated answer successfully and directly addresses the user\'s original question.\n\nFail the generation if it just summarizes the documents without answering the specific user prompt, or if it is overly evasive."),
        ("human", "User Question: {question}\n\n---------------------\nAI Generation: {generation}"),
    ],
    input_variables = ["question", "generation"],
)

relevance_check_node_llm = relevance_check_node_prompt | llm_model.with_structured_output(RelevanceScore, method="json_schema", strict=True)

result = relevance_check_node_llm.invoke({"question": question, "generation": generation})

print(type(result))
print(result.is_relevant)

In [None]:
# graph state
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        documents: list of documents
        loop_count: number to track loops
        gen_retries: number of generates for hallucinations
    """
    question: str
    generation: str
    documents: Annotated[list, operator.add]
    missing_concepts: str
    loop_count: int
    gen_retries: int

In [None]:
# Nodes
def retrieve(state):
    """
    retrieve documents from the index
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): retrieved documents added to the state
    """
    
    print("\n---RETRIEVE---")
    question = state["question"]
    loop_count = state["loop_count"]

    loop_count += 1
    print(f"incrementing loop_count: {loop_count}\n")

    print(f"Question: {question}")
    print(f"Documents: {documents}")
    documents = hybrid_search(database, embedding_model, question, sparse_weight=0.7, dense_weight=1, limit=20)
    return {"documents": documents, "question": question, "loop_count": loop_count}

def rerank_documents(state):
    """
    rerank the documents using the rerank model, apply a filter for scores and keeps only the top k
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): reranks documents and select top k documents
    """
    print("\n---RERANK---")

    question = state["question"]
    docs = state["documents"]
    loop_count = state["loop_count"]

    question_and_docs = [[question, doc["text"]] for doc in docs]
    scores = rerank_model.compute_score(question_and_docs, normalize=True)
    filtered_pairs = ((doc, score) for doc, score in zip(docs, scores) if score > 0.5)  # threshold for filter 0.5
    reranked_docs = heapq.nlargest(5, filtered_pairs, key=lambda x: x[1])
    documents = [doc for doc, score in reranked_docs]
    print(f"Reranked document: {documents}")
    return {"documents": documents, "question": question, "loop_count": loop_count}

def generate(state):
    """
    generate answer for query
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): retrieved documents added to the state
    """
    print("\n---GENERATE---")

    question = state["question"]
    documents = state["documents"]
    loop_count = state["loop_count"]

    prompt = {
        "system": "",
        "user"  : f""
    }
    completion = llm_model.chat.completions.create(
        model="Qwen/Qwen3-8B-AWQ",
        messages=[
            {"role": "system", "content": prompt["system"]},
            {"role": "user", "content": prompt["user"]},
        ],
    )
    generation = completion.choices[0].message

    
    return {"documents": documents, "question": question, "generation": generation, "loop_count": loop_count}

def rewrite_query(state):
    """
    rewrite query for better question
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): retrieved documents added to the state
    """
    print("\n---REWRITE_QUERY---")

    question = state["question"]
    documents = state["documents"]
    loop_count = state["loop_count"]

    prompt = {
        "system": "",
        "user"  : ""
    }
    completion = llm_model.chat.completions.create(
        model="Qwen/Qwen3-8B-AWQ",
        messages=[
            {"role": "system", "content": prompt["system"]},
            {"role": "user", "content": prompt["user"]},
        ],
    )
    rewritten_query = completion.choices[0].message
    
    return {"documents": documents, "question": rewritten_query, "loop_count": loop_count}

def web_search(state):
    """
    web search for query to find relevant context
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): retrieved documents added to the state
    """
    
    return state

In [None]:
# Edges

def query_router(state):
    """
    [Adaptive RAG]: routes the query to vectorbase or simple llm call or more
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): retrieved documents added to the state
    """
    question = state["question"]
    result = router_node_llm.invoke({"question": question})
    
    return state

def check_context(state):
    """
    check relevance before generation, else retrieve or get more info about missing elements
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): reranks documents and select top k documents
    """
    question = state["question"]
    documents = state["documents"]
    loop_count = state["loop_count"]
    gen_retries = state.get("gen_retries", 0)

    print("\n---CHECKING CONTEXT OF DOCUMENTS---")
    
    return 

def rerank_router(state):
    """
    [Adaptive RAG]: routes the query to vectorbase or simple llm call or more
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): retrieved documents added to the state
    """
    
    return state

def rewrite_router(state):
    """
    [Adaptive RAG]: routes the query to vectorbase or simple llm call or more
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): retrieved documents added to the state
    """
    
    return state

def hallucinations_and_relevence_router(state):
    """
    [Adaptive RAG]: routes the query to vectorbase or simple llm call or more
    
    Args:
        state (dict): Current graph state

    Returns:
        state (dict): retrieved documents added to the state
    """
    
    return state