In [1]:
import json
import warnings
from typing import Any, Literal

import numpy as np
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "white": "#FFFFFF",  # Bright white
        "info": "#00FF00",  # Bright green
        "warning": "#FFD700",  # Bright gold
        "error": "#FF1493",  # Deep pink
        "success": "#00FFFF",  # Cyan
        "highlight": "#FF4500",  # Orange-red
    }
)
console = Console(theme=custom_theme)

# Visualization
# import matplotlib.pyplot as plt

# NumPy settings
np.set_printoptions(precision=4)

# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Polars settings
pl.Config.set_fmt_str_lengths(1_000)
pl.Config.set_tbl_cols(n=1_000)
pl.Config.set_tbl_rows(n=200)

warnings.filterwarnings("ignore")

# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [2]:
def go_up_from_current_directory(*, go_up: int = 1) -> None:
    """This is used to up a number of directories.

    Params:
    -------
    go_up: int, default=1
        This indicates the number of times to go back up from the current directory.

    Returns:
    --------
    None
    """
    import os
    import sys

    CONST: str = "../"
    NUM: str = CONST * go_up

    # Goto the previous directory
    prev_directory = os.path.join(os.path.dirname(__name__), NUM)
    # Get the 'absolute path' of the previous directory
    abs_path_prev_directory = os.path.abspath(prev_directory)

    # Add the path to the System paths
    sys.path.insert(0, abs_path_prev_directory)
    print(abs_path_prev_directory)


# Demo (Prevents ruff from removing the unused module import)
name: Any
category: Literal["A", "B", "C"]
json.loads('{"name": "Smart-RAG", "version": "1.0"}')

{'name': 'Smart-RAG', 'version': '1.0'}

In [3]:
go_up_from_current_directory(go_up=1)

from src.config import app_config, app_settings  # noqa: E402
from src.utilities.model_config import RemoteModel  # noqa: E402

settings = app_settings

/Users/mac/Desktop/Projects/smart-rag


In [4]:
from src.utilities.vectorstores import ai_vectorstore

ai_vectorstore

2025-10-25 19:53:33 - vectorstores - [INFO] - AI news filepath: /Users/mac/Desktop/Projects/smart-rag/data/ai_news
2025-10-25 19:53:34 - vectorstores - [INFO] - Loaded 43 documents from 5 filepaths.
2025-10-25 19:53:36 - vectorstores - [INFO] - Qdrant vector store set up with collection 'ai_news' and vector size 768
2025-10-25 19:53:40 - vectorstores - [INFO] - Embedded and stored 168 documents.
2025-10-25 19:53:40 - vectorstores - [INFO] - AI news vector store setup complete.
2025-10-25 19:53:44 - vectorstores - [INFO] - Loaded 174 documents from 7 filepaths.
2025-10-25 19:53:44 - vectorstores - [INFO] - Qdrant vector store set up with collection 'football_news' and vector size 768
2025-10-25 19:53:54 - vectorstores - [INFO] - Embedded and stored 567 documents.
2025-10-25 19:53:54 - vectorstores - [INFO] - Football news vector store setup complete.


<langchain_qdrant.qdrant.QdrantVectorStore at 0x1502b20f0>

# Adaptive RAG

Adaptive RAG is an advanced retrieval-augmented generation (RAG) architecture that intelligently combines **traditional RAG techniques** with **self-reflection** and **external tool usage** to enhance the quality and reliability of generated answers.

<br>

[![image.png](https://i.postimg.cc/65nZVgN9/image.png)](https://postimg.cc/ZCY04fRg)

[Source: LangChain](https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_adaptive_rag/)

<br>

### Breakdown of the above diagram

#### 1.) Query Analysis (Red Box)

- The user's Question first goes to a Query Analysis step (likely an LLM prompt or classifier).
- It determines if the question is [related to index] (the internal knowledge base) or [unrelated to index] (requiring external tools).

#### 2.) RAG + Self-Reflection (Top Dashed Box)

- If related to the index, the query proceeds through the RAG workflow.
- Retrieve & Grade: The system fetches documents (Retrieve Node), and then an LLM agent grades their relevance.
- Decision Loop:
  - If the documents are relevant, it proceeds to Generate (Node) and then checks for Hallucinations?
  - If the answer is free of hallucinations and Answers question? successfully, the process stops with the final Answer.
  - If the documents are not relevant or the generated answer fails the self-reflection checks, the question is sent to the Re-write question (Node) and the process loops back to Retrieve. This allows the agent to iteratively improve its search query.

#### 3.) Tool Use (Bottom Green Path)

- If the original Query Analysis determined the question was [unrelated to index], it activates the Web search tool, generates the answer using the web results, and provides the Answer w/ web search.

In [None]:
from langchain_openai import ChatOpenAI

remote_llm = ChatOpenAI(
    api_key=settings.OPENROUTER_API_KEY.get_secret_value(),  # type: ignore
    base_url=settings.OPENROUTER_URL,
    temperature=0.0,
    model=RemoteModel.GEMINI_2_0_FLASH_001,
)


# Test the LLMs
response = remote_llm.invoke("Tell me a very short joke.")
response.pretty_print()

### Load documents


In [None]:
from glob import glob

from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents.base import Document


def load_pdf_doc(filepath: str, engine: str = "pypdfloader") -> list[Document]:
    """This is used to load a single document."""
    if engine == "pypdfloader":
        loader = PyPDFLoader(filepath)
    docs: list[Document] = loader.load()
    return docs


# Load a single doc
fp: str = "../data/ai_news/Inside_San_Francisco’s_AI_school.pdf"
docs = load_pdf_doc(filepath=fp)
docs

In [None]:
# Load multiple docs
ai_filepaths: list[str] = glob("../data/ai_news/*.pdf")
docs_ai = [doc for fp in ai_filepaths for doc in load_pdf_doc(filepath=fp)]
print(len(docs_ai))

football_filepaths: list[str] = glob("../data/football_news/*.pdf")
docs_football = [doc for fp in football_filepaths for doc in load_pdf_doc(filepath=fp)]
print(len(docs_football))

In [None]:
docs_ai[:5]

In [None]:
console.print(docs_ai[0])

### Set up document embeddings

In [None]:
import os
from typing import Any, List

import together
from langchain_core.embeddings import Embeddings
from langchain_core.utils import convert_to_secret_str
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    SecretStr,
    model_validator,
)


def set_together_api(value: str | None = None) -> SecretStr:
    """Set the Together API key"""
    if value is None:
        return convert_to_secret_str(os.getenv("TOGETHER_API_KEY", ""))
    return convert_to_secret_str(value)


class TogetherEmbeddings(BaseModel, Embeddings):
    """Using Field with default_factory for automatic client creation."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    client: together.Together = Field(default_factory=together.Together)
    together_api_key: SecretStr = Field(default_factory=lambda: set_together_api)
    model: str = Field(default="togethercomputer/m2-bert-80M-32k-retrieval")

    @model_validator(mode="before")
    @classmethod
    def validate_environment(cls, values: dict[str, Any]) -> dict[str, Any]:
        """Set up the Together API key and client before model instantiation."""
        # Handle API key setup
        api_key = values.get("together_api_key") or os.getenv("TOGETHER_API_KEY", "")
        if isinstance(api_key, str):
            api_key = set_together_api(api_key)
        values["together_api_key"] = api_key
        values["client"] = together.Together()

        # Set global API key
        together.api_key = api_key.get_secret_value()

        return values

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed search docs."""
        return [
            i.embedding
            for i in self.client.embeddings.create(input=texts, model=self.model).data
        ]  # type: ignore

    def embed_query(self, text: str) -> List[float]:
        """Embed query text."""
        return self.embed_documents([text])[0]

In [None]:
# Test with a known working Together AI model
embeddings = TogetherEmbeddings(
    model="BAAI/bge-base-en-v1.5",  # Using known working model
    together_api_key=settings.TOGETHER_API_KEY,
)

# Test the embedding
try:
    test_text = "This is a test embedding"
    result = embeddings.embed_query(test_text)
    console.print(f"✅ Embedding successful! Dimension: {len(result)}", style="success")
    console.print(f"First 5 values: {result[:5]}", style="info")

except Exception as e:
    console.print(f"❌ Embedding failed: {e}", style="error")

In [None]:
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

client = QdrantClient(":memory:")

vector_size = len(embeddings.embed_query("sample text"))
collection_name_ai = "ai_news"

if not client.collection_exists(collection_name_ai):
    client.create_collection(
        collection_name=collection_name_ai,
        vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
    )
vectorstore_ai = QdrantVectorStore(
    client=client,
    collection_name=collection_name_ai,
    embedding=embeddings,
)


collection_name_football = "football_news"

if not client.collection_exists(collection_name_football):
    client.create_collection(
        collection_name=collection_name_football,
        vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
    )
vectorstore_football = QdrantVectorStore(
    client=client,
    collection_name=collection_name_football,
    embedding=embeddings,
)

### Split Documents Into Chunks

In [None]:
from langchain_text_splitters import RecursiveCharacterTextSplitter

text_splitter_ai = RecursiveCharacterTextSplitter(
    chunk_size=500,  # chunk size (characters)
    chunk_overlap=50,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)
all_splits_ai = text_splitter_ai.split_documents(docs_ai)

print(f"Split into {len(all_splits_ai)} sub-documents.")


text_splitter_football = RecursiveCharacterTextSplitter(
    chunk_size=500,  # chunk size (characters)
    chunk_overlap=50,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)
all_splits_football = text_splitter_football.split_documents(docs_football)

print(f"Split into {len(all_splits_football)} sub-documents.")

### Embed The Document Chunks And Add to Vector Store

In [None]:
document_ids: list[str] = vectorstore_ai.add_documents(documents=all_splits_ai)
print(document_ids[:3])

document_ids: list[str] = vectorstore_football.add_documents(
    documents=all_splits_football
)

<br>

### Alternative Method

- Fewer lines of code

In [None]:
RUN: bool = False

if RUN:
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,  # chunk size (characters)
        chunk_overlap=100,  # chunk overlap (characters)
        add_start_index=True,  # track index in original document
    )

    all_splits = text_splitter.split_documents(docs)

    vector_store = QdrantVectorStore.from_documents(
        documents=all_splits,
        embedding=embeddings,
        location=":memory:",  # for in-memory Qdrant instance
        # OR specify url and api_key for external Qdrant instance
        # url="http://localhost:6333",  # assuming local Qdrant server is running
        # api_key="your_qdrant_api_key",  # if needed (for cloud instances)
        collection_name="test",
    )

In [None]:
c_model: RemoteModel = RemoteModel(
    app_config.llm_model_config.classifier_model.model_name
)


# from src.utilities.vectorstores import VectorStoreSetup

# my_vectorstore_setup = VectorStoreSetup(
#     collection_name="test_collection",
#     filepaths=ai_filepaths,
#     embeddings=embeddings,
#     client=client,
# )
# my_vectorstore_setup.setup()
# vectorstore_instance = my_vectorstore_setup.embed_and_store()

### Test Retrieval from Vector Store

In [None]:
query: str = "Has Nvidia broken any laws?"

retrieved_docs = vectorstore_ai.similarity_search(query, k=2)
formatted_docs: str = "\n\n".join(
    (f"Source: {doc.metadata}\nContent: {doc.page_content}") for doc in retrieved_docs
)

console.print(formatted_docs)

### Convert Retriever To A Tool

- I eventually did NOT use this approach in the final implementation, but it's good to know how to do it.
- This is because when the agent uses the tool, it returns the output as a string, which makes it difficult to pass the retrieved documents to the LLM call node for further processing.


In [None]:
from langchain.tools import tool
from langchain_core.documents.base import Document


@tool(response_format="content_and_artifact")
def retriever_ai_tool(query: str) -> tuple[str, list[Document]]:
    """Retrieve information related to AI news to help answer a query.
    AI news include ANY info relating to: nvidia, openai, google, chinese tech news, etc.

    Parameters:
    -----------
    query: str
        The search query to retrieve relevant documents.

    Returns:
    --------
    tuple[str, list[Document]]
        A tuple containing formatted string of retrieved documents and the list of Document objects.
    """
    retrieved_docs = vectorstore_ai.similarity_search(query, k=2)
    formatted_docs: str = "\n\n".join(
        (f"Source: {doc.metadata}\nContent: {doc.page_content}")
        for doc in retrieved_docs
    )
    return formatted_docs, retrieved_docs


@tool(response_format="content_and_artifact")
def retriever_football_tool(query: str) -> tuple[str, list[Document]]:
    """Retrieve information related to football news to help answer a query.

    Parameters:
    -----------
    query: str
        The search query to retrieve relevant documents.

    Returns:
    --------
    tuple[str, list[Document]]
        A tuple containing formatted string of retrieved documents and the list of Document objects.
    """
    retrieved_docs = vectorstore_football.similarity_search(query, k=2)
    formatted_docs: str = "\n\n".join(
        (f"Source: {doc.metadata}\nContent: {doc.page_content}")
        for doc in retrieved_docs
    )
    return formatted_docs, retrieved_docs

In [None]:
llm_with_tools = remote_llm.bind_tools([retriever_ai_tool, retriever_football_tool])
response = await llm_with_tools.ainvoke("Has Nvidia broken any laws?")

response.tool_calls

In [None]:
response

#### Create Additional Tools

In [None]:
from langchain_tavily import TavilySearch

tavily_search = TavilySearch(
    api_key=settings.TAVILY_API_KEY.get_secret_value(),
    max_results=2,
    topic="general",
)
search_response = tavily_search.invoke({"query": "Has Nvidia broken any laws?"})

In [None]:
search_response["results"][0]["content"]

In [None]:
search_response

In [None]:
@tool(response_format="content")
async def search_tool(query: str, max_chars: int = 500) -> str:
    """Perform a search using TavilySearch tool.

    Parameters:
    -----------
    query: str
        The search query.
    max_chars: int, default=1000
        The maximum number of characters per source to return from the search results.

    Returns:
    --------
    str
        The formatted search results.
    """
    separator: str = "\n\n"

    tavily_search = TavilySearch(
        api_key=settings.TAVILY_API_KEY.get_secret_value(),
        max_results=3,
        topic="general",
    )
    search_response = await tavily_search.ainvoke({"query": query})
    formatted_results: str = "\n\n".join(
        f"Title: {result['title']}\nContent: {result['content'][:max_chars]} [truncated]\nURL: {result['url']}{separator}"
        for result in search_response["results"]
    )
    return formatted_results

In [None]:
console.print(search_tool)

In [None]:
response = await search_tool.coroutine("who is pope leo?")
console.print(response)

In [None]:
search_response = await search_tool.coroutine("Has Nvidia broken any laws?")
console.print(search_response)

In [None]:
from enum import Enum

from langchain_core.messages import (
    HumanMessage,
    SystemMessage,
)


# ==================================================================
# ============================= TYPES ==============================
# ==================================================================
class YesOrNo(str, Enum):
    YES = "yes"
    NO = "no"


class DataSource(str, Enum):
    VECTORSTORE = "vectorstore"
    WEBSEARCH = "websearch"


class VectorSearchType(str, Enum):
    FOOTBALL = "football news"  # "(arsenal news | chelsea news | liverpool news)"
    AI = "ai news"  # "(ai news | ai browser | nvidia | openai| tech in china)"


# ==================================================================
# ============================ SCHEMAS =============================
# ==================================================================
class RouteQuerySchema(BaseModel):
    """Route query model."""

    data_source: DataSource = Field(description="The data source to use for the query.")


class VectorSearchTypeSchema(BaseModel):
    """Vector search type model."""

    vector_search_type: VectorSearchType = Field(
        description="The vector search type to use for the query."
    )


class GradeRetrievalSchema(BaseModel):
    """Grade retrieval model."""

    is_relevant: YesOrNo = Field(
        description="Whether the retrieved documents are relevant to the user query."
    )


class GradeResponseSchema(BaseModel):
    """Grade response model."""

    is_relevant: YesOrNo = Field(
        description="Whether the response is relevant to the user query."
    )


class HallucinationSchema(BaseModel):
    """Check hallucination model."""

    is_hallucinating: YesOrNo = Field(
        description="Whether the response contains hallucinations."
    )


topics: list[str] = [_topic.value for _topic in VectorSearchType]
valid_output: list[str] = [_typ.value for _typ in YesOrNo]

In [None]:
# ==================================================================
# ============================ PROMPTS =============================
# ==================================================================
query_analysis_prompt: str = """
<SYSTEM>
    <ROLE>
        You're an expert at determining whether a user query requires information from a vector store or a web search.
    </ROLE>
    <TOPICS>{topics}</TOPICS>

    <GUIDELINES>
    - If the query is related to the topics above, choose 'vectorstore'.
    - If the query is not covered by the topics above, choose 'websearch'.
    - Base your decision solely on the content of the query.
    </GUIDELINES>
</SYSTEM> 
"""

retrieval_grading_prompt: str = """
<SYSTEM>
    <ROLE>
        You're an expert at determining whether the retrieved documents from a vector store is relevant to the user query.
    </ROLE>
    <VALID_OUTPUT>{valid_output}</VALID_OUTPUT>

    <GUIDELINES>
    - If the documents are relevant to the user query, choose 'yes'.
    - If the documents are not relevant to the user query, choose 'no'.
    </GUIDELINES>

</SYSTEM> 
"""

query_n_retrieved_docs_prompt: str = """
<QUERY>{query}</QUERY>
<RETRIEVED_DOCUMENTS>{retrieved_documents}</RETRIEVED_DOCUMENTS>

Are the retrieved documents relevant to the user query?
"""

query_n_response_prompt: str = """
<QUERY>{query}</QUERY>
<RESPONSE>{response}</RESPONSE>

Is the response relevant to the user query?
"""

rag_response_generator_prompt: str = """
<ROLE>
    You're an expert at generating accurate and concise answers to user queries based on retrieved documents.
</ROLE>

    <QUERY>{query}</QUERY>
    <RETRIEVED_DOCUMENTS>{retrieved_documents}</RETRIEVED_DOCUMENTS>

    <GUIDELINES>
    - Limit your summary to a maximum of 5 sentences.
    - Use only the information provided in the retrieved documents.
    </GUIDELINES>
"""

hallucination_prompt: str = """
<SYSTEM>
    <ROLE>
        You're an expert at determining whether the generated response is accurate and relevant to the user query.
    </ROLE>
    <VALID_OUTPUT>{valid_output}</VALID_OUTPUT>

    <GUIDELINES>
    - If the response is NOT relevant to the user query, choose 'yes'.
    - If the response is relevant to the user query, choose 'no'.
    </GUIDELINES>

</SYSTEM> 
"""

query_rewriter_prompt: str = """
<ROLE>
    You're an expert at rewriting user queries to improve vector search retrieval.
</ROLE>

<ORIGINAL_QUERY>{original_query}</ORIGINAL_QUERY>

<GUIDELINES>
- Rewrite the query to be more specific and clear.
- Ensure the rewritten query captures the user's intent accurately.
- There must be no preamble, just the single rewritten query.
</GUIDELINES>
"""

websearch_prompt: str = """
<SYSTEM>
    <ROLE>
        You are an expert assistant specialized in generating a concise summary of web search results.
    </ROLE>

    <GUIDELINES>
    - Summarize the search results accurately and concisely.
    - Limit your summary to a maximum of 5 sentences.
    </GUIDELINES>

</SYSTEM>
"""

vectorstore_routing_prompt = """
<INSTR>
    Analyze this query and determine which retriever to use.
    <QUERY>{query}</QUERY>
</INSTR>
    """

In [None]:
topics

In [None]:
# ==================================================================
# ============================= TOOLS ==============================
# ==================================================================
@tool(response_format="content")
async def search_tool(query: str, max_chars: int = 500) -> str:
    """Perform a search using TavilySearch tool.

    Parameters:
    -----------
    query: str
        The search query.
    max_chars: int, default=500
        The maximum number of characters per source to return from the search results.

    Returns:
    --------
    str
        The formatted search results.
    """
    separator: str = "\n\n"

    tavily_search = TavilySearch(
        api_key=settings.TAVILY_API_KEY.get_secret_value(),
        max_results=3,
        topic="general",
    )
    search_response = await tavily_search.ainvoke({"query": query})
    formatted_results: str = "\n\n".join(
        f"Title: {result['title']}\nContent: {result['content'][:max_chars]} [truncated]\nURL: {result['url']}{separator}"
        for result in search_response["results"]
    )
    return formatted_results


tool_names = ["retriever_ai_tool", "retriever_football_tool"]
tool_names

<br>

### Define Workflow

In [None]:
# vectorstore_football.as_retriever(search_kwargs={"k": 5}).invoke("Any news about Caicedo's contract situation?")

### Structured Output

- For structured output, I decided to use `Instructor` because it performs better than Langchain's built-in Pydantic output parser in my tests.

In [None]:
import instructor
from langchain_core.messages import AIMessage
from langsmith import traceable
from openai import AsyncOpenAI

_async_client = AsyncOpenAI(
    api_key=settings.OPENROUTER_API_KEY.get_secret_value(),
    base_url=settings.OPENROUTER_URL,
)

aclient = instructor.from_openai(
    _async_client, mode=instructor.Mode.OPENROUTER_STRUCTURED_OUTPUTS
)

type PydanticModel = type[BaseModel]


@traceable
async def get_structured_output(
    messages: list[dict[str, Any]],
    model: RemoteModel,
    schema: PydanticModel,
) -> BaseModel:
    """
    Retrieves structured output from a chat completion model.

    Parameters
    ----------
    messages : list[dict[str, Any]]
        The list of messages to send to the model for the chat completion.
    model : RemoteModel
        The remote model to use for the chat completion (e.g., 'gpt-4o').
    schema : PydanticModel
        The Pydantic schema to enforce for the structured output.

    Returns
    -------
    BaseModel
        An instance of the provided Pydantic schema containing the structured output.

    Notes
    -----
    This is an asynchronous function that awaits the completion of the API call.
    """
    return await aclient.chat.completions.create(
        model=model,
        response_model=schema,
        messages=messages,
        max_retries=5,
    )


def convert_langchain_messages_to_dicts(
    messages: list[HumanMessage | SystemMessage | AIMessage],
) -> list[dict[str, str]]:
    """Convert LangChain messages to a list of dictionaries.

    Parameters
    ----------
    messages : list[HumanMessage | SystemMessage | AIMessage]
        List of LangChain message objects to convert.

    Returns
    -------
    list[dict[str, str]]
        List of dictionaries with 'role' and 'content' keys.
        Roles are mapped as follows:
        - HumanMessage -> "user"
        - SystemMessage -> "system"
        - AIMessage -> "assistant"

    """
    role_mapping: dict[str, str] = {
        "SystemMessage": "system",
        "HumanMessage": "user",
        "AIMessage": "assistant",
    }

    converted_messages: list[dict[str, str]] = []
    for msg in messages:
        message_type: str = msg.__class__.__name__
        role: str = role_mapping.get(
            message_type, "user"
        )  # Default to "user" if unknown
        converted_messages.append({"role": role, "content": msg.content})

    return converted_messages

In [None]:
class Person(BaseModel):
    fullname: str
    salary: float
    exeprience: int


await get_structured_output(
    messages=[
        {
            "role": "user",
            "content": "Neidu Emmanuel, earning 30,000 has 4 years of experience",
        }
    ],
    model=RemoteModel.LLAMA_3_3_70B_INSTRUCT,
    schema=Person,
)

In [None]:
import operator as op
from typing import Annotated, TypedDict


# ==================================================================
# ================== CUSTOM REDUCER FOR DICT =======================
# ==================================================================
def merge_dicts(existing: dict[str, Any], new: dict[str, Any]) -> dict[str, Any]:
    """Merge two dictionaries, with new values updating existing ones."""
    if existing is None:
        return new
    # Update existing dict with new dict values
    return {**existing, **new}


# ==================================================================
# ============================= STATE ==============================
# ==================================================================
class OtherInfo(TypedDict):
    source_type: str
    retrieval_relevance: str
    is_hallucinating: str
    rewritten_query: str


class State(TypedDict):
    query: str
    messages: Annotated[list[str], op.add]
    runs: int
    other_info: Annotated[dict[str, Any], merge_dicts]  # Use custom merger
    documents: list[Document]
    response: str


# ==================================================================
# ============================= NODES ==============================
# ==================================================================
classifier_model: RemoteModel = RemoteModel.GPT_OSS_20B


async def classify_query_node(state: State) -> dict[str, Any]:
    """Classify the user query to determine the data source to use."""
    print("Calling ===> classify_query_node <===")

    query = state.get("query")
    sys_msg = SystemMessage(content=query_analysis_prompt.format(topics=topics))
    messages = convert_langchain_messages_to_dicts(
        [sys_msg, HumanMessage(content=query)]
    )
    query_type: RouteQuerySchema = await get_structured_output(
        messages=messages,
        model=classifier_model,
        schema=RouteQuerySchema,
    )

    print(f"✅ Classified query to use data source: {query_type.data_source.value}")
    return {"other_info": {"source_type": query_type.data_source.value}}


async def llm_call_node(state: State) -> dict[str, Any]:
    print("Calling ===> llm_call_node <===")

    query = state.get("query")
    if not query and "messages" in state:
        messages = state["messages"]
        query = messages[-1] if isinstance(messages, list) else messages

    llm_with_tools = remote_llm.bind_tools([search_tool])
    response = await llm_with_tools.ainvoke(query)

    return {
        "query": query,
        # Messages key is the default key for tools
        "messages": [response],
    }


async def generate_web_search_response(state: State) -> dict[str, Any]:
    print("Calling ===> generate_web_search_response <===")

    message: str = state.get("messages", [])[-1].content
    if not message:
        return {
            "response": "I couldn't find relevant information to answer your query."
        }
    sys_msg = SystemMessage(content=websearch_prompt)
    prompt: str = f"SEARCH RESULTS:\n{message}"

    response = await remote_llm.ainvoke([sys_msg, HumanMessage(content=prompt)])

    return {
        "query": query,
        "response": response.content,
    }


async def retrieve_documents(state: State) -> dict[str, Any]:
    """Retrieve documents by intelligently selecting the appropriate retriever."""
    max_chars: int = 1_000
    print("Calling ===> retrieve_documents <===")

    query = state.get("query")
    prompt: str = vectorstore_routing_prompt.format(query=query)

    user_msg = {"role": "user", "content": prompt}

    retriever_choice = await get_structured_output(
        messages=[user_msg],
        model=classifier_model,
        schema=VectorSearchTypeSchema,
    )
    retriever_choice: str = retriever_choice.vector_search_type.value

    print(f"✅ Retriever choice: {retriever_choice}")

    # Retrieve documents based on the routing decision
    if retriever_choice == VectorSearchType.FOOTBALL.value:
        retrieved_docs = vectorstore_football.similarity_search(query, k=3)
        print(f"✅ Used football retriever, found {len(retrieved_docs)} documents")
    elif retriever_choice == VectorSearchType.AI.value:
        retrieved_docs = vectorstore_ai.similarity_search(query, k=3)
        print(f"✅ Used AI retriever, found {len(retrieved_docs)} documents")
    else:
        return {"response": "I couldn't find the vectorstore to answer your query."}

    # Format documents for message display
    formatted_docs = "\n\n".join(
        f"Source: {doc.metadata.get('source', 'Unknown')}\nContent: {doc.page_content[:max_chars]} [truncated]\n"
        for doc in retrieved_docs
    )

    return {
        "query": query,
        "documents": retrieved_docs,
        "messages": [f"Retrieved {len(retrieved_docs)} documents:\n{formatted_docs}"],
    }


async def grade_documents(state: State) -> dict[str, Any]:
    """Grade the relevance of retrieved documents."""
    print("Calling ===> grade_documents <===")

    query = state.get("query")
    documents = state.get("documents", [])

    if not documents:
        print("⚠️ No documents to grade")
        return {"other_info": {"retrieval_relevance": YesOrNo.NO.value}}

    # Grade each document
    relevant_docs: list[Document] = []
    for doc in documents:
        doc_content = f"Source: {doc.metadata}\nContent: {doc.page_content}"

        sys_msg = SystemMessage(
            content=retrieval_grading_prompt.format(
                valid_output=valid_output, retrieved_documents=doc_content
            )
        )
        grading_query = query_n_retrieved_docs_prompt.format(
            query=query, retrieved_documents=doc_content
        )

        messages = convert_langchain_messages_to_dicts(
            [sys_msg, HumanMessage(content=grading_query)]
        )
        grade: GradeRetrievalSchema = await get_structured_output(
            messages=messages,
            model=classifier_model,
            schema=GradeRetrievalSchema,
        )

        if grade.is_relevant.value == YesOrNo.YES.value:
            relevant_docs.append(doc)

    print(f"✅ Graded documents: {len(relevant_docs)}/{len(documents)} relevant")

    return {
        "documents": relevant_docs,
        "other_info": {
            "retrieval_relevance": (
                YesOrNo.YES.value if relevant_docs else YesOrNo.NO.value
            )
        },
    }


def should_continue_to_retrieve(state: State) -> Literal["retrieve", "web_search"]:
    source_type = state.get("other_info", {}).get("source_type", DataSource.WEBSEARCH)

    if source_type == DataSource.VECTORSTORE.value:
        return "retrieve"
    return "web_search"


def should_continue_to_generate(
    state: State,
) -> Literal["generate", "rewrite", "failed"]:  # type: ignore
    relevance = state.get("other_info", {}).get("retrieval_relevance", YesOrNo.NO)
    runs: int = state.get("runs", 0)

    if runs <= 3:
        if relevance == YesOrNo.YES.value:
            return "generate"
        return "rewrite"

    return "failed"


async def generate_response(state: State) -> dict[str, Any]:
    """Generate response based on retrieved documents."""
    print("Calling ===> generate_response <===")

    query = state.get("query")
    documents = state.get("documents", [])

    if not documents:
        return {
            "response": "I couldn't find relevant information to answer your query."
        }

    if documents:
        # Format documents for the prompt
        formatted_docs = "\n\n".join(
            f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(documents)
        )

    prompt = rag_response_generator_prompt.format(
        query=query, retrieved_documents=formatted_docs
    )

    response = await remote_llm.ainvoke(prompt)

    return {"response": response.content}


async def check_hallucination_node(state: State) -> dict[str, Any]:
    """Check if the generated response contains hallucinations."""
    print("Calling ===> check_hallucination_node <===")

    query = state.get("query")
    runs: int = state.get("runs", 0)

    response = state.get("response")

    sys_msg = SystemMessage(
        content=hallucination_prompt.format(valid_output=valid_output)
    )
    check_query = query_n_response_prompt.format(query=query, response=response)

    messages = convert_langchain_messages_to_dicts(
        [sys_msg, HumanMessage(content=check_query)]
    )
    result: HallucinationSchema = await get_structured_output(
        messages=messages,
        model=classifier_model,
        schema=HallucinationSchema,
    )

    print(f"✅ Hallucination check: {result.is_hallucinating.value}")

    return {
        "runs": runs + 1,
        "other_info": {"is_hallucinating": result.is_hallucinating.value},
    }


def should_continue_to_final_answer(
    state: State,
) -> Literal["answer", "rewrite", "failed"]:  # type: ignore
    is_hallucinating = state.get("other_info", {}).get("is_hallucinating", YesOrNo.YES)
    runs: int = state.get("runs", 0)

    if runs <= 3:
        if is_hallucinating == YesOrNo.NO.value:
            return "answer"
        return "rewrite"

    return "failed"


async def rewrite_query(state: State) -> dict[str, Any]:
    """Rewrite the query to improve retrieval."""
    print("Calling ===> rewrite_query <===")
    runs: int = state.get("runs", 0)

    query = state.get("query")
    prompt = query_rewriter_prompt.format(original_query=query)
    response = await remote_llm.ainvoke(prompt)

    rewritten = response.content
    print(f"Original: {query}\nRewritten: {rewritten}")
    print(f"⚠️ Runs: {runs + 1}")

    return {
        "query": query,
        "runs": runs + 1,
        "other_info": {"rewritten_query": rewritten},
    }


def failed_node(state: State) -> dict[str, Any]:
    """Finalize the answer."""
    print("Calling ===> failed_node <===")

    return {
        "response": state.get(
            "response", "I couldn't find relevant information to answer your query."
        )
    }


def answer_node(state: State) -> dict[str, Any]:
    """Finalize the answer."""
    print("Calling ===> answer_node <===")

    response = state.get(
        "response", "I couldn't find relevant information to answer your query."
    )

    return {"response": response}

In [None]:
from IPython.display import Image, Markdown, display
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.types import RetryPolicy

In [None]:
# from langgraph.prebuilt import ToolNode, tools_condition
# from langgraph.types import RetryPolicy

In [None]:
builder: StateGraph = StateGraph(State)

# Add nodes
tool_node = ToolNode([search_tool])

builder.add_node(
    "query_analysis",
    classify_query_node,
    retry_policy=RetryPolicy(max_attempts=3, initial_interval=1.0),
)
builder.add_node(
    "tools", tool_node, retry_policy=RetryPolicy(max_attempts=3, initial_interval=1.0)
)
builder.add_node("llm_call_node", llm_call_node)
builder.add_node("retrieve", retrieve_documents)
builder.add_node(
    "grade",
    grade_documents,
    retry_policy=RetryPolicy(max_attempts=3, initial_interval=1.0),
)
builder.add_node(
    "generate",
    generate_response,
    retry_policy=RetryPolicy(max_attempts=3, initial_interval=1.0),
)
builder.add_node(
    "rewrite",
    rewrite_query,
    retry_policy=RetryPolicy(max_attempts=3, initial_interval=1.0),
)
builder.add_node(
    "check_hallucination",
    check_hallucination_node,
    retry_policy=RetryPolicy(max_attempts=3, initial_interval=1.0),
)
builder.add_node(
    "web_search_response",
    generate_web_search_response,
    retry_policy=RetryPolicy(max_attempts=3, initial_interval=1.0),
)
builder.add_node("answer", answer_node)
builder.add_node("failed", failed_node)


# Build the graph query_analysis
builder.add_edge(START, "query_analysis")
builder.add_conditional_edges(
    "query_analysis",
    should_continue_to_retrieve,
    {"retrieve": "retrieve", "web_search": "llm_call_node"},
)
builder.add_conditional_edges(
    "llm_call_node",
    tools_condition,
    {"tools": "tools", END: "failed"},
)
builder.add_edge("tools", "web_search_response")
builder.add_edge("web_search_response", END)
builder.add_edge("retrieve", "grade")
builder.add_conditional_edges(
    "grade",
    should_continue_to_generate,
    {"generate": "generate", "rewrite": "rewrite", "failed": "failed"},
)
builder.add_edge("generate", "check_hallucination")
builder.add_conditional_edges(
    "check_hallucination",
    should_continue_to_final_answer,
    {"answer": "answer", "rewrite": "rewrite", "failed": "failed"},
)
builder.add_edge("answer", END)
builder.add_edge("rewrite", "retrieve")

# Compile the graph
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)

# Visualize the graph with ASCII fallback
try:
    display(Image(graph.get_graph(xray=1).draw_mermaid_png()))
except Exception as e:
    console.print(f"[yellow]PNG visualization failed: {e}[/yellow]")
    console.print("[cyan]Displaying ASCII representation instead:[/cyan]\n")
    try:
        print(graph.get_graph(xray=1).draw_ascii())
    except ImportError as ie:
        console.print(f"[red]ASCII visualization also failed: {ie}[/red]")
        console.print("[magenta]Showing basic graph structure:[/magenta]\n")
        graph_obj = graph.get_graph(xray=1)
        console.print(f"Nodes: {[node.id for node in graph_obj.nodes.values()]}")
        console.print(f"Edges: {[(e.source, e.target) for e in graph_obj.edges]}")

In [None]:
remote_llm = ChatOpenAI(
    api_key=settings.OPENROUTER_API_KEY.get_secret_value(),  # type: ignore
    base_url=settings.OPENROUTER_URL,
    temperature=0.0,
    model=RemoteModel.GPT_OSS_20B,
)


# Test the LLMs
response = remote_llm.invoke("Tell me a very short joke.")
response.pretty_print()

In [None]:
# Re-build the graph
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)

config: dict[str, Any] = {"configurable": {"thread_id": "test-01"}}
response = await graph.ainvoke(
    {"query": "Any news on Liverpool's player meetings?"},
    config=config,
)

In [None]:
response

In [None]:
Markdown(f"### Final Response:\n\n{response['response']}")

In [None]:
go_up_from_current_directory(2)