In [1]:
import json
import sys
import warnings
from pathlib import Path
from typing import Annotated, Any, Callable, Coroutine, Literal

import numpy as np
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "white": "#FFFFFF",  # Bright white
        "info": "#00FF00",  # Bright green
        "warning": "#FFD700",  # Bright gold
        "error": "#FF1493",  # Deep pink
        "success": "#00FFFF",  # Cyan
        "highlight": "#FF4500",  # Orange-red
    }
)
console = Console(theme=custom_theme)

# Visualization
# import matplotlib.pyplot as plt

# NumPy settings
np.set_printoptions(precision=4)

# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Polars settings
pl.Config.set_fmt_str_lengths(1_000)
pl.Config.set_tbl_cols(n=1_000)
pl.Config.set_tbl_rows(n=200)

warnings.filterwarnings("ignore")

# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [2]:
def go_up_from_current_directory(*, go_up: int = 1) -> None:
    """This is used to up a number of directories.

    Params:
    -------
    go_up: int, default=1
        This indicates the number of times to go back up from the current directory.

    Returns:
    --------
    None
    """
    import os

    CONST: str = "../"
    NUM: str = CONST * go_up

    # Goto the previous directory
    prev_directory = os.path.join(os.path.dirname(__name__), NUM)
    # Get the 'absolute path' of the previous directory
    abs_path_prev_directory = os.path.abspath(prev_directory)

    # Add the path to the System paths
    sys.path.insert(0, abs_path_prev_directory)
    print(abs_path_prev_directory)


# Demo (Prevents ruff from removing the unused module import)
my_coroutine: Coroutine
my_path: Path = Path(".")
obj: Annotated[list[Any], "This is an annotated list of any type"]
category: Literal["A", "B", "C"]
json.loads('{"name": "Smart-RAG", "version": "1.0"}')

{'name': 'Smart-RAG', 'version': '1.0'}

In [3]:
go_up_from_current_directory(go_up=1)

from src.config import app_settings  # noqa: E402 # type: ignore
from src.utilities.model_config import RemoteModel  # noqa: E402 # type: ignore

settings = app_settings

/Users/mac/Desktop/Projects/smart-rag


In [4]:
from langchain_openai import ChatOpenAI

remote_llm = ChatOpenAI(
    api_key=settings.OPENROUTER_API_KEY.get_secret_value(),  # type: ignore
    base_url=settings.OPENROUTER_URL,
    temperature=0.0,
    seed=1,
    model=RemoteModel.GPT_OSS_120B,
)


# Test the LLMs
response = remote_llm.invoke("Tell me a very short joke.")
response.pretty_print()


Why don’t scientists trust atoms?  

Because they make up everything.


In [5]:
console.print(response)

In [6]:
import asyncio

import uvloop

# Use Uvloop's implementation (Place this at the entrypoint)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

In [7]:
import httpx


class HTTPXClient:
    def __init__(
        self,
        base_url: str = "",
        timeout: int = 30,
        http2: bool = True,
        max_connections: int = 20,
        max_keepalive_connections: int = 5,
    ) -> None:
        self.base_url = base_url
        self.timeout = timeout
        self.http2 = http2
        self.max_connections = max_connections
        self.max_keepalive_connections = max_keepalive_connections
        self.client = httpx.AsyncClient(
            base_url=self.base_url,
            timeout=self.timeout,
            http2=self.http2,
            limits=httpx.Limits(
                max_connections=self.max_connections,
                max_keepalive_connections=self.max_keepalive_connections,
            ),
        )

    async def __aenter__(self) -> "HTTPXClient":
        return self

    async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
        await self.client.aclose()

    async def get(
        self,
        url: str,
        params: dict[str, Any] | None = None,
        headers: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Perform an asynchronous GET request."""
        try:
            response = await self.client.get(url, params=params, headers=headers)
            return self._parse_response(response)
        except Exception as e:
            return self._handle_exception(e)

    async def post(
        self,
        url: str,
        data: dict[str, Any] | None = None,
        params: dict[str, Any] | None = None,
        headers: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Perform an asynchronous POST request."""
        try:
            response = await self.client.post(
                url, data=data, params=params, headers=headers
            )
            return self._parse_response(response)
        except Exception as e:
            return self._handle_exception(e)

    def _parse_response(self, response: httpx.Response) -> dict[str, Any]:
        """Parse the HTTPX response and return a standardized dictionary."""
        try:
            data = response.json()
        except json.JSONDecodeError:
            data = response.text

        return {
            "success": response.status_code < 400,
            "status_code": response.status_code,
            "data": data,
            "headers": dict(response.headers),
            "error": (
                None
                if response.status_code < 400
                else f"HTTP {response.status_code} Error"
            ),
        }

    def _handle_exception(self, e: Exception) -> dict[str, Any]:
        """Handle exceptions and return a standardized error response."""
        if isinstance(e, httpx.ConnectError):
            error_msg = f"Connection Error: {str(e)}"
        elif isinstance(e, httpx.TimeoutException):
            error_msg = f"Request Timeout: {str(e)}"
        else:
            error_msg = f"Unexpected Error: {str(e)}"

        return {
            "success": False,
            "status_code": None,
            "data": None,
            "headers": None,
            "error": error_msg,
        }

In [8]:
async with HTTPXClient() as client:  # type: ignore
    response = await client.get(
        "https://www.bbc.com/sport/football/articles/cwy543n274wo"
    )
    print(response)



In [9]:
response["data"]



In [10]:
from markdownify import markdownify as md

console.print(md(response["data"])[3000:5000])

<br>

# RAG Pipeline


## Step 0

- Download and prepare your documents.

In [11]:
import re
import unicodedata
from pathlib import Path
from typing import Any

from bs4 import BeautifulSoup
from markdownify import markdownify as md

_ESCAPE_RE: re.Pattern[str] = re.compile(
    r"(?:\\x[0-9a-fA-F]{2}|\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8})"
)
_ZERO_WIDTH: set[str] = {"\u200b", "\u200c", "\u200d", "\u2060", "\ufeff"}


def normalize_header_string(text: str) -> str:
    """Normalize header-like strings with minimal, safe transforms.

    Applies targeted unicode-escape decoding when present, replaces NBSP, removes
    zero-width characters, normalizes (NFKC), and collapses whitespace.
    """
    # Targeted backslash-escape decoding (avoid decoding unrelated backslashes)
    if _ESCAPE_RE.search(text):

        def _sub(m: re.Match[str]) -> str:
            token = m.group(0)
            try:
                return token.encode().decode("unicode_escape")
            except Exception:
                return token

        text = _ESCAPE_RE.sub(_sub, text)

    # Replace non‑breaking space with normal space
    text = text.replace("\u00a0", " ")

    # Remove zero‑width characters
    if any(ch in text for ch in _ZERO_WIDTH):
        text = "".join(ch for ch in text if ch not in _ZERO_WIDTH)

    # Unicode normalize and collapse whitespace
    text = unicodedata.normalize("NFKC", text)
    return " ".join(text.split())


def clean_xbrl_noise(text: str) -> str:
    """Aggressively remove XBRL noise while preserving document structure.

    This function removes all XBRL/XML metadata and keeps only the meaningful
    HTML content that can be converted to readable markdown.
    """

    body_match = re.search(r"<body[^>]*>(.*)</body>", text, re.DOTALL | re.IGNORECASE)
    if body_match:
        text = "<body>" + body_match.group(1) + "</body>"

    try:
        soup = BeautifulSoup(text, "html.parser")

        # Remove <head> entirely - it contains most XBRL metadata
        for head in soup.find_all("head"):
            head.decompose()

        # Remove all script and style tags
        for tag in soup(["script", "style", "meta", "link"]):
            tag.decompose()

        # Remove XML/XBRL namespaced elements (tags with colons)
        for tag in soup.find_all():
            if tag.name and ":" in tag.name:
                tag.decompose()

        # Remove hidden XBRL data elements (usually display:none or specific XBRL classes)
        for tag in soup.find_all(style=re.compile(r"display:\s*none", re.I)):
            tag.decompose()

        for tag in soup.find_all(class_=re.compile(r"xbrl|hidden", re.I)):
            tag.decompose()

        # Remove specific XBRL attribute clutter
        for tag in soup.find_all():
            if tag.name:
                # Remove XBRL attributes
                attrs_to_remove = []
                for attr in tag.attrs:
                    if (
                        ":" in attr
                        or attr.startswith("xmlns")
                        or attr in ["contextref", "unitref", "decimals"]
                    ):
                        attrs_to_remove.append(attr)  # noqa: PERF401
                for attr in attrs_to_remove:
                    del tag[attr]

        # Get the cleaned HTML
        cleaned: str = str(soup)

    except Exception as e:
        print(f"Warning: HTML parsing failed: {e}")
        cleaned = text

    # Post-processing regex cleanup for any remaining XBRL noise

    # Remove namespace URLs that got left behind
    cleaned = re.sub(
        r'http://[^\s<>"]+(?:xbrl|fasb|sec\.gov)[^\s<>"]*', "", cleaned, flags=re.I
    )

    # Remove XBRL namespace tokens (us-gaap:Something, iso4217:USD, etc.)
    cleaned = re.sub(
        r"\b(?:us-gaap|nvda|srt|stpr|fasb|xbrli|iso4217|xbrl|dei|ix|country|xbrldi|link):[A-Za-z0-9_\-:()]+(?:Member)?\b",
        "",
        cleaned,
        flags=re.I,
    )

    # Remove long numeric strings (CIK numbers, etc.) - 10+ digits
    cleaned = re.sub(r"\b\d{10,}\b", "", cleaned)
    # Remove date patterns that are concatenated without separators (2023-01-292022-01-30)
    cleaned = re.sub(r"(?:\d{4}-\d{2}-\d{2}){2,}", "", cleaned)
    # Remove very long alphanumeric strings (40+ chars) that indicate concatenated tags
    cleaned = re.sub(r"\b[A-Za-z0-9_\-]{40,}\b", "", cleaned)
    # Remove XML/namespace declarations
    cleaned = re.sub(r'xmlns[:\w]*="[^"]*"', "", cleaned)
    cleaned = re.sub(r'xml:\w+="[^"]*"', "", cleaned)
    # Remove "pure" standalone (XBRL unit)
    cleaned = re.sub(r"\bpure\b(?!\s+\w)", "", cleaned)
    # Clean up multiple colons and extra punctuation
    cleaned = re.sub(r":{2,}", ":", cleaned)
    return re.sub(r"\s*:\s*:\s*", " ", cleaned)


async def download_and_parse_data(
    url: str,
    raw_doc_path: Path | str,
    cleaned_doc_path: Path | str,
    force_download: bool = False,
) -> None:
    """Download and parse HTML/XBRL documents with aggressive noise removal.

    Parameters
    ----------
        url : str
            The remote URL to download
        raw_doc_path : Path | str
            Output path for the raw bytes/text
        cleaned_doc_path : Path | str
            Output path for the cleaned markdown/text
        force_download : bool, default=False
            When True, re-download and re-clean even if file(s) exist

    Returns
    -------
        None
    """
    if isinstance(raw_doc_path, str):
        raw_doc_path = Path(raw_doc_path)
    if isinstance(cleaned_doc_path, str):
        cleaned_doc_path = Path(cleaned_doc_path)

    # Safe, identifiable user agent:
    USER_AGENT: str = (
        "MyCompany MyDownloader/1.0 (+https://mycompany.example; dev@mycompany.example)"
    )
    headers: dict[str, str] = {"User-Agent": USER_AGENT, "Accept": "application/json"}

    # If raw document exists and we are not forcing re-download
    if raw_doc_path.exists() and raw_doc_path.is_file() and not force_download:
        print(f"Raw file already exists: {raw_doc_path}. Skipping download.")
    else:
        # Ensure the path exists
        raw_doc_path.parent.mkdir(parents=True, exist_ok=True)

        async with HTTPXClient() as client:
            response: dict[str, Any] = await client.get(url, headers=headers)

        if not response["success"]:
            print(f"Failed to download {url}: {response.get('error')}")
            return

        # Response data may be a dict or string; store as text
        raw_content: Any = response["data"]
        if not isinstance(raw_content, str):
            # Coerce to text safely
            try:
                raw_content = json.dumps(raw_content, ensure_ascii=False)
            except Exception:
                raw_content = str(raw_content)

        raw_doc_path.write_text(raw_content, encoding="utf-8")
        print(f"Saved raw content to {raw_doc_path}")

    # Convert the raw HTML/text into a cleaned markdown or plain text
    raw_text: str = raw_doc_path.read_text(encoding="utf-8")

    # Use the aggressive cleaner to remove XBRL noise
    cleaned_html = clean_xbrl_noise(raw_text)

    # For HTML content, convert to markdown with better formatting
    try:
        # Configure markdownify to preserve more structure
        cleaned_text: str = md(
            cleaned_html,
            heading_style="ATX",  # Use # for headers
            bullets="-",  # Use - for bullet points
            strong_em_symbol="**",  # Use ** for bold
            strip=["script", "style"],  # Remove script and style tags
        )
    except Exception as e:
        # If markdownify fails, try basic text extraction
        print(f"Warning: Markdown conversion failed: {e}")
        try:
            soup = BeautifulSoup(cleaned_html, "html.parser")
            cleaned_text = soup.get_text("\n", strip=True)
        except Exception:
            cleaned_text = cleaned_html

    # Post-processing cleanup on the markdown text
    # Remove lines that are mostly XBRL noise (lots of colons, short tokens)
    lines: list[str] = cleaned_text.split("\n")
    cleaned_lines: list[str] = []
    for line in lines:
        # Skip lines with excessive XBRL patterns
        if len(line) < 10:  # Keep very short lines (might be intentional)
            cleaned_lines.append(line)
            continue

        # Count suspicious patterns
        colon_count = line.count(":")
        token_count = len(
            re.findall(r"\b[A-Z][a-z]+(?:[A-Z][a-z]+)+\b", line)
        )  # CamelCase tokens

        # If line has too many colons or camelCase tokens relative to length, skip it
        if colon_count > len(line) / 20 or (token_count > 5 and len(line.split()) < 20):
            continue

        cleaned_lines.append(line)

    cleaned_text = "\n".join(cleaned_lines)

    # Remove excessive blank lines (more than 2 consecutive)
    cleaned_text = re.sub(r"\n{3,}", "\n\n", cleaned_text)

    # Remove leading/trailing whitespace from each line
    cleaned_text = "\n".join(line.strip() for line in cleaned_text.split("\n"))

    # Final whitespace cleanup
    cleaned_text = cleaned_text.strip()

    # Ensure the path exists
    cleaned_doc_path.parent.mkdir(parents=True, exist_ok=True)
    cleaned_doc_path.write_text(cleaned_text, encoding="utf-8")

    print(f"Saved cleaned content to {cleaned_doc_path}")
    return

In [12]:
url: str = "https://www.sec.gov/Archives/edgar/data/1045810/000104581023000017/nvda-20230129.htm"

await download_and_parse_data(
    url=url, raw_doc_path="raw_doc.txt", cleaned_doc_path="cleaned_doc.txt"
)

Raw file already exists: raw_doc.txt. Skipping download.
Saved cleaned content to cleaned_doc.txt


In [13]:
fp: str = "cleaned_doc.txt"

with Path(fp).open("r", encoding="utf-8") as file:
    cleaned_doc = file.read()

In [14]:
console.print(cleaned_doc[500:1_500])

<br>

## Step 1
- Split documents into chunks (with metadata) using text splitter.
- Create embeddings.

In [15]:
# from langchain_community.document_loaders import CSVLoader
from langchain_community.document_loaders import TextLoader

loader = TextLoader(fp)  # Integration-specific parameters here

# Load all documents
documents = loader.load()

# For large datasets, lazily load documents
# for document in loader.lazy_load():
#     print(document)

In [16]:
len(documents)

1

In [17]:
from re import Match, Pattern

# Extract 10-K sections with title and content separately (line-by-line comments)
# Get the entire document text from the TextLoader's first document
raw_text: str = documents[0].page_content  # the string to search for ITEM headers

# Header pattern: match 'ITEM 1.' or 'ITEM 1A.' etc. at the beginning of a line
# ^\s*            -> allow leading whitespace before the header
# ITEM\s+         -> the literal word ITEM followed by at least one space
# \d+             -> the item number (one or more digits)
# [A-Z]?           -> optional letter (A, B, etc.) after the number
# \.               -> period following the number (escaped dot)
# [\t ]+          -> at least one whitespace char (tab/space) after the dot
# [^\n\r]*        -> the remainder of the heading line (until newline)
# re.MULTILINE     -> ^ anchors at the beginning of each line
header_pattern: Pattern[str] = re.compile(
    r"^\s*(ITEM\s+\d+[A-Z]?\.[\t ]+[^\n\r]*)", re.MULTILINE
)

# run finditer which returns match objects with start()/end() locations
matches: list[Match[str]] = list(
    header_pattern.finditer(raw_text)
)  # convert to list for indexing

# Prepare lists to hold the results
section_titles: list[str] = []  # will store the header lines like 'ITEM 1. BUSINESS'
# will store the textual content of each section (no header)
section_content: list[str] = []

# Walk through each header match, capturing both title and the content after it
for i, match in enumerate(matches):
    title: str = match.group(1).strip()  # capture the heading text and strip whitespace
    # Normalize the header to handle NBSP/zero-width and consistent spacing
    title = normalize_header_string(title)
    section_titles.append(title)

    # The content begins right after the matched heading line
    start_pos: int = match.end()  # numeric index where this header finishes

    # Determine where this section ends: next header start or the end of the document
    if i + 1 < len(matches):
        end_pos: int = matches[i + 1].start()  # next header's start position
    else:
        end_pos: int = len(raw_text)  # or EOF if this is the last header

    # Use the start/end slices to get the body text and strip leading/trailing whitespace
    content: str = raw_text[start_pos:end_pos].strip()  # remove extra whitespace
    section_content.append(content)  # store the cleaned body in the sections list

# Confirmation print for quick inspection when the cell runs
print(f"Found {len(section_titles)} ITEM sections.")

Found 21 ITEM sections.


### Create Metadata-rich Chunks

- Split documents using the section headers (title, subsection, etc.) and contents as references.
- Using each section and corresponding title as metadata, create Document objects for each chunk.
- This ensures that each chunk retains context about its origin within the larger document.

In [18]:
from uuid import uuid4

from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1_000,  # chunk size (characters)
    chunk_overlap=100,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)

doc_chunks_with_metadata: list[Document] = []

# Loop thru each section's content and its title
for title, content in zip(section_titles, section_content):
    section_chunks: list[str] = text_splitter.split_text(content)

    # Loop thru each chunk to add metadata
    for chunk in section_chunks:
        chunk_id: str = str(uuid4())  # unique ID for this chunk
        doc_chunks_with_metadata.append(  # noqa: PERF401
            Document(
                page_content=chunk,
                metadata={
                    "source_doc": fp,  # original document path
                    # Ensure section titles are normalized in metadata
                    "section": normalize_header_string(title),
                    "chunk_id": chunk_id,  # unique chunk ID
                },
            )
        )

print(f"Created {len(doc_chunks_with_metadata)} document chunks with metadata.")

Created 374 document chunks with metadata.


In [19]:
console.print(doc_chunks_with_metadata[51])

In [20]:
section_titles

['ITEM 1. BUSINESS',
 'ITEM 1A. RISK FACTORS',
 'ITEM 1B. UNRESOLVED STAFF COMMENTS',
 'ITEM 2. PROPERTIES',
 'ITEM 3. LEGAL PROCEEDINGS',
 'ITEM 4. MINE SAFETY DISCLOSURES',
 'ITEM 5. MARKET FOR REGISTRANT’S COMMON EQUITY, RELATED STOCKHOLDER MATTERS AND ISSUER PURCHASES OF EQUITY SECURITIES',
 'ITEM 6. [RESERVED]',
 'ITEM 7. MANAGEMENT’S DISCUSSION AND ANALYSIS OF FINANCIAL CONDITION AND RESULTS OF OPERATIONS',
 'ITEM 7A. QUANTITATIVE AND QUALITATIVE DISCLOSURES ABOUT MARKET RISK',
 'ITEM 8. FINANCIAL STATEMENTS AND SUPPLEMENTARY DATA',
 'ITEM 9. CHANGES IN AND DISAGREEMENTS WITH ACCOUNTANTS ON ACCOUNTING AND FINANCIAL DISCLOSURE',
 'ITEM 9A. CONTROLS AND PROCEDURES',
 'ITEM 9C. DISCLOSURE REGARDING FOREIGN JURISDICTIONS THAT PREVENT INSPECTIONS',
 'ITEM 10. DIRECTORS, EXECUTIVE OFFICERS AND CORPORATE GOVERNANCE',
 'ITEM 11. EXECUTIVE COMPENSATION',
 'ITEM 12. SECURITY OWNERSHIP OF CERTAIN BENEFICIAL OWNERS AND MANAGEMENT AND RELATED STOCKHOLDER MATTERS',
 'ITEM 13. CERTAIN RELATIONS

In [21]:
# Test the Metadata-aware chunking: e.g. 'Risk Factors' should be in the section
sample_chunk = (
    chunk
    for chunk in doc_chunks_with_metadata
    if "risk factors" in chunk.metadata.get("section", "").lower()
)
console.print(next(sample_chunk))

In [22]:
console.print(next(sample_chunk))

In [23]:
import os
from typing import Any

from langchain_core.embeddings import Embeddings
from langchain_core.utils import convert_to_secret_str
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    SecretStr,
    model_validator,
)

from src.utilities.openrouter.client import AsyncOpenRouterClient, OpenRouterClient


def set_openrouter_api(value: str | None = None) -> SecretStr:
    """Set the OpenRouter API key"""
    if value is None:
        return convert_to_secret_str(os.getenv("OPENROUTER_API_KEY", ""))
    return convert_to_secret_str(value)


class OpenRouterEmbeddings(BaseModel, Embeddings):
    """Using Field with default_factory for automatic client creation."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    client: OpenRouterClient = Field(default_factory=OpenRouterClient)
    aclient: AsyncOpenRouterClient = Field(default_factory=AsyncOpenRouterClient)

    openrouter_api_key: SecretStr = Field(default_factory=set_openrouter_api)
    model: str = Field(default="openai/text-embedding-3-small")

    @model_validator(mode="after")
    def validate_environment(self) -> "OpenRouterEmbeddings":
        """Validate the environment and set up the OpenRouter client."""
        _api_key: SecretStr | str = self.openrouter_api_key or os.getenv(
            "OPENROUTER_API_KEY", ""
        )
        if not _api_key:
            raise ValueError(
                "OpenRouter API key not found. Please set the OPENROUTER_API_KEY environment variable."
            )

        if isinstance(_api_key, str):
            _api_key = convert_to_secret_str(_api_key)

        # Set up the OpenRouter client if not already set
        self.client = OpenRouterClient(
            api_key=_api_key.get_secret_value(),  # type: ignore
            default_model=self.model,
        )
        self.aclient = AsyncOpenRouterClient(
            api_key=_api_key.get_secret_value(),  # type: ignore
            default_model=self.model,
        )
        return self

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        """Embed search docs."""
        response: dict[str, Any] = self.client.embeddings.create(
            input=texts, model=self.model
        )
        return [emb["embedding"] for emb in response["data"]]

    def embed_query(self, text: str) -> list[float]:
        """Embed query text."""
        return self.embed_documents([text])[0]

    async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
        """Embed search docs."""
        response: dict[str, Any] = await self.aclient.aembeddings.create(
            input=texts, model=self.model
        )

        return [emb["embedding"] for emb in response["data"]]

    async def aembed_query(self, text: str) -> list[float]:
        """Embed query text."""
        return (await self.aembed_documents([text]))[0]


embeddings = OpenRouterEmbeddings()
result: list[list[float]] = await embeddings.aembed_documents(texts=["Hello there!"])

In [24]:
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

client = QdrantClient(":memory:")

vector_size: int = len(await embeddings.aembed_query("sample text"))
collection_name: str = "smart_rag_collection"

if not client.collection_exists(collection_name):
    client.create_collection(
        collection_name=collection_name,
        vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
    )
vectorstore = QdrantVectorStore(
    client=client,
    collection_name=collection_name,
    embedding=embeddings,
)
# Embed all the documents
document_ids: list[str] = await vectorstore.aadd_documents(
    documents=doc_chunks_with_metadata
)
print(document_ids[:3])

['9cc210ca3a824d498f92c62333253d7f', '9206a90af45040f99d312cabc64ea7c5', 'dd8dae79c43a415e806e604ef5192ae4']


In [25]:
doc_chunks_with_metadata[0].model_dump()

{'id': None,
 'metadata': {'source_doc': 'cleaned_doc.txt',
  'section': 'ITEM 1. BUSINESS',
  'chunk_id': '5843ee57-5c09-471c-9d07-2b068c2b5320'},
 'page_content': 'Our Company\n\nNVIDIA pioneered accelerated computing to help solve the most challenging computational problems. Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields. Fueled by the sustained demand for exceptional 3D graphics and the scale of the gaming market, NVIDIA has leveraged its GPU architecture to create platforms for scientific computing, artificial intelligence, or AI, data science, autonomous vehicles, or AV, robotics, metaverse and 3D internet applications.',
 'type': 'Document'}

In [30]:
client = QdrantClient(host="localhost", port=6333)

vector_size: int = len(await embeddings.aembed_query("sample text"))
collection_name: str = "smart_rag_collection"

if not client.collection_exists(collection_name):
    client.create_collection(
        collection_name=collection_name,
        vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
    )
vectorstore = QdrantVectorStore(
    client=client,
    collection_name=collection_name,
    embedding=embeddings,
)
# Embed all the documents
document_ids: list[str] = await vectorstore.aadd_documents(
    documents=doc_chunks_with_metadata
)
print(document_ids[:3])

['2f67d0b2f5e04830b43144b121b895ff', 'fa5ccd06225e4b0ab1968f2b9fb0c523', '155de15248f54e1eac06a504fe572c3a']


In [31]:
from src.utilities.vectorstores import VectorStoreSetup

In [34]:
vs_setup = VectorStoreSetup()
docs_custom = vs_setup.chunk_documents(
    documents=documents, source="my_source", split_by_sections=True
)

Found 21 ITEM sections.
Created 374 document chunks with metadata.


In [36]:
console.print(docs_custom[0])

## Create Tools

<br>

### VectorSearch (Filtering Chunks by Metadata)

In [None]:
from qdrant_client import QdrantClient, models

query: str = "According to NVIDIA's management, what assurance level do disclosure controls and internal controls actually provide?"

retrieved_docs = vectorstore.similarity_search(
    query,
    k=3,
    filter=models.Filter(
        must=[
            models.FieldCondition(
                key="metadata.section",
                match=models.MatchValue(value="ITEM 9A. CONTROLS AND PROCEDURES"),
            )
        ]
    ),
)
formatted_docs: str = "\n\n".join(
    (f"Source: {doc.metadata}\nContent: {doc.page_content}") for doc in retrieved_docs
)

console.print(formatted_docs)

In [None]:
from langchain.tools import tool
from qdrant_client.models import Filter


@tool
async def avector_search(
    query: str, filter: str | None = None, k: int = 3
) -> list[Document]:
    """Perform a vector search with metadata filtering.

    Parameters
    ----------
    query : str
        The search query string.
    filter : str or None, default=None
        The metadata filter value for 'metadata.section'.
    k : int, default=3
        The number of top similar documents to retrieve.

    Returns
    -------
    list[Document]
        A list of retrieved Document objects.
    """
    key: str = "metadata.section"
    _filter: Filter | None = (
        models.Filter(
            must=[models.FieldCondition(key=key, match=models.MatchValue(value=filter))]
        )
        if filter
        else None
    )
    return await vectorstore.asimilarity_search(query, k=k, filter=_filter)


async def avector_search_tool(
    query: str, filter: str | None = None, k: int = 3
) -> list[Document]:
    """Perform a vector search with metadata filtering.

    Parameters
    ----------
    query : str
        The search query string.
    filter : str or None, default=None
        The metadata filter value for 'metadata.section'.
    k : int, default=3
        The number of top similar documents to retrieve.

    Returns
    -------
    list[Document]
        A list of retrieved Document objects.
    """
    key: str = "metadata.section"
    _filter: Filter | None = (
        models.Filter(
            must=[models.FieldCondition(key=key, match=models.MatchValue(value=filter))]
        )
        if filter
        else None
    )
    return await vectorstore.asimilarity_search(query, k=k, filter=_filter)

In [None]:
console.print(avector_search)

In [None]:
avector_search.coroutine

#### Note

```py
# Without the @tool decorator
result = await avector_search(query=query, filter=None, k=3)

# Using the @tool decorator
result = await avector_search.coroutine(query=query, filter=None, k=3)
```

In [None]:
# Without the @tool decorator
# result = await avector_search(query=query, filter=None, k=3)

# Using the @tool decorator
result = await avector_search.coroutine(query=query, filter=None, k=3)


console.print(result)

In [None]:
result = await avector_search_tool(
    query=query, filter="ITEM 9A. CONTROLS AND PROCEDURES", k=3
)

console.print(result)

### Keyword Search

In [None]:
from rank_bm25 import BM25Okapi

corpus: list[str] = [
    "Hello there good man!",
    "It is quite windy in London",
    "How is the weather today?",
]

tokenized_corpus: list[list[str]] = [doc.split(" ") for doc in corpus]

bm25 = BM25Okapi(tokenized_corpus)
query: str = "windy London"
tokenized_query: list[str] = query.split(" ")  # type: ignore

doc_scores = bm25.get_scores(tokenized_query)
print(doc_scores)
# Sort in descending order of scores
sorted_idxs = np.argsort(doc_scores)[::-1]
sorted_idxs

In [None]:
[corpus[idx] for idx in sorted_idxs]

In [None]:
import re

from tokenizers import (  # type: ignore
    Regex,
    Tokenizer,
    normalizers,
)
from tokenizers import (
    models as t_models,
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"


class CustomTokenizer:
    """A class for ..."""

    pattern_digits: str = r"[0-9]+"
    pattern_punctuation: str = r"[^\w\s\\\/]"  # Includes `\`, `/`
    pattern_spaces: str = r"\s{2,}"
    pattern_split: str = r"\W"

    unk_str: str = "[UNK]"

    def __init__(self, to_lower: bool = False) -> None:
        """Initialize with a WordPiece tokenizer and normalizer sequence."""
        self.to_lower = to_lower
        self.tokenizer = Tokenizer(t_models.WordPiece(unk_token=self.unk_str))  # type: ignore

        # Create the custom normalizer
        transformations_list = []
        if self.to_lower:
            transformations_list.append(normalizers.Lowercase())

        transformations_list.extend(
            [  # type: ignore
                normalizers.NFD(),
                normalizers.Replace(Regex(self.pattern_digits), " "),
                normalizers.Replace(Regex(self.pattern_punctuation), " "),
                normalizers.StripAccents(),
                normalizers.Strip(),
                # Last step
                normalizers.Replace(Regex(self.pattern_spaces), " "),
            ]
        )
        self.tokenizer.normalizer = normalizers.Sequence(  # type: ignore
            transformations_list  # type: ignore
        )

    def split_on_patterns(self, text: str) -> str:
        """Split a string on a pattern and join the parts with spaces.

        Parameters
        ----------
        text : str
            Input text to be split.

        Returns
        -------
        str
            Processed text with pattern-based splits.
        """
        parts: list[str] = re.split(self.pattern_split, text, flags=re.I)
        # Remove empty strings and join by spaces
        output: str = " ".join(filter(lambda x: x != "", [p.strip() for p in parts]))
        return output

    def format_data(self, data: str) -> str:
        """Format a single text string using pattern splitting and normalization.

        Parameters
        ----------
        data : str
            Input text to be formatted.

        Returns
        -------
        str
            Normalized and formatted text.
        """
        text: str = self.split_on_patterns(data)
        return self.tokenizer.normalizer.normalize_str(text)

    def batch_format_data(self, data: list[str]) -> list[str]:
        """Format a batch of text strings.

        Parameters
        ----------
        data : list[str]
            List of input texts to be formatted.

        Returns
        -------
        list[str]
            List of normalized and formatted texts.
        """
        return [self.format_data(row) for row in data]

In [None]:
custom_tokenizer = CustomTokenizer()
custom_tokenizer.batch_format_data(corpus)

In [None]:
print("\nBuilding BM25 index for keyword search...")

# Create a list where each element is a list of words from a document
tokenized_corpus = [
    custom_tokenizer.format_data(doc.page_content).split(" ")
    for doc in doc_chunks_with_metadata
]

# Create a list of all unique document IDs
doc_ids: list[str] = [doc.metadata["chunk_id"] for doc in doc_chunks_with_metadata]

# Create a mapping from a document's ID back to the full Document object for easy lookup
doc_dict: dict[str, Document] = {
    doc.metadata["chunk_id"]: doc for doc in doc_chunks_with_metadata
}

# Initialize the BM25Okapi index with our tokenized corpus
bm25 = BM25Okapi(tokenized_corpus)

In [None]:
import asyncio


def keyword_search(query: str, k: int = 3) -> list[Document]:
    """Perform keyword search using BM25 and return top k documents."""
    # Tokenize the query
    tokenized_query: list[str] = custom_tokenizer.format_data(query).split()
    doc_scores = bm25.get_scores(tokenized_query)
    # Sort in descending order and select the top k
    top_k_idxs: np.ndarray = np.argsort(doc_scores)[::-1][:k]

    return [doc_dict[doc_ids[i]] for i in top_k_idxs]


@tool
async def akeyword_search(
    query: str, filter: str | None = None, k: int = 3
) -> list[Document]:  # noqa: ARG001
    """Perform keyword search asynchronously using BM25 and return top k documents.

    Parameters
    ----------
    query : str
        The search query string.
    filter : str | None, default=None
        For function signature compatibility. (Not used in keyword search)
    k : int, default=3
        The number of top similar documents to retrieve.

    Returns
    -------
    list[Document]
        A list of retrieved Document objects.
    """
    return await asyncio.to_thread(keyword_search, query, k)


async def akeyword_search_tool(
    query: str,
    filter: str | None = None,  # noqa: ARG001
    k: int = 3,
) -> list[Document]:
    """Perform keyword search asynchronously using BM25 and return top k documents.

    Parameters
    ----------
    query : str
        The search query string.
    filter : str | None, default=None
        For function signature compatibility. (Not used in keyword search)
    k : int, default=3
        The number of top similar documents to retrieve.

    Returns
    -------
    list[Document]
        A list of retrieved Document objects.
    """
    return await asyncio.to_thread(keyword_search, query, k)

In [None]:
query: str = "According to NVIDIA's management, what assurance level do disclosure controls and internal controls actually provide?"

retrieved_docs = keyword_search(query=query, k=3)

console.print(retrieved_docs)

In [None]:
query: str = "According to NVIDIA's management, what assurance level do disclosure controls and internal controls actually provide?"

retrieved_docs = await akeyword_search_tool(query=query, k=3)

console.print(retrieved_docs)

### Hybrid Search

- Keyword + Vector Search

In [None]:
_filter = "ITEM 9A. CONTROLS AND PROCEDURES"
k: int = 3
tasks: list[Coroutine[Any, Any, list[Document]]] = [
    avector_search_tool(query=query, filter=_filter, k=k),
    akeyword_search_tool(query=query, k=k),
]

semantic_docs, kw_docs = await asyncio.gather(*tasks)

semantic_docs, kw_docs

In [None]:
from langchain_core.documents.base import Document


async def ahybrid_search_tool(
    query: str, filter: str | None = None, k: int = 5
) -> list[Document]:
    """
    Asynchrounously combine vector and keyword search results using Reciprocal Rank Fusion (RRF).

    Parameters
    ----------
    query : str
        The search query string.
    filter : str or None, optional
        Optional filter expression passed to the vector search, by default None.
    k : int, optional
        Maximum number of documents to return, by default 5.

    Returns
    -------
    list[Document]
        Top-k documents ranked by fused scores.

    Notes
    -----
    RRF is a simple, unsupervised method for merging ranked lists.
    The constant ``K`` (set to 61) controls the steepness of the rank
    discount curve and is taken from the original RRF paper.
    """
    K: int = 61  # Default for RRF

    tasks: list[Coroutine[Any, Any, list[Document]]] = [
        avector_search_tool(query=query, filter=filter, k=k),  # type: ignore
        akeyword_search_tool(query=query, k=k),  # type: ignore
    ]
    semantic_docs, kw_docs = await asyncio.gather(*tasks)

    # Results of vector and kw search
    res_ids: list[list[str]] = [
        [doc.metadata["chunk_id"] for doc in semantic_docs],
        [doc.metadata["chunk_id"] for doc in kw_docs],
    ]
    # Calculate Reciprocal Rank Fusion (RRF)
    rrf_dict: dict[str, float] = {}

    for doc_list in res_ids:
        # Grab each doc_id
        for idx, doc_id in enumerate(doc_list):
            if doc_id not in rrf_dict:
                rrf_dict[doc_id] = 0
            # Add (1 / (idx + k)) to each retrieved doc
            rrf_dict[doc_id] += 1 / (idx + K)
    # Sort result using RRF score in descending order
    ranked_ids: list[str] = sorted(
        rrf_dict.keys(), key=lambda x: rrf_dict[x], reverse=True
    )[:k]

    return [doc_dict[_id] for _id in ranked_ids]

In [None]:
result = await ahybrid_search_tool(
    query=query, filter="ITEM 9A. CONTROLS AND PROCEDURES", k=5
)
console.print(result)

<br>

### Re-Ranker

- Re-rank retrieved chunks based on relevance to the query.


In [None]:
from sentence_transformers import CrossEncoder

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# reranker

In [None]:
def rerank_documents(
    query: str, documents: list[Document], k: int = 3
) -> list[Document]:
    """Rerank documents by relevance to query using CrossEncoder.

    Parameters
    ----------
    query : str
        The search query string.
    documents : list[Document]
        List of Document objects to rerank.
    k : int, optional
        Maximum number of documents to return, by default 3.

    Returns
    -------
    list[Document]
        Documents sorted by relevance score in descending order.
    """
    # Prepare pairs of (query, document content) for scoring
    pairs: list[tuple[str, str]] = [(query, doc.page_content) for doc in documents]
    # Get relevance scores from the CrossEncoder
    scores: list[float] | np.ndarray = reranker.predict(pairs)

    # Combine documents with their scores
    doc_score_pairs: list[tuple[Document, float]] = list(zip(documents, scores))
    # Sort documents by score in descending order
    ranked_docs: list[Document] = [
        doc for doc, _ in sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
    ][:k]
    return ranked_docs


async def arerank_documents(
    query: str, documents: list[Document], k: int = 3
) -> list[Document]:
    """Asynchronously rerank documents by relevance to query using CrossEncoder.

    Parameters
    ----------
    query : str
        The search query string.
    documents : list[Document]
        List of Document objects to rerank.
    k : int, optional
        Maximum number of documents to return, by default 3.

    Returns
    -------
    list[Document]
        Documents sorted by relevance score in descending order.
    """
    return await asyncio.to_thread(rerank_documents, query, documents, k)

In [None]:
rerank_doc = await arerank_documents(query, documents=result)
print(f"Query: {query}\n")
print("Reranked Documents:\n")
console.print(rerank_doc)

### Web Search

- Using Tavily Search

In [None]:
from langchain_community.tools import DuckDuckGoSearchResults


def truncate_content(content: str | None, max_chars: int | None = None) -> str | None:
    """Truncate content to max_chars with ellipsis indicator."""
    if not content:
        return None

    if max_chars:
        return (
            f"{content[:max_chars]} [truncated]..."
            if len(content) > max_chars
            else content
        )
    return content


def extract_main_content_from_html(content: str) -> str:
    """Extract main content from HTML by removing noise and finding article body.

    Parameters
    ----------
    content : str
        Raw HTML content string.

    Returns
    -------
    str
        BeautifulSoup element containing the main content area.
        Falls back to body element if no main content found.

    Notes
    -----
    Removes scripts, styles, navigation, headers, footers, and ads.
    Searches for common content containers: main, article, or content divs.
    """
    soup = BeautifulSoup(content, "html.parser")

    # Remove unwanted elements
    for tag in soup(
        [
            "script",
            "style",
            "nav",
            "header",
            "footer",
            "aside",
            "iframe",
            "noscript",
        ]
    ):
        tag.decompose()

    # Try to find main content area (common patterns)
    main_content = None
    for selector in [
        soup.find("main"),
        soup.find("article"),
        soup.find(
            "div",
            class_=lambda x: x
            and any(
                c in str(x).lower()  # type: ignore
                for c in ["content", "article", "post", "story"]
            ),
        ),
        soup.find(
            "div",
            id=lambda x: x
            and any(
                c in str(x).lower()  # type: ignore
                for c in ["content", "article", "post", "main"]
            ),
        ),
    ]:
        if selector and selector.get_text(strip=True):
            main_content = selector
            break

    # Fall back to body if no main content found
    if not main_content:
        main_content = soup.find("body") or soup

    return str(main_content)


async def afetch_raw_content(url: str) -> str | None:
    """Fetch HTML content from a URL and convert to markdown.

    Parameters
    ----------
    url : str
        The URL to fetch content from.

    Returns
    -------
    str | None
        Markdown-converted content if successful, None otherwise.

    Notes
    -----
    Uses browser-like headers to avoid bot detection and a 15-second timeout.
    Extracts main content from common article/content tags.
    """
    # Browser-like headers to avoid bot detection
    headers: dict[str, str] = {
        "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
        "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8",
        "Accept-Language": "en-US,en;q=0.9",
        "Accept-Encoding": "gzip, deflate, br",
        "DNT": "1",
        "Connection": "keep-alive",
        "Upgrade-Insecure-Requests": "1",
        "Sec-Fetch-Dest": "document",
        "Sec-Fetch-Mode": "navigate",
        "Sec-Fetch-Site": "none",
        "Cache-Control": "max-age=0",
    }

    try:
        async with HTTPXClient(timeout=15) as client:
            response = await client.get(url, headers=headers)

            # Check if request was successful
            if not response.get("success"):
                return None

            # Response might be dict or str
            if isinstance(response["data"], dict):
                content = json.dumps(response["data"])
            else:
                content = response["data"]
            html_content = content

            # Parse HTML and extract main content
            main_content: str = extract_main_content_from_html(content=html_content)

            # Convert to markdown
            markdown_content = md(
                str(main_content),
                heading_style="ATX",
                bullets="-",
                strip=["script", "style"],
            )

            # Clean up excessive whitespace
            lines: list[str] = [
                line.strip() for line in markdown_content.split("\n") if line.strip()
            ]
            cleaned = "\n\n".join(lines)

            return cleaned if cleaned and len(cleaned) > 100 else None

    except Exception as e:
        print(f"Warning: Failed to fetch full page content for {url}: {str(e)}")
        return None


async def aduckduckgo_search(
    query: str, fetch_full_page: bool = False, k: int = 5, max_chars: int | None = None
) -> dict[str, list[dict[str, Any]]]:
    """Search DuckDuckGo and optionally fetch full page content.

    Parameters
    ----------
    query : str
        The search query string.
    fetch_full_page : bool, default=False
        If True, fetch and parse full HTML content for each result.
    k : int, optional
        Maximum number of documents to return, by default 5.
    max_chars : int or None, default=None
        Maximum characters to return per result. If None, no truncation.

    Returns
    -------
    dict[str, list[dict[str, Any]]]
        Dictionary with "results" key containing list of search results.
        Each result has title, url, content, and raw_content fields.

    Notes
    -----
    When fetch_full_page=True, uses browser-like headers and smart content
    extraction to avoid bot detection and JS-blocking issues.
    """

    try:
        search = DuckDuckGoSearchResults(output_format="list", num_results=k)
        raw_results = await search.ainvoke(query)

        # format the data
        raw_results: list[dict[str, Any]] = [
            {
                "title": row["title"],
                "url": row["link"],
                "content": row["snippet"],
                "raw_content": row["snippet"],
            }
            for row in raw_results
        ]

        if fetch_full_page:
            # Fetch full pages concurrently for better performance
            tasks: list[Coroutine] = [
                afetch_raw_content(row["url"]) for row in raw_results
            ]
            full_contents = await asyncio.gather(*tasks)

            raw_results: list[dict[str, Any]] = [
                {
                    **row,
                    "raw_content": truncate_content(
                        content=full_content
                        or row["content"],  # Fall back to content if fetch fails
                        max_chars=max_chars,
                    ),
                }
                for row, full_content in zip(raw_results, full_contents)
            ]
        return {"results": raw_results}

    except Exception as e:
        print(f"Duckduckgo search failed: {str(e)}")
        return {"results": []}

In [None]:
search_result = await aduckduckgo_search(
    query="Has Nvidia broken any laws?",
    fetch_full_page=True,
    max_chars=5_000,
)

console.print(search_result)

In [None]:
async def aduckduckgo_web_search_tool(
    query: str, fetch_full_page: bool = False, k: int = 5, max_chars: int | None = None
) -> list[Document]:
    """Asynchronously search DuckDuckGo and optionally fetch full page content.

    Parameters
    ----------
    query : str
        The search query string.
    fetch_full_page : bool, default=False
        If True, fetch and parse full HTML content for each result.
    k : int, optional
        Maximum number of documents to return, by default 5.
    max_chars : int or None, default=None
        Maximum characters to return per result. If None, no truncation.

    Returns
    -------
    list[Document]
    """
    max_chars = 8_000 if not max_chars else max_chars

    search_response: dict[str, list[dict[str, Any]]] = await aduckduckgo_search(
        query=query, fetch_full_page=fetch_full_page, k=k, max_chars=max_chars
    )

    formatted_results: list[Document] = [
        Document(
            page_content=f"Title: {result['title']}\nContent: {result['raw_content']}",
            metadata={
                "url": result["url"],
                "title": result["title"],
            },
        )
        for result in search_response["results"]
    ]
    return formatted_results

In [None]:
search_result = await aduckduckgo_web_search_tool(
    query="Has Nvidia broken any laws?",
    fetch_full_page=True,
    max_chars=5_000,
)

console.print(search_result)

In [None]:
from langchain_tavily import TavilySearch

tavily_search = TavilySearch(
    api_key=settings.TAVILY_API_KEY.get_secret_value(),
    max_results=2,
    topic="general",
)
search_response = await tavily_search.ainvoke({"query": "Has Nvidia broken any laws?"})

console.print(search_response)

In [None]:
async def tavily_search_tool(
    query: str, fetch_full_page: bool = False, k: int = 5, max_chars: int | None = None
) -> dict[str, list[dict[str, Any]]]:
    """Search the web using TavilySearch and return formatted results.

    Parameters
    ----------
    query : str
        The search query string.
    fetch_full_page : bool, default=False
        If True, include full raw_content from Tavily results.
    k : int, optional
        Maximum number of documents to return, by default 5.
    max_chars : int or None, default=None
        Maximum characters to return per result. If None, no truncation.

    Returns
    -------
    dict[str, list[dict[str, Any]]]
        Dictionary with "results" key containing list of search results.
        Each result has title, url, content, and raw_content fields.

    Notes
    -----
    Tavily automatically provides raw_content when available.
    No additional fetching needed - Tavily handles this internally.
    """

    tavily_search = TavilySearch(
        api_key=settings.TAVILY_API_KEY.get_secret_value(),
        max_results=k,
        topic="general",
        # include_raw_content tells Tavily to fetch full page content
        include_raw_content=fetch_full_page,
    )
    search_response = await tavily_search.ainvoke({"query": query})
    # return search_response
    raw_results: list[dict[str, Any]] = [
        {
            "title": row["title"],
            "url": row["url"],
            "content": row["content"],
            "raw_content": truncate_content(
                content=row.get("raw_content") or row["content"],
                max_chars=max_chars,
            ),
        }
        for row in search_response["results"]
    ]

    return {"results": raw_results}


async def atavily_web_search_tool(
    query: str, fetch_full_page: bool = False, k: int = 5, max_chars: int | None = None
) -> list[Document]:
    """Asynchronously search using Tavily and optionally fetch full page content.

    Parameters
    ----------
    query : str
        The search query string.
    fetch_full_page : bool, default=False
        If True, fetch and parse full HTML content for each result.
    k : int, optional
        Maximum number of documents to return, by default 5.
    max_chars : int or None, default=None
        Maximum characters to return per result. If None, no truncation.

    Returns
    -------
    list[Document]
    """
    max_chars = 8_000 if not max_chars else max_chars
    search_response: dict[str, list[dict[str, Any]]] = await tavily_search_tool(
        query=query, fetch_full_page=fetch_full_page, k=k, max_chars=max_chars
    )

    formatted_results: list[Document] = [
        Document(
            page_content=f"Title: {result['title']}\nContent: {result['raw_content']}",
            metadata={
                "url": result["url"],
                "title": result["title"],
            },
        )
        for result in search_response["results"]
    ]
    return formatted_results

In [None]:
search_response = await tavily_search_tool(
    query="Has Nvidia broken any laws?",
    fetch_full_page=True,
    max_chars=5_000,
)

console.print(search_response)

In [None]:
search_response = await atavily_web_search_tool(
    query="Has Nvidia broken any laws?",
    fetch_full_page=True,
    max_chars=None,
)

console.print(search_response)

### Create States

<br>

#### 1.) Step

- The smallest unit. Multiple Steps make up a Plan.
- A Step has:
  - question: The question being asked.
  - rationale: The reasoning behind the question.
  - tool: The tool to be used to answer the question. (e.g. web_search or vectorstore_lookup)
  - search_keywords: Keywords to use for searching.
  - target_section: The section of the document to focus on. (Only for vectorstore_lookup tool)


In [None]:
from enum import StrEnum
from typing import TypedDict


class RetrieverMethodType(StrEnum):
    """The type of retrieval method to use for internal document search."""

    VECTOR_SEARCH = "vector_search"
    KEYWORD_SEARCH = "keyword_search"
    HYBRID_SEARCH = "hybrid_search"


class ToolsType(StrEnum):
    """The type of tool to use for each step."""

    VECTOR_STORE = "vector_store"
    WEB_SEARCH = "web_search"


class NextAction(StrEnum):
    """Tells the executor what to do after the current planning step."""

    CONTINUE = "continue"
    FINISH = "finish"


class Step(BaseModel):
    """A single step in the multi-step reasoning process."""

    question: str = Field(description="The question to be answered by the step.")
    rationale: str = Field(description="The brief reasoning behind the question.")
    tool: ToolsType = Field(
        description="The tool to use for this step. For information found ONLY in internal documents, "
        "use 'vector_store'. For the latest information found on the web, use 'web_search'.",
    )
    search_keywords: list[str] = Field(
        description="Critical keywords and phrases to use for web search or vector store "
        "retrieval to ensure quality results are returned.",
    )
    target_section: str | None = Field(
        default=None,
        description="The target section in the document to focus on. This is ONLY required when "
        "the tool is 'vector_store'. e.g., 'ITEM 1A. RISK FACTORS'.",
    )
    depends_on: list[int] = Field(
        default_factory=list,
        description="List of step indices (0-based) that this step depends on. "
        "Leave empty if this step can run immediately.",
    )


class ReWrittenQuery(BaseModel):
    question: str = Field(description="Original query to be re-written.")
    rewritten_query: list[str] = Field(description="The re-written query.")
    rationale: str = Field(description="The brief reasoning behind the decision.")


class ValidateQuery(BaseModel):
    is_related_to_context: bool = Field(
        description="Whether the query is related to the context."
    )
    next_action: NextAction = Field(description="The next action to take.")
    rationale: str = Field(description="The brief reasoning behind the decision.")


class RetrieverMethod(BaseModel):
    method: RetrieverMethodType = Field(
        description="The retrieval method to use for retrieving internal documents.",
    )
    rationale: str = Field(description="The brief reasoning behind the decision.")


class Decision(BaseModel):
    next_action: NextAction = Field(..., description="The next action to take.")
    rationale: str = Field(description="The brief reasoning behind the decision.")

#### 2.) Plan

- A Plan is a sequence of Steps to achieve a goal.


#### 2.b) PastStep

- This is used to store the history of executed steps in the plan.


#### 3.) State

- A State represents the current status of the RAG process. It includes:
  - original_question: The initial question posed by the user.
  - plan: The current plan being executed.
  - past_steps: A list of PastStep objects representing the history of executed steps.
  - current_step_index: The index of the current step in the plan.
  - retrieved_docs: A list of Document objects that have been retrieved so far.
  - reranked_docs: A list of Document objects that have been reranked based on relevance.
  - synthesized_content: The content synthesized from the reranked documents.
  - final_answer: The final answer generated for the original question.

In [None]:
class Plan(BaseModel):
    """A multi-step plan for answering a complex question."""

    steps: list[Step] = Field(description="A list of steps to execute in the plan.")


class StepState(TypedDict):
    """State of a completed step in the multi-step reasoning process."""

    step_index: int  # Index of the step in the plan
    question: str  # The question asked in this step
    rewritten_queries: list[str]  # Re-written queries for this step
    retrieved_documents: list[Document]  # Documents retrieved for this step
    summary: str  # Summary of the step's findings


class State(TypedDict):
    """State of the multi-step reasoning process."""

    original_question: str  # The original complex question
    is_related_to_context: bool  # Whether the question is related to the context
    plan: Plan  # The multi-step plan
    step_state: list[StepState]  # List of completed steps
    current_step_index: int  # Index of the current step being executed
    retrieved_documents: list[Document]  # Documents retrieved in the current step
    reranked_documents: list[Document]  # Documents reranked based on relevance
    num_iterations: int  # Number of iterations completed
    synthesized_context: str  # Synthesized context from reranked documents
    final_answer: str  # The final answer to the original question


# NB: If you want to append to a key and return the entire updated list, you need to use operator.add
# Otherwise return ONLY the new items to append and LangGraph will handle the appending internally.
# e.g.,
# class State(TypedDict):
#     """State of the multi-step reasoning process."""

#     original_question: str  # The original complex question
#     ... other fields ...
#     retrieved_documents: Annotated[list[Document], operator.add]  # Documents retrieved in the current step
#     reranked_documents: Annotated[list[Document], operator.add]  # Documents reranked based on relevance

# Within the function, return the FULL state object with updated lists.
# return State(
#     original_question=state["original_question"],
#     ... other fields ...
#     retrieved_documents=state["retrieved_documents"] + new_retrieved_docs,
#     reranked_documents=state["reranked_documents"] + new_reranked_docs,
#   )

### 1.) To append ONLY new items to a list in the State
```py
class State(TypedDict):
    """State of the multi-step reasoning process."""

    original_question: str  # The original complex question
    plan: Plan  # The multi-step plan
    step_state: list[StepState]  # List of completed steps
    current_step_index: int  # Index of the current step being executed
    retrieved_documents: list[Document]  # Documents retrieved in the current step
    reranked_documents: list[Document]  # Documents reranked based on relevance
    num_iterations: int  # Number of iterations completed
    synthesized_context: str  # Synthesized context from reranked documents
    final_answer: str  # The final answer to the original question

# Within the function, return ONLY the NEW items to append to the lists as a dict.
# LangGraph will handle the appending internally.

async def my_function(state: State) -> dict[str, list[Document]]:
    new_retrieved_docs: list[Document] = [...]  # New documents retrieved in this step
    new_reranked_docs: list[Document] = [...]  # New documents reranked in this step

  return {
      "retrieved_documents": new_retrieved_docs,
      "reranked_documents": new_reranked_docs,  
  }
```

### 2.) To append and return the FULL updated list in the State

```py
# If you want to append to a key and return the entire updated list, you need to use operator.add 
# e.g.,
class State(TypedDict):
    """State of the multi-step reasoning process."""

    original_question: str  # The original complex question
    ... other fields ...
    retrieved_documents: Annotated[list[Document], operator.add]  # Append to this list
    reranked_documents: Annotated[list[Document], operator.add]  # Append to this list

async def my_function(state: State) -> State:
    new_retrieved_docs: list[Document] = [...]  # New documents retrieved in this step
    new_reranked_docs: list[Document] = [...]  # New documents

  # Within the function, return the FULL state object with updated lists.
  return State(
      original_question=state["original_question"],
      ... other fields ...
      retrieved_documents=state["retrieved_documents"] + new_retrieved_docs,
      reranked_documents=state["reranked_documents"] + new_reranked_docs,
    )
```

In [None]:
"""Prompt templates for various agent interactions."""

query_validation_prompt: str = """
<SYSTEM>
    <ROLE>Expert analyzing and validating user questions.</ROLE>
    <TOPICS>{topics}</TOPICS>
    <TASK>Determine if the question relates to provided topics and decide next action.</TASK>
    <GUIDELINES>
        - Set `is_related_to_context` to True if question is relevant to topics, False otherwise
        - Set `next_action` to `Continue` if relevant, `Finish` if not relevant
    </GUIDELINES>
</SYSTEM>
"""

planner_prompt: str = """
<SYSTEM>
    <ROLE>Expert decomposing user queries into efficient multi-step plans.</ROLE>

    <GUIDELINES>
        - Create 2-5 logical steps that build upon each other (use more ONLY if absolutely necessary)
        - Each step should be atomic and answer a specific question
        - Mix `web_search` and `vector_store` tools appropriately
        - Do NOT include summarization/synthesis steps (handled separately)
        - Each step needs clear rationale for why it's necessary
        - Make questions specific and focused for targeted retrieval
        - For `vector_store`, ALWAYS specify `target_section`
    </GUIDELINES>

    <SECTIONS>{section_titles}</SECTIONS>

    <TOOLS>
        - web_search: Search web for up-to-date information
        - vector_store: Search internal documents by section
    </TOOLS>

    <OUTPUT>
        Return Plan with Steps containing: question, rationale, tool, search_keywords (3-5), target_section
    </OUTPUT>
</SYSTEM>
"""


retriever_type_prompt: str = """
<ROLE>You are an expert at selecting optimal retrieval methods based on query characteristics.</ROLE>

<QUERY>{question}</QUERY>

<METHODS>
    Choose:
    - vector_search if: Query is conceptual, uses natural language, seeks related information
    - keyword_search if: Query has specific terms, proper nouns, technical codes, exact phrases required
    - hybrid_search if: Query needs both semantic context and precise term matching
</METHODS>
"""

query_rewriter_prompt: str = """
<ROLE>Query optimizer for document retrieval and web search.</ROLE>

<GUIDELINES>
    - Extract core intent, remove ambiguity
    - Use specific, domain-relevant terms
    - Retain critical details (names, dates, figures)
    - Output 5-10 keywords/phrases
</GUIDELINES>

<QUERY>{question}</QUERY>
<KEYWORDS>{search_keywords}</KEYWORDS>

<OUTPUT>Return 3-7 query variations capturing original intent.</OUTPUT>
"""

decision_prompt: str = """
<SYSTEM>
    <ROLE>
        Master strategist evaluating research progress and determining optimal next actions.
    </ROLE>

    <TASK>
        Analyze completed research against the original question to decide whether 
        to continue execution or finalize the answer.
    </TASK>

    <DECISION_CRITERIA>
        <FINISH_IF>
            - All critical aspects of the original question are COMPLETELY addressed
            - Sufficient evidence and data have been collected
            - Remaining plan steps would add minimal value
        </FINISH_IF>

        <CONTINUE_IF>
            - Key parts of the question remain unanswered
            - Critical dependencies in the plan are not yet satisfied
            - Collected information has gaps or lacks specificity
        </CONTINUE_IF>
    </DECISION_CRITERIA>

    <EVALUATION_PROCESS>
        1. Review the original question's requirements
        2. Assess what information has been gathered in completed steps
        3. Identify gaps between collected findings and question needs
        4. Consider whether remaining plan steps address those gaps
    </EVALUATION_PROCESS>

    <GUIDELINES>
        - Prioritize answer completeness over plan completion
        - A partial plan execution can be sufficient if the question is answered
        - Don't continue simply to complete all steps if information is adequate
    </GUIDELINES>

    <OUTPUT_FORMAT>
        Respond with:
        - Decision: [FINISH | CONTINUE]
        - Rationale: Brief explanation (1-2 sentences) of why this decision is optimal
    </OUTPUT_FORMAT>

    <QUERY>{question}</QUERY>

    <INITIAL_PLAN>{plan}</INITIAL_PLAN>

</SYSTEM>
"""

compression_prompt: str = """
<SYSTEM>
    <ROLE>Expert analyst — compress retrieved content into a single, dense, factual paragraph.</ROLE>

    <QUERY>{question}</QUERY>

    <REQUIREMENTS>
        - Exactly one paragraph, 3–6 sentences
        - Include all key facts, figures, dates, names, and precise details
        - Focus only on information most relevant to the query
        - Remain 100% objective — no interpretation, opinions, or added commentary
        - Use precise language; never paraphrase numbers or technical terms
        - Start directly with the content (never "The document states…", "According to…", etc.)
    </REQUIREMENTS>

    <OUTPUT_FORMAT>
        First line: [Source: [<concise title if none exists>](<URL>)]
        Second line: Content: <single dense paragraph>

        <EXAMPLE>
            [Source: [NVIDIA 2023 10-K Risk Factors](https://nvidia.com/10k-2023.pdf)]
            Content: NVIDIA faces intense competition...
        </EXAMPLE>
    </OUTPUT_FORMAT>
</SYSTEM>
"""

summarization_prompt: str = """
<SYSTEM>
    <ROLE>
        Research assistant creating concise summaries of retrieved findings 
        for multi-step reasoning continuity.
    </ROLE>

    <TASK>
        Summarize the key findings from the context in ONE clear sentence that:
        - Directly answers the sub-question
        - Includes specific facts, numbers, or conclusions with citations
        - Can be referenced by subsequent reasoning steps
        - Remains factual without interpretation
    </TASK>

    <FORMAT>
        Write a single declarative sentence. Avoid phrases like "The context shows..." 
        or "According to the document..." Start directly with the finding.
    </FORMAT>

    <EXAMPLES>
        Query: "What were Apple's R&D expenses in 2023?"
        Good: "Apple's R&D expenses were $29.9 billion in fiscal 2023, representing 7.8% of net sales."
        Poor: "The context indicates that Apple spent money on research and development."

        Query: "What are the company's main competitive risks?"
        Good: "The company faces competitive risks from pricing pressure, rapid technological change, and new market 
        entrants in emerging economies."
        Poor: "There are several competitive risks mentioned in the document."
    </EXAMPLES>

    <OUTPUT_FORMAT>
        First line: [Source: [<concise title if none exists>](<URL>)]
        Second line: Content: <ssummarized content>

        <EXAMPLE>
            [Source: [NVIDIA 2023 10-K Risk Factors](https://nvidia.com/10k-2023.pdf)]
            Content: NVIDIA faces intense competition...
        </EXAMPLE>
    </OUTPUT_FORMAT>

    <QUERY>{question}</QUERY>
</SYSTEM>
"""

final_answer_prompt: str = """
<SYSTEM>
    Expert at synthesizing research from multiple sources into brief, well-cited answers.

    <TASK>
        Integrate internal documents and web sources into a coherent narrative answering the user's question.
    </TASK>

    <GUIDELINES>
        <STRUCTURE>
            - 1-3 paragraphs based on query complexity
            - Prioritize concise responses (2-6 sentences) unless complexity demands more
            - Lead with direct answer, support with evidence
            - Organize: facts → analysis → implications
            - Always conclude with key takeaways when appropriate.
                i.e. **Key Takeaways**
                    * point 1
                    * point 2, etc
        </STRUCTURE>

        <CITATIONS>
            - Cite every sentence with specific facts/data/claims
            - [Source: [<TITLE>](<URL>)]
            - Don't cite general knowledge or transitions
            - Citations MUST be SECTION TITLES or URLs only
        </CITATIONS>

        <STANDARDS>
            - Ground all claims in provided context—no speculation
            - Use precise figures and dates
            - Maintain professional, objective tone
            - Address all parts of the question
            - Acknowledge gaps if context insufficient
        </STANDARDS>

    </GUIDELINES>

    <AVOID>
        - Uncited factual claims
        - Vague statements when specifics available
        - Bullet points (use prose)
        - Unnecessary preambles like "Based on the research...", "The document states...", etc.
        - Mixed citation formats
    </AVOID>

    <QUERY>{question}</QUERY>
</SYSTEM>
"""

In [None]:
hello_prompt: str = """{section_titles}"""
hello_prompt.format(section_titles=" | ".join(section_titles))

In [None]:
# utils.py

from enum import Enum

import instructor
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langsmith import traceable
from openai import AsyncOpenAI

from src.config import app_settings

# from src.schemas.types import OpenRouterModels, PydanticModel


class OpenRouterModels(str, Enum):
    """OpenRouter LLMs."""

    GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-001"
    GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
    GEMINI_2_5_FLASH_LITE = "google/gemini-2.5-flash-lite"
    GPT_OSS_120B = "openai/gpt-oss-120b"
    GPT_OSS_20B = "openai/gpt-oss-20b"
    GPT_5_NANO = "openai/gpt-5-nano"
    LLAMA_3_3_70B_INSTRUCT = "meta-llama/llama-3.3-70b-instruct"
    LLAMA_3_8B_INSTRUCT = "meta-llama/llama-3-8b-instruct"
    NEMOTRON_NANO_9B_V2 = "nvidia/nemotron-nano-9b-v2"
    QWEN3_30B_A3B = "qwen/qwen3-30b-a3b"
    QWEN3_NEXT_80B_A3B_INSTRUCT = "qwen/qwen3-next-80b-a3b-instruct"
    QWEN3_32B = "qwen/qwen3-32b"
    SAO10K_L3_LUNARIS_8B = "sao10k/l3-lunaris-8b"
    X_AI_GROK_4_FAST = "x-ai/grok-4-fast"
    X_AI_GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
    Z_AI_GLM_4_5 = "z-ai/glm-4.5"


_async_client = AsyncOpenAI(
    api_key=app_settings.OPENROUTER_API_KEY.get_secret_value(),
    base_url=app_settings.OPENROUTER_URL,
)

aclient = instructor.from_openai(
    _async_client,
    mode=instructor.Mode.OPENROUTER_STRUCTURED_OUTPUTS,
)
type PydanticModel = type[BaseModel]


@traceable
async def get_structured_output(
    messages: list[dict[str, Any]],
    schema: PydanticModel,
    model: OpenRouterModels | None = None,
) -> PydanticModel:
    """
    Retrieves structured output from a chat completion model.

    Parameters
    ----------
    messages : list[dict[str, Any]]
        The list of messages to send to the model for the chat completion.
    model : RemoteModel
        The remote model to use for the chat completion (e.g., 'gpt-4o').
    schema : PydanticModel
        The Pydantic schema to enforce for the structured output.

    Returns
    -------
    BaseModel
        An instance of the provided Pydantic schema containing the structured output.

    Notes
    -----
    This is an asynchronous function that awaits the completion of the API call.
    """
    model = model if model else RemoteModel.GEMINI_2_5_FLASH_LITE

    return await aclient.chat.completions.create(
        model=model,
        response_model=schema,
        messages=messages,  # type: ignore
        temperature=0.0,
        seed=42,
        max_retries=5,
    )


def convert_langchain_messages_to_dicts(
    messages: list[HumanMessage | SystemMessage | AIMessage],
) -> list[dict[str, str]]:
    """Convert LangChain messages to a list of dictionaries.

    Parameters
    ----------
    messages : list[HumanMessage | SystemMessage | AIMessage]
        List of LangChain message objects to convert.

    Returns
    -------
    list[dict[str, str]]
        List of dictionaries with 'role' and 'content' keys.
        Roles are mapped as follows:
        - HumanMessage -> "user"
        - SystemMessage -> "system"
        - AIMessage -> "assistant"

    """
    role_mapping: dict[str, str] = {
        "SystemMessage": "system",
        "HumanMessage": "user",
        "AIMessage": "assistant",
    }

    converted_messages: list[dict[str, str]] = []
    for msg in messages:
        message_type: str = msg.__class__.__name__
        # Default to "user" if unknown
        role: str = role_mapping.get(message_type, "user")
        converted_messages.append({"role": role, "content": msg.content})  # type: ignore

    return converted_messages


def append_memory(existing: dict[str, Any], new: dict[str, Any]) -> dict[str, Any]:
    """Merge new memory data into existing memory, appending lists and merging dicts.

    Parameters
    ----------
    existing: dict[str, Any]
        The existing memory data.
    new: dict[str, Any]
        The new memory data to merge.

    Returns
    -------
    dict[str, Any]
        The merged memory data.
    """
    result: dict[str, Any] = existing.copy()

    for key, new_value in new.items():
        # Skip None or empty values
        if new_value is None or new_value == "" or new_value == []:
            continue

        existing_value = result.get(key)

        # If key doesn't exist, just add it
        if existing_value is None:
            result[key] = new_value
            continue

        # Lists: combine and remove duplicates
        if isinstance(new_value, list):
            combined = existing_value + new_value
            # Preserve order, remove duplicates
            # ["a", "b", "a"] -> ["a", "b"]
            result[key] = list(dict.fromkeys(combined))

        # Dicts: merge
        elif isinstance(new_value, dict):
            result[key] = {**existing_value, **new_value}

        # Everything else: new value overwrites
        else:
            result[key] = new_value

    return result

In [None]:
from typing import Any, cast

from langchain_core.documents.base import Document

type RetrieverFn = Callable[[str, str | None, int], Coroutine[Any, Any, list[Document]]]
retrieval_method_dicts: dict[str, RetrieverFn] = {
    RetrieverMethodType.VECTOR_SEARCH: avector_search_tool,
    RetrieverMethodType.KEYWORD_SEARCH: akeyword_search_tool,
    RetrieverMethodType.HYBRID_SEARCH: ahybrid_search_tool,
}


# =========================================================
# ============== HELPER FUNCTIONS FOR NODES ===============
# =========================================================
def deduplicate(documents: list[Document]) -> list[Document]:
    """Deduplicate documents based on 'chunk_id' in metadata."""
    docs_dict: dict[str, Document] = {}

    if not documents[0].metadata:
        raise ValueError(
            "Cannot deduplicate documents without 'chunk_id' in metadata. Please ensure documents have "
            "'chunk_id' in their metadata."
        )
    for doc in documents:
        if (_id := doc.metadata["chunk_id"]) not in docs_dict:
            docs_dict[_id] = doc
    return list(docs_dict.values())


def format_documents(documents: list[Document]) -> str:
    """Format documents for synthesis input."""
    delimiter: str = "===" * 20
    try:
        docs: list[str] = [
            f"[Source]: {doc.metadata['source_doc']}\n[Content]: {doc.page_content}\n{delimiter}"
            for doc in documents
        ]
    except KeyError:
        docs = [
            f"[Source]: {doc.metadata['url']}\n[Content]: {doc.page_content}\n{delimiter}"
            for doc in documents
        ]
    formated_docs: str = "\n\n".join(docs)

    return formated_docs


async def aretrieve_internal_documents(
    method: RetrieverMethodType | str,
    rewritten_queries: list[str],
    target_section: str | None,
    k: int,
) -> list[Document]:
    """Retrieve internal documents using the specified retrieval method.

    Parameters
    ----------
    method : RetrieverMethodType | str
        Retrieval method to use (`vector_search`, `keyword_search`, or `hybrid_search`).
    rewritten_queries : list[str]
        Query variations produced by the query rewriter for this step.
    target_section : str | None
        Target section filter for internal document search. Only applied when
        method is `vector_search` or `hybrid_search`; ignored for pure keyword search.
    k : int
        Number of top documents to retrieve per query before deduplication.

    Returns
    -------
    list[Document]
        List of unique retrieved documents across all query variations.

    Raises
    ------
    ValueError
        If the provided method is not supported.
    """
    method: RetrieverMethodType = (
        method
        if isinstance(method, RetrieverMethodType)
        else RetrieverMethodType(method)
    )
    retrieval_fn = retrieval_method_dicts.get(method)
    if retrieval_fn is None:
        raise ValueError(f"Unsupported retrieval method: {method}")

    tasks = [
        retrieval_fn(
            query=query,
            filter=target_section,
            k=k,
        )
        for query in rewritten_queries
    ]
    all_docs: list[list[Document]] = await asyncio.gather(*tasks)
    # Flatten the docs
    retrieved_docs: list[Document] = [doc for sublist in all_docs for doc in sublist]  # type: ignore

    return deduplicate(documents=retrieved_docs)


async def query_rewriter(question: str, search_keywords: list[str]) -> ReWrittenQuery:
    """Re-write the user's question into multiple query variations."""
    prompt = query_rewriter_prompt.format(
        question=question, search_keywords=", ".join(search_keywords)
    )
    messages = convert_langchain_messages_to_dicts(messages=[HumanMessage(prompt)])
    response = await get_structured_output(
        messages=messages, model=None, schema=ReWrittenQuery
    )
    return cast(ReWrittenQuery, response)


async def determine_retrieval_type(question: str) -> RetrieverMethod:
    """Determine the optimal retrieval method for the given question."""
    prompt = retriever_type_prompt.format(question=question)
    messages = convert_langchain_messages_to_dicts(messages=[HumanMessage(prompt)])
    response = await get_structured_output(
        messages=messages, model=None, schema=RetrieverMethod
    )
    return cast(RetrieverMethod, response)


def convert_context_to_str(state_state: list[StepState]) -> str:
    """This function converts the list of StepState dictionaries into a single string.

    Parameters
    ----------
    state_state : list[StepState]
        The list of StepState dictionaries representing the research history.

    Returns
    -------
    str
        A single string representation of the research history.
    """
    return "\n\n".join(
        [
            f"Step {s['step_index']}: {s['question']}\nSummary: {s['summary']}"
            for s in state_state
        ]
    )


def format_plan(plan: Plan | None) -> str:
    """Format the plan into a string representation.

    Parameters
    ----------
    plan : Plan
        The multi-step plan to be formatted.

    Returns
    -------
    str
        A string representation of the plan.
    """
    if plan is None:
        return ""
    return json.dumps([step.model_dump() for step in plan.steps])


async def get_decision(question: str, plan: Plan | None, history: str) -> Decision:
    """This node is used to determine whether to continue with the plan or finish.

    Parameters
    ----------
    question : str
        The original user question.
    plan : Plan
        The multi-step plan object.
    history : str
        The history of completed steps.

    Returns
    -------
    Decision
        The decision object containing the next action and rationale.
    """
    sys_msg = decision_prompt.format(question=question, plan=format_plan(plan=plan))
    history_query: str = f"<COMPLETED_STEPS>{history}</COMPLETED_STEPS>"

    messages: list[dict[str, str]] = convert_langchain_messages_to_dicts(
        messages=[SystemMessage(sys_msg), HumanMessage(history_query)]
    )
    response = await get_structured_output(
        messages=messages, model=None, schema=Decision
    )
    return cast(Decision, response)


async def rerank_retrieved_documents(state: State) -> dict[str, Any]:
    """Rerank documents by relevance to query."""
    k: int = 3
    question: str = state["original_question"]
    retrieved_documents: list[Document] = state["retrieved_documents"]
    # Get the details of the current step
    current_step_idx: int = state["current_step_index"]
    current_step: Step = state["plan"].steps[current_step_idx]
    console.print(
        f"Retrieving documents for reranking for Step {current_step_idx}: {current_step.question}"
    )

    reranked_docs: list[Document] = await arerank_documents(
        query=question, documents=retrieved_documents, k=k
    )

    return {"reranked_documents": reranked_docs}


async def compression_documents(state: State) -> dict[str, Any]:
    """Synthesize final context from reranked documents."""
    reranked_documents: list[Document] = state["reranked_documents"]
    # Get the details of the current step
    current_step_idx: int = state["current_step_index"]
    current_step: Step = state["plan"].steps[current_step_idx]

    # Format document
    user_query: str = f"<DOCUMENT>{format_documents(reranked_documents)}</DOCUMENT>"
    sys_msg: str = compression_prompt.format(question=current_step.question)
    console.print(
        f"Synthesizing documents for Step {current_step_idx}: {current_step.question}"
    )
    response = await remote_llm.ainvoke(
        [SystemMessage(sys_msg), HumanMessage(user_query)]
    )

    return {"synthesized_context": response.content}


# =========================================================
# ========================= NODES =========================
# =========================================================


@traceable
async def validate_query_node(state: State) -> dict[str, Any]:
    """Validate the user's query to ensure it is relevant to the specified topics.

    Parameters
    ----------
    state : State
        Current state containing the original_question.

    Returns
    -------
    dict[str, Any]
        Updated state with validation results.
    """
    topics: str = "NVIDIA's financial performance, form 10-K internal documents, news related to NVIDIA, and industry trends, "
    user_question: str = state["original_question"]
    user_query: str = f"<USER_QUESTION>{user_question}</USER_QUESTION>"
    sys_msg = query_validation_prompt.format(
        topics=topics, section_titles=" | ".join(section_titles)
    )
    messages = convert_langchain_messages_to_dicts(
        messages=[SystemMessage(content=sys_msg), HumanMessage(content=user_query)]
    )
    console.print("🚨 Validating user question against context topics...")
    response = await get_structured_output(
        messages=messages, model=None, schema=ValidateQuery
    )
    response = cast(ValidateQuery, response)
    console.print(
        f"🚨 Related to topic?: {response.is_related_to_context} | "
        f"Next Action: {response.next_action} | Rationale: {response.rationale}"
    )

    step_state: list[StepState] = [
        StepState(
            step_index=-1,
            question=user_question,
            rewritten_queries=[],
            retrieved_documents=[],
            summary=response.rationale,
        )
    ]

    return {
        "current_step_index": -1,
        "is_related_to_context": response.is_related_to_context,
        "step_state": step_state,
        "plan": None,
    }


@traceable
async def generate_plan_node(state: State) -> dict[str, Any]:
    """Generate a multi-step plan based on the user's question.

    Parameters
    ----------
    state : State
        Current state containing the original_question.

    Returns
    -------
    dict[str, Any]
        Updated state with the generated plan.
    """
    # If plan already exists, return empty update (no overwrite)
    if state.get("plan"):
        return {}

    user_question: str = state["original_question"]
    user_query: str = f"<USER_QUESTION>{user_question}</USER_QUESTION>"

    sys_msg = planner_prompt.format(section_titles=" | ".join(section_titles))

    messages = convert_langchain_messages_to_dicts(
        messages=[SystemMessage(content=sys_msg), HumanMessage(content=user_query)]
    )
    response = await get_structured_output(
        messages=messages, model="x-ai/grok-4.1-fast:free", schema=Plan
    )
    response = cast(Plan, response)
    console.print(f"Number of steps: {len(response.steps)}...")

    return {
        "is_related_to_context": True,
        "plan": response,
        "step_state": [],
        "current_step_index": 0,
    }


@traceable
async def retrieve_internal_docs_node(state: State) -> dict[str, Any]:
    """Retrieve internal documents node."""
    k: int = 5
    # Get the details of the current step
    current_step_idx: int = state["current_step_index"]
    current_step: Step = state["plan"].steps[current_step_idx]
    console.print(
        f"🛢 Using Vector DB\nRetrieving documents for Step {current_step_idx}: {current_step.question}"
    )

    # Re-write the query and determine retrieval method concurrently
    re_written_query_obj, retriever_method = await asyncio.gather(
        query_rewriter(
            question=current_step.question,
            search_keywords=current_step.search_keywords,
        ),
        determine_retrieval_type(question=current_step.question),
    )

    rewritten_queries: list[str] = re_written_query_obj.rewritten_query
    console.print(f"Re-written queries: {rewritten_queries}")
    console.print(
        f"Selected retrieval method: {retriever_method.method};\nRationale: {retriever_method.rationale}"
    )

    # Retrieve documents based on the selected method
    retrieved_docs: list[Document] = await aretrieve_internal_documents(
        method=retriever_method.method,
        rewritten_queries=rewritten_queries,
        target_section=current_step.target_section,
        k=k,
    )
    step_state = StepState(
        step_index=current_step_idx,
        question=current_step.question,
        rewritten_queries=rewritten_queries,
        retrieved_documents=retrieved_docs,
        summary="",
    )
    # Update the state with retrieved docs and step_state
    return {
        "step_state": [step_state],
        "retrieved_documents": retrieved_docs,
    }


@traceable
async def internet_search_node(state: State) -> list[Document]:
    """Retrieve documents from the web using re-written queries.

    Parameters
    ----------
    state : State
        Current state of the agent.

    Returns
    -------
    list[Document]
        The retrieved documents
    """
    k: int = 5
    # Get the details of the current step
    current_step_idx: int = state["current_step_index"]
    current_step: Step = state["plan"].steps[current_step_idx]
    console.print(
        f"Retrieving documents for Step {current_step_idx}: {current_step.question}"
    )

    # Re-write the query using the query re-writer
    re_written_query_obj: ReWrittenQuery = await query_rewriter(
        question=current_step.question,
        search_keywords=current_step.search_keywords,
    )
    rewritten_queries: list[str] = re_written_query_obj.rewritten_query
    console.print(f"🌐 WEB SEARCH\nRe-written queries: {rewritten_queries}")

    tasks: list[Coroutine[Any, Any, list[Document]]] = [
        atavily_web_search_tool(
            query=query,
            fetch_full_page=False,
            k=k,
            max_chars=None,
        )
        for query in rewritten_queries
    ]
    all_docs: list[list[Document]] = await asyncio.gather(*tasks)
    # Flatten the docs
    retrieved_docs: list[Document] = [doc for sublist in all_docs for doc in sublist]  # type: ignore

    step_state = StepState(
        step_index=current_step_idx,
        question=current_step.question,
        rewritten_queries=rewritten_queries,
        retrieved_documents=retrieved_docs,
        summary="",
    )
    # Update the state with retrieved docs and step_state
    return {
        "step_state": [step_state],
        "retrieved_documents": retrieved_docs,
    }


@traceable
async def rerank_and_compress_node(state: State) -> dict[str, Any]:
    """Rerank documents and then synthesize final context.

    Parameters
    ----------
    state : State
        The current state containing retrieved documents and other info.

    Returns
    -------
    dict[str, Any]
        Updated state with reranked documents and synthesized context.
    """
    rerank_result = await rerank_retrieved_documents(state)
    # Update state with reranked documents
    updated_state = {**state, **rerank_result}
    compression_result = await compression_documents(updated_state)
    return {**rerank_result, **compression_result}


@traceable
async def summarization_node(state: State) -> dict[str, Any]:
    """Synthesize final context from reranked documents. This node is also responsible for moving to
    the next step in the multi-step plan.
    """

    synthesized_context: str = state["synthesized_context"]
    # Get the details of the current step
    current_step_idx: int = state["current_step_index"]
    current_step: Step = state["plan"].steps[current_step_idx]
    rewritten_queries = [
        step for step in state["step_state"] if current_step_idx == step["step_index"]
    ][0]["rewritten_queries"]

    # Format document
    context: str = f"<CONTEXT>{synthesized_context}</CONTEXT>"
    sys_msg: str = summarization_prompt.format(question=current_step.question)
    console.print(f"Summarizing for Step {current_step_idx}: {current_step.question}")
    response = await remote_llm.ainvoke([SystemMessage(sys_msg), HumanMessage(context)])

    new_step_state: StepState = StepState(
        step_index=current_step_idx,
        question=current_step.question,
        rewritten_queries=rewritten_queries,
        retrieved_documents=state["retrieved_documents"],
        summary=response.content,
    )
    console.print(
        f"⚠️ Number of steps completed: {current_step_idx + 1} | Num iterations: {state.get('num_iterations', 0) + 1}"
    )

    # Append the new step state to the existing list and increment current step index
    return {
        "step_state": state.get("step_state", []) + [new_step_state],
        "current_step_index": current_step_idx + 1,
        "synthesized_context": response.content,
        "num_iterations": state.get("num_iterations", 0) + 1,
    }


@traceable
async def final_answer_node(state: State) -> dict[str, Any]:
    """Generate the final answer with citations based on all collected evidence."""

    console.print("--- ✅: Generating Final Answer with Citations ---")
    # Gather all the evidence we've collected from ALL steps.
    final_context: str = ""
    for i, step in enumerate(state["step_state"]):
        final_context += f"\n--- Findings from Research Step {i + 1} ---\n"
        # Include the source metadata (section or URL) for each document to enable citations.
        for doc in step["retrieved_documents"]:
            source: str = doc.metadata.get("source_doc") or doc.metadata.get("url")
            final_context += f"Source: {source}\nContent: {doc.page_content}\n\n"

    prompt: str = final_answer_prompt.format(question=state["original_question"])
    context: str = f"<CONTEXT>{final_context}</CONTEXT>"

    final_answer = await remote_llm.ainvoke(
        [SystemMessage(prompt), HumanMessage(context)]
    )
    # Update the state with the final answer and reset num_iterations
    return {
        "final_answer": final_answer.content,
        "num_iterations": 0,  # Reset counter for next question
    }


def unrelated_query_node(state: State) -> dict[str, Any]:  # noqa: ARG001
    """Handle unrelated queries by providing a default response.

    Parameters
    ----------
    state : State
        The current state of the agent.

    Returns
    -------
    dict[str, Any]
        Updated state with a default final answer.
    """
    console.print("🚨 Query unrelated to context. Generating default response...")
    default_response: str = (
        "I'm sorry, but your question does not relate to the available information "
        "about NVIDIA's financial performance, form 10-K, news related to NVIDIA, "
        "or industry trends. Please ask a question relevant to these topics."
    )
    return {
        "plan": None,
        "final_answer": default_response,
        "num_iterations": 0,  # Reset counter for next question
    }


# =========================================================
# =================== CONDITIONAL NODES ===================
# =========================================================
def route_by_tool_condition(state: State) -> ToolsType:
    """Determine the tool type for the current step.

    Parameters
    ----------
    state : State
        The current state of the agent.

    Returns
    -------
    ToolsType
        The tool type for the current step.
    """
    current_step_idx: int = state["current_step_index"]
    current_step: Step = state["plan"].steps[current_step_idx]
    return current_step.tool


async def should_continue_condition(
    state: State, max_reasoning_interations: int = 8
) -> NextAction:
    """Determine if the current step should be retried.

    Parameters
    ----------
    state : State
        The current state of the agent.
    max_reasoning_interations : int, optional
        The maximum number of reasoning iterations allowed, by default 8.

    Returns
    -------
    NextAction
        The next action to take (CONTINUE or FINISH).
    """
    print("--- Evaluating Multi Step Reasoning Policy ---")
    is_related_to_context: bool = state.get("is_related_to_context", True)
    current_step_idx: int = state["current_step_index"]
    num_iterations: int = state.get("num_iterations", 0)

    # Checks
    # If query does NOT relate to the topics, finish immediately
    if not is_related_to_context:
        console.print(" -> Query not related to context. Finishing...")
        return NextAction.FINISH

    # Are all the steps completed?
    if state["plan"] and (current_step_idx >= len(state["plan"].steps)):
        console.print(f" -> Plan complete. {num_iterations} iterations. Finishing...")
        return NextAction.FINISH

    # Is the max num of iterations exhausted?
    if num_iterations >= max_reasoning_interations:
        console.print(
            f" -> Max iterations reached. {num_iterations} iterations. Finishing..."
        )
        return NextAction.FINISH

    # Last retrieval step failed to find any docs
    if state.get("reranked_documents") is not None and not state["reranked_documents"]:
        console.print(
            "⚠️ -> Retrieval failed for the last step. Continuing with next step in plan."
        )
        return NextAction.CONTINUE

    # If the conditions above are NOT met
    history: str = convert_context_to_str(state["step_state"])

    # Get decision from LLM
    decision = await get_decision(
        question=state["original_question"],
        plan=state["plan"],
        history=history,
    )
    console.print(
        f" -> Decision: {decision.next_action} | Rationale: {decision.rationale}"
    )

    if decision.next_action == NextAction.FINISH:
        return NextAction.FINISH
    return NextAction.CONTINUE

In [None]:
from IPython.display import Image, Markdown, display
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.types import RetryPolicy

In [None]:
max_attempts: int = 3
initial_interval: float = 1.0

builder: StateGraph = StateGraph(State)

builder.add_node(
    "validate_query",
    validate_query_node,
    retry_policy=RetryPolicy(
        max_attempts=max_attempts, initial_interval=initial_interval
    ),
)
builder.add_node(
    "unrelated_query",
    unrelated_query_node,
    retry_policy=RetryPolicy(
        max_attempts=max_attempts, initial_interval=initial_interval
    ),
)

builder.add_node(
    "generate_plan",
    generate_plan_node,
    retry_policy=RetryPolicy(
        max_attempts=max_attempts, initial_interval=initial_interval
    ),
)
builder.add_node(
    "retrieve_internal_docs",
    retrieve_internal_docs_node,
    retry_policy=RetryPolicy(
        max_attempts=max_attempts, initial_interval=initial_interval
    ),
)
builder.add_node(
    "internet_search",
    internet_search_node,
    retry_policy=RetryPolicy(
        max_attempts=max_attempts, initial_interval=initial_interval
    ),
)
builder.add_node(
    "rerank_and_compress",
    rerank_and_compress_node,
    retry_policy=RetryPolicy(
        max_attempts=max_attempts, initial_interval=initial_interval
    ),
)

builder.add_node(
    "summarize",
    summarization_node,
    retry_policy=RetryPolicy(
        max_attempts=max_attempts, initial_interval=initial_interval
    ),
)
builder.add_node(
    "final_answer",
    final_answer_node,
    retry_policy=RetryPolicy(
        max_attempts=max_attempts, initial_interval=initial_interval
    ),
)

# Add edges

builder.add_edge(START, "validate_query")
builder.add_conditional_edges(
    "validate_query",
    should_continue_condition,
    {
        NextAction.CONTINUE: "generate_plan",
        NextAction.FINISH: "unrelated_query",
    },
)
builder.add_conditional_edges(
    "generate_plan",
    route_by_tool_condition,  # function to determine which tool to use
    {
        ToolsType.VECTOR_STORE: "retrieve_internal_docs",
        ToolsType.WEB_SEARCH: "internet_search",
    },
)

builder.add_edge("retrieve_internal_docs", "rerank_and_compress")
builder.add_edge("internet_search", "rerank_and_compress")
builder.add_edge("rerank_and_compress", "summarize")
builder.add_conditional_edges(
    "summarize",
    should_continue_condition,  # function to determine next action
    {NextAction.CONTINUE: "generate_plan", NextAction.FINISH: "final_answer"},
)
builder.add_edge("final_answer", END)
builder.add_edge("unrelated_query", END)

# Compile the graph
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)

# Visualize the graph with ASCII fallback
try:
    display(Image(graph.get_graph(xray=1).draw_mermaid_png()))
except Exception as e:
    console.print(f"[yellow]PNG visualization failed: {e}[/yellow]")
    console.print("[cyan]Displaying ASCII representation instead:[/cyan]\n")
    try:
        print(graph.get_graph(xray=1).draw_ascii())
    except ImportError as ie:
        console.print(f"[red]ASCII visualization also failed: {ie}[/red]")
        console.print("[magenta]Showing basic graph structure:[/magenta]\n")
        graph_obj = graph.get_graph(xray=1)
        console.print(f"Nodes: {[node.id for node in graph_obj.nodes.values()]}")
        console.print(f"Edges: {[(e.source, e.target) for e in graph_obj.edges]}")

In [None]:
# Re-build the graph
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)
user_query: str = (
    "Based on NVIDIA's 2023 10-K filing, identify their key risks related to competition. "
    "Then, find recent news (post-filing, from 2024) about AMD's AI chip strategy and explain "
    "how this new strategy directly addresses or exacerbates one of NVIDIA's stated risks."
)

config: dict[str, Any] = {"configurable": {"thread_id": "test-01"}}
response = await graph.ainvoke(
    {"original_question": user_query},
    config=config,
)

In [None]:
console.print(f"Original Question: {user_query}\n")
Markdown(response.get("final_answer"))

In [None]:
console.print(response["synthesized_context"])

In [None]:
user_query: str = """
How does NVIDIA ensure they remain on top in the industry? What was their tax records in 2023 according to 10-k?
How are activities financed and what was the shares and dividends distribution in 2023 and 2024. 
How much was the gross and marginal profit in 2023 and 2024?
"""

# Re-build the graph
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)

# config: dict[str, Any] = {"configurable": {"thread_id": "test-01"}}
response = await graph.ainvoke(
    {"original_question": user_query},
    config=config,
)

In [None]:
console.print(f"Original Question: {user_query}\n")
Markdown(response.get("final_answer"))

In [None]:
response

In [None]:
user_query: str = """
Identify a major product category (e.g., specific hardware or platform) that NVIDIA includes in its Product Sales Revenue 
according to the search results. Then, contrast the recognition method for this product category with the industry standard 
for recognizing revenue from the associated perpetual software licenses—specifically explaining why the revenue for the 
perpetual license is often recognized up front when the software is made available to the customer.
"""
user_query: str = """
Who is Neidu? Does he work at NVIDIA?
"""
# Re-build the graph
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)

config: dict[str, Any] = {"configurable": {"thread_id": "test-01"}}
response = await graph.ainvoke(
    {"original_question": user_query},
    config=config,
)

In [None]:
Markdown(response.get("final_answer"))

In [None]:
response

In [None]:
# curl -s --compressed "https://api.search.brave.com/res/v1/web/search?q=brave+search" \
#   -H "Accept: application/json" \
#   -H "Accept-Encoding: gzip" \
#   -H "X-Subscription-Token: $BRAVE_SEARCH_API_KEY"


# curl -s --compressed "https://api.search.brave.com/res/v1/web/search?q=weather+in+munich&enable_rich_callback=1" \
#   -H "Accept: application/json" \
#   -H "Accept-Encoding: gzip" \
#   -H "X-Subscription-Token: $BRAVE_SEARCH_API_KEY"

In [None]:
async def brave_search_tool(
    url: str = "/res/v1/web/search",
    params: dict[str, Any] | None = None,
    headers: dict[str, str] | None = None,
    method: str = "GET",
    data: dict[str, Any] | None = None,
    base_url: str = "https://api.search.brave.com",
    timeout: int = 30,
) -> dict[str, Any]:
    """Send HTTP request with query parameters using HTTPXClient.

    Parameters
    ----------
    url : str
        The endpoint URL (can be relative if base_url is provided).
    params : dict[str, Any] or None, default=None
        Query parameters to append to the URL.
    headers : dict[str, str] or None, default=None
        HTTP headers to include in the request.
    method : str, default="GET"
        HTTP method (GET, POST, etc.).
    data : dict[str, Any] or None, default=None
        Request body data for POST requests.
    base_url : str, default=""
        The base URL for the API (e.g., "https://api.search.brave.com").
    timeout : int, default=30
        Request timeout in seconds.

    Returns
    -------
    dict[str, Any]
        Standardized response with keys: success, status_code, data, headers, error.

    Examples
    --------
    >>> # Brave Search API example
    >>> params = {"q": "weather in munich", "enable_rich_callback": 1}
    >>> headers = {
    ...     "Accept": "application/json",
    ...     "Accept-Encoding": "gzip",
    ...     "X-Subscription-Token": app_settings.BRAVE_SEARCH_API_KEY.get_secret_value(),
    ... }
    >>> result = await brave_search_tool(
    ...     url="/res/v1/web/search",
    ...     base_url="https://api.search.brave.com",
    ...     params=params,
    ...     headers=headers,
    ... )
    >>> if result["success"]:
    ...     data = result["data"]
    """
    async with HTTPXClient(base_url=base_url, timeout=timeout) as client:
        if method.upper() == "GET":
            return await client.get(url=url, params=params, headers=headers)
        if method.upper() == "POST":
            return await client.post(url=url, data=data, params=params, headers=headers)
        return {
            "success": False,
            "status_code": None,
            "data": None,
            "headers": None,
            "error": f"Unsupported HTTP method: {method}",
        }

In [None]:
params: dict[str, Any] = {
    "q": "has Nvidia broken any laws in China?",
    "enable_rich_callback": 1,
}
headers: dict[str, str] = {
    "Accept": "application/json",
    "Accept-Encoding": "gzip",
    "X-Subscription-Token": app_settings.BRAVE_SEARCH_API_KEY.get_secret_value(),
}

await brave_search_tool(url="/res/v1/news/search", params=params, headers=headers)