In [None]:
import json
import sys
import warnings
from pathlib import Path
from typing import Any, Coroutine, Literal

import numpy as np
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "white": "#FFFFFF",  # Bright white
        "info": "#00FF00",  # Bright green
        "warning": "#FFD700",  # Bright gold
        "error": "#FF1493",  # Deep pink
        "success": "#00FFFF",  # Cyan
        "highlight": "#FF4500",  # Orange-red
    }
)
console = Console(theme=custom_theme)

# Visualization
# import matplotlib.pyplot as plt

# NumPy settings
np.set_printoptions(precision=4)

# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Polars settings
pl.Config.set_fmt_str_lengths(1_000)
pl.Config.set_tbl_cols(n=1_000)
pl.Config.set_tbl_rows(n=200)

warnings.filterwarnings("ignore")

# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [None]:
def go_up_from_current_directory(*, go_up: int = 1) -> None:
    """This is used to up a number of directories.

    Params:
    -------
    go_up: int, default=1
        This indicates the number of times to go back up from the current directory.

    Returns:
    --------
    None
    """
    import os

    CONST: str = "../"
    NUM: str = CONST * go_up

    # Goto the previous directory
    prev_directory = os.path.join(os.path.dirname(__name__), NUM)
    # Get the 'absolute path' of the previous directory
    abs_path_prev_directory = os.path.abspath(prev_directory)

    # Add the path to the System paths
    sys.path.insert(0, abs_path_prev_directory)
    print(abs_path_prev_directory)


# Demo (Prevents ruff from removing the unused module import)
my_coroutine: Coroutine
my_path: Path = Path(".")
name: Any
category: Literal["A", "B", "C"]
json.loads('{"name": "Smart-RAG", "version": "1.0"}')

{'name': 'Smart-RAG', 'version': '1.0'}

In [3]:
go_up_from_current_directory(go_up=1)

from src.config import app_settings  # noqa: E402
from src.utilities.model_config import RemoteModel  # noqa: E402

settings = app_settings

/Users/mac/Desktop/Projects/smart-rag


In [4]:
from langchain_openai import ChatOpenAI

remote_llm = ChatOpenAI(
    api_key=settings.OPENROUTER_API_KEY.get_secret_value(),  # type: ignore
    base_url=settings.OPENROUTER_URL,
    temperature=0.0,
    model=RemoteModel.GEMINI_2_5_FLASH_LITE,
)


# Test the LLMs
response = remote_llm.invoke("Tell me a very short joke.")
response.pretty_print()


Why did the scarecrow win an award?

Because he was outstanding in his field!


In [5]:
# Use aiohttp for async requests
# Create pipeline for downloading the data
# - add tqdm for progress bar

In [None]:
import httpx


class HTTPXClient:
    def __init__(
        self,
        base_url: str = "",
        timeout: int = 30,
        http2: bool = True,
        max_connections: int = 20,
        max_keepalive_connections: int = 5,
    ) -> None:
        self.base_url = base_url
        self.timeout = timeout
        self.http2 = http2
        self.max_connections = max_connections
        self.max_keepalive_connections = max_keepalive_connections
        self.client = httpx.AsyncClient(
            base_url=self.base_url,
            timeout=self.timeout,
            http2=self.http2,
            limits=httpx.Limits(
                max_connections=self.max_connections,
                max_keepalive_connections=self.max_keepalive_connections,
            ),
        )

    async def __aenter__(self) -> "HTTPXClient":
        return self

    async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
        await self.client.aclose()

    async def get(
        self,
        url: str,
        params: dict[str, Any] | None = None,
        headers: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Perform an asynchronous GET request."""
        try:
            response = await self.client.get(url, params=params, headers=headers)
            return self._parse_response(response)
        except Exception as e:
            return self._handle_exception(e)

    async def post(
        self,
        url: str,
        data: dict[str, Any] | None = None,
        params: dict[str, Any] | None = None,
        headers: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        """Perform an asynchronous POST request."""
        try:
            response = await self.client.post(
                url, data=data, params=params, headers=headers
            )
            return self._parse_response(response)
        except Exception as e:
            return self._handle_exception(e)

    def _parse_response(self, response: httpx.Response) -> dict[str, Any]:
        """Parse the HTTPX response and return a standardized dictionary."""
        try:
            data = response.json()
        except json.JSONDecodeError:
            data = response.text

        return {
            "success": response.status_code < 400,
            "status_code": response.status_code,
            "data": data,
            "headers": dict(response.headers),
            "error": (
                None
                if response.status_code < 400
                else f"HTTP {response.status_code} Error"
            ),
        }

    def _handle_exception(self, e: Exception) -> dict[str, Any]:
        """Handle exceptions and return a standardized error response."""
        if isinstance(e, httpx.ConnectError):
            error_msg = f"Connection Error: {str(e)}"
        elif isinstance(e, httpx.TimeoutException):
            error_msg = f"Request Timeout: {str(e)}"
        else:
            error_msg = f"Unexpected Error: {str(e)}"

        return {
            "success": False,
            "status_code": None,
            "data": None,
            "headers": None,
            "error": error_msg,
        }

In [None]:
async with HTTPXClient() as client:
    response = await client.get(
        "https://www.bbc.com/sport/football/articles/cwy543n274wo"
    )
    print(response)



In [8]:
response["data"]



In [9]:
from markdownify import markdownify as md

console.print(md(response["data"])[3000:5000])

<br>

# RAG Pipeline


## Step 0

- Download and prepare your documents.

In [None]:
import re
import unicodedata
from pathlib import Path
from typing import Any

from bs4 import BeautifulSoup
from markdownify import markdownify as md

# Simplified header string normalizer for common cases
# - decodes \xNN / \uNNNN / \U00NNNNNN escapes if present
# - replaces non‑breaking space with normal space
# - strips zero‑width chars
# - normalizes unicode and collapses whitespace

_ESCAPE_RE: re.Pattern[str] = re.compile(
    r"(?:\\x[0-9a-fA-F]{2}|\\u[0-9a-fA-F]{4}|\\U[0-9a-fA-F]{8})"
)
_ZERO_WIDTH: set[str] = {"\u200b", "\u200c", "\u200d", "\u2060", "\ufeff"}


def normalize_header_string(text: str) -> str:
    """Normalize header-like strings with minimal, safe transforms.

    Applies targeted unicode-escape decoding when present, replaces NBSP, removes
    zero-width characters, normalizes (NFKC), and collapses whitespace.
    """
    # Targeted backslash-escape decoding (avoid decoding unrelated backslashes)
    if _ESCAPE_RE.search(text):

        def _sub(m: re.Match[str]) -> str:
            token = m.group(0)
            try:
                return token.encode().decode("unicode_escape")
            except Exception:
                return token

        text = _ESCAPE_RE.sub(_sub, text)

    # Replace non‑breaking space with normal space
    text = text.replace("\u00a0", " ")

    # Remove zero‑width characters
    if any(ch in text for ch in _ZERO_WIDTH):
        text = "".join(ch for ch in text if ch not in _ZERO_WIDTH)

    # Unicode normalize and collapse whitespace
    text = unicodedata.normalize("NFKC", text)
    return " ".join(text.split())


def clean_xbrl_noise(text: str) -> str:
    """Aggressively remove XBRL noise while preserving document structure.

    This function removes all XBRL/XML metadata and keeps only the meaningful
    HTML content that can be converted to readable markdown.
    """

    body_match = re.search(r"<body[^>]*>(.*)</body>", text, re.DOTALL | re.IGNORECASE)
    if body_match:
        text = "<body>" + body_match.group(1) + "</body>"

    try:
        soup = BeautifulSoup(text, "html.parser")

        # Remove <head> entirely - it contains most XBRL metadata
        for head in soup.find_all("head"):
            head.decompose()

        # Remove all script and style tags
        for tag in soup(["script", "style", "meta", "link"]):
            tag.decompose()

        # Remove XML/XBRL namespaced elements (tags with colons)
        for tag in soup.find_all():
            if tag.name and ":" in tag.name:
                tag.decompose()

        # Remove hidden XBRL data elements (usually display:none or specific XBRL classes)
        for tag in soup.find_all(style=re.compile(r"display:\s*none", re.I)):
            tag.decompose()

        for tag in soup.find_all(class_=re.compile(r"xbrl|hidden", re.I)):
            tag.decompose()

        # Remove specific XBRL attribute clutter
        for tag in soup.find_all():
            if tag.name:
                # Remove XBRL attributes
                attrs_to_remove = []
                for attr in tag.attrs:
                    if (
                        ":" in attr
                        or attr.startswith("xmlns")
                        or attr in ["contextref", "unitref", "decimals"]
                    ):
                        attrs_to_remove.append(attr)  # noqa: PERF401
                for attr in attrs_to_remove:
                    del tag[attr]

        # Get the cleaned HTML
        cleaned: str = str(soup)

    except Exception as e:
        print(f"Warning: HTML parsing failed: {e}")
        cleaned = text

    # Post-processing regex cleanup for any remaining XBRL noise

    # Remove namespace URLs that got left behind
    cleaned = re.sub(
        r'http://[^\s<>"]+(?:xbrl|fasb|sec\.gov)[^\s<>"]*', "", cleaned, flags=re.I
    )

    # Remove XBRL namespace tokens (us-gaap:Something, iso4217:USD, etc.)
    cleaned = re.sub(
        r"\b(?:us-gaap|nvda|srt|stpr|fasb|xbrli|iso4217|xbrl|dei|ix|country|xbrldi|link):[A-Za-z0-9_\-:()]+(?:Member)?\b",
        "",
        cleaned,
        flags=re.I,
    )

    # Remove long numeric strings (CIK numbers, etc.) - 10+ digits
    cleaned = re.sub(r"\b\d{10,}\b", "", cleaned)
    # Remove date patterns that are concatenated without separators (2023-01-292022-01-30)
    cleaned = re.sub(r"(?:\d{4}-\d{2}-\d{2}){2,}", "", cleaned)
    # Remove very long alphanumeric strings (40+ chars) that indicate concatenated tags
    cleaned = re.sub(r"\b[A-Za-z0-9_\-]{40,}\b", "", cleaned)
    # Remove XML/namespace declarations
    cleaned = re.sub(r'xmlns[:\w]*="[^"]*"', "", cleaned)
    cleaned = re.sub(r'xml:\w+="[^"]*"', "", cleaned)
    # Remove "pure" standalone (XBRL unit)
    cleaned = re.sub(r"\bpure\b(?!\s+\w)", "", cleaned)
    # Clean up multiple colons and extra punctuation
    cleaned = re.sub(r":{2,}", ":", cleaned)
    return re.sub(r"\s*:\s*:\s*", " ", cleaned)


async def download_and_parse_data(
    url: str,
    raw_doc_path: Path | str,
    cleaned_doc_path: Path | str,
    force_download: bool = False,
) -> None:
    """Download and parse HTML/XBRL documents with aggressive noise removal.

    Parameters
    ----------
        url : str
            The remote URL to download
        raw_doc_path : Path | str
            Output path for the raw bytes/text
        cleaned_doc_path : Path | str
            Output path for the cleaned markdown/text
        force_download : bool, default=False
            When True, re-download and re-clean even if file(s) exist

    Returns
    -------
        None
    """
    if isinstance(raw_doc_path, str):
        raw_doc_path = Path(raw_doc_path)
    if isinstance(cleaned_doc_path, str):
        cleaned_doc_path = Path(cleaned_doc_path)

    # Safe, identifiable user agent:
    USER_AGENT: str = (
        "MyCompany MyDownloader/1.0 (+https://mycompany.example; dev@mycompany.example)"
    )
    headers: dict[str, str] = {"User-Agent": USER_AGENT, "Accept": "application/json"}

    # If raw document exists and we are not forcing re-download
    if raw_doc_path.exists() and raw_doc_path.is_file() and not force_download:
        print(f"Raw file already exists: {raw_doc_path}. Skipping download.")
    else:
        # Ensure the path exists
        raw_doc_path.parent.mkdir(parents=True, exist_ok=True)

        async with HTTPXClient() as client:
            response: dict[str, Any] = await client.get(url, headers=headers)

        if not response["success"]:
            print(f"Failed to download {url}: {response.get('error')}")
            return

        # Response data may be a dict or string; store as text
        raw_content: Any = response["data"]
        if not isinstance(raw_content, str):
            # Coerce to text safely
            try:
                raw_content = json.dumps(raw_content, ensure_ascii=False)
            except Exception:
                raw_content = str(raw_content)

        raw_doc_path.write_text(raw_content, encoding="utf-8")
        print(f"Saved raw content to {raw_doc_path}")

    # Convert the raw HTML/text into a cleaned markdown or plain text
    raw_text: str = raw_doc_path.read_text(encoding="utf-8")

    # Use the aggressive cleaner to remove XBRL noise
    cleaned_html = clean_xbrl_noise(raw_text)

    # For HTML content, convert to markdown with better formatting
    try:
        # Configure markdownify to preserve more structure
        cleaned_text: str = md(
            cleaned_html,
            heading_style="ATX",  # Use # for headers
            bullets="-",  # Use - for bullet points
            strong_em_symbol="**",  # Use ** for bold
            strip=["script", "style"],  # Remove script and style tags
        )
    except Exception as e:
        # If markdownify fails, try basic text extraction
        print(f"Warning: Markdown conversion failed: {e}")
        try:
            soup = BeautifulSoup(cleaned_html, "html.parser")
            cleaned_text = soup.get_text("\n", strip=True)
        except Exception:
            cleaned_text = cleaned_html

    # Post-processing cleanup on the markdown text
    # Remove lines that are mostly XBRL noise (lots of colons, short tokens)
    lines: list[str] = cleaned_text.split("\n")
    cleaned_lines: list[str] = []
    for line in lines:
        # Skip lines with excessive XBRL patterns
        if len(line) < 10:  # Keep very short lines (might be intentional)
            cleaned_lines.append(line)
            continue

        # Count suspicious patterns
        colon_count = line.count(":")
        token_count = len(
            re.findall(r"\b[A-Z][a-z]+(?:[A-Z][a-z]+)+\b", line)
        )  # CamelCase tokens

        # If line has too many colons or camelCase tokens relative to length, skip it
        if colon_count > len(line) / 20 or (token_count > 5 and len(line.split()) < 20):
            continue

        cleaned_lines.append(line)

    cleaned_text = "\n".join(cleaned_lines)

    # Remove excessive blank lines (more than 2 consecutive)
    cleaned_text = re.sub(r"\n{3,}", "\n\n", cleaned_text)

    # Remove leading/trailing whitespace from each line
    cleaned_text = "\n".join(line.strip() for line in cleaned_text.split("\n"))

    # Final whitespace cleanup
    cleaned_text = cleaned_text.strip()

    # Ensure the path exists
    cleaned_doc_path.parent.mkdir(parents=True, exist_ok=True)
    cleaned_doc_path.write_text(cleaned_text, encoding="utf-8")

    print(f"Saved cleaned content to {cleaned_doc_path}")
    return

In [None]:
url: str = "https://www.sec.gov/Archives/edgar/data/1045810/000104581023000017/nvda-20230129.htm"

await download_and_parse_data(
    url=url, raw_doc_path="raw_doc.txt", cleaned_doc_path="cleaned_doc.txt"
)

Raw file already exists: raw_doc.txt. Skipping download.
Saved cleaned content to cleaned_doc.txt


In [12]:
fp: str = "cleaned_doc.txt"

with Path(fp).open("r", encoding="utf-8") as file:
    cleaned_doc = file.read()

In [13]:
console.print(cleaned_doc[500:1_500])

<br>

## Step 1
- Split documents into chunks and create embeddings.

In [14]:
# from langchain_community.document_loaders import CSVLoader
from langchain_community.document_loaders import TextLoader

loader = TextLoader(fp)  # Integration-specific parameters here

# Load all documents
documents = loader.load()

# For large datasets, lazily load documents
# for document in loader.lazy_load():
#     print(document)

In [15]:
len(documents)

1

In [None]:
from re import Match, Pattern

# Extract 10-K sections with title and content separately (line-by-line comments)
# Get the entire document text from the TextLoader's first document
raw_text: str = documents[0].page_content  # the string to search for ITEM headers

# Header pattern: match 'ITEM 1.' or 'ITEM 1A.' etc. at the beginning of a line
# ^\s*            -> allow leading whitespace before the header
# ITEM\s+         -> the literal word ITEM followed by at least one space
# \d+             -> the item number (one or more digits)
# [A-Z]?           -> optional letter (A, B, etc.) after the number
# \.               -> period following the number (escaped dot)
# [\t ]+          -> at least one whitespace char (tab/space) after the dot
# [^\n\r]*        -> the remainder of the heading line (until newline)
# re.MULTILINE     -> ^ anchors at the beginning of each line
header_pattern: Pattern[str] = re.compile(
    r"^\s*(ITEM\s+\d+[A-Z]?\.[\t ]+[^\n\r]*)", re.MULTILINE
)

# run finditer which returns match objects with start()/end() locations
matches: list[Match[str]] = list(
    header_pattern.finditer(raw_text)
)  # convert to list for indexing

# Prepare lists to hold the results
section_titles: list[str] = []  # will store the header lines like 'ITEM 1. BUSINESS'
# will store the textual content of each section (no header)
section_content: list[str] = []

# Walk through each header match, capturing both title and the content after it
for i, match in enumerate(matches):
    title: str = match.group(1).strip()  # capture the heading text and strip whitespace
    # Normalize the header to handle NBSP/zero-width and consistent spacing
    title = normalize_header_string(title)
    section_titles.append(title)

    # The content begins right after the matched heading line
    start_pos: int = match.end()  # numeric index where this header finishes

    # Determine where this section ends: next header start or the end of the document
    if i + 1 < len(matches):
        end_pos: int = matches[i + 1].start()  # next header's start position
    else:
        end_pos: int = len(raw_text)  # or EOF if this is the last header

    # Use the start/end slices to get the body text and strip leading/trailing whitespace
    content: str = raw_text[start_pos:end_pos].strip()  # remove extra whitespace
    section_content.append(content)  # store the cleaned body in the sections list

# Confirmation print for quick inspection when the cell runs
print(f"Found {len(section_titles)} ITEM sections.")

Found 21 ITEM sections.


### Create Metadata-rich Chunks

In [17]:
from uuid import uuid4

from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1_000,  # chunk size (characters)
    chunk_overlap=50,  # chunk overlap (characters)
    add_start_index=True,  # track index in original document
)

doc_chunks_with_metadata: list[Document] = []

# Loop thru each section's content and its title
for content, title in zip(section_content, section_titles):
    section_chunks: list[str] = text_splitter.split_text(content)

    # Loop thru each chunk to add metadata
    for chunk in section_chunks:
        chunk_id: str = str(uuid4())  # unique ID for this chunk
        doc_chunks_with_metadata.append(  # noqa: PERF401
            Document(
                page_content=chunk,
                metadata={
                    "source_doc": fp,  # original document path
                    # Ensure section titles are normalized in metadata
                    "section": normalize_header_string(title),
                    "chunk_id": chunk_id,  # unique chunk ID
                },
            )
        )

print(f"Created {len(doc_chunks_with_metadata)} document chunks with metadata.")

Created 371 document chunks with metadata.


In [18]:
console.print(doc_chunks_with_metadata[51])

In [19]:
section_titles

['ITEM 1. BUSINESS',
 'ITEM 1A. RISK FACTORS',
 'ITEM 1B. UNRESOLVED STAFF COMMENTS',
 'ITEM 2. PROPERTIES',
 'ITEM 3. LEGAL PROCEEDINGS',
 'ITEM 4. MINE SAFETY DISCLOSURES',
 'ITEM 5. MARKET FOR REGISTRANT’S COMMON EQUITY, RELATED STOCKHOLDER MATTERS AND ISSUER PURCHASES OF EQUITY SECURITIES',
 'ITEM 6. [RESERVED]',
 'ITEM 7. MANAGEMENT’S DISCUSSION AND ANALYSIS OF FINANCIAL CONDITION AND RESULTS OF OPERATIONS',
 'ITEM 7A. QUANTITATIVE AND QUALITATIVE DISCLOSURES ABOUT MARKET RISK',
 'ITEM 8. FINANCIAL STATEMENTS AND SUPPLEMENTARY DATA',
 'ITEM 9. CHANGES IN AND DISAGREEMENTS WITH ACCOUNTANTS ON ACCOUNTING AND FINANCIAL DISCLOSURE',
 'ITEM 9A. CONTROLS AND PROCEDURES',
 'ITEM 9C. DISCLOSURE REGARDING FOREIGN JURISDICTIONS THAT PREVENT INSPECTIONS',
 'ITEM 10. DIRECTORS, EXECUTIVE OFFICERS AND CORPORATE GOVERNANCE',
 'ITEM 11. EXECUTIVE COMPENSATION',
 'ITEM 12. SECURITY OWNERSHIP OF CERTAIN BENEFICIAL OWNERS AND MANAGEMENT AND RELATED STOCKHOLDER MATTERS',
 'ITEM 13. CERTAIN RELATIONS

In [None]:
# Test the Metadata-aware chunking: e.g. 'Risk Factors' should be in the section
sample_chunk = (
    chunk
    for chunk in doc_chunks_with_metadata
    if "risk factors" in chunk.metadata.get("section", "").lower()
)
console.print(next(sample_chunk))

In [None]:
import os
from typing import Any

from langchain_core.embeddings import Embeddings
from langchain_core.utils import convert_to_secret_str
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    SecretStr,
    model_validator,
)

from src.utilities.openrouter.client import AsyncOpenRouterClient, OpenRouterClient


def set_openrouter_api(value: str | None = None) -> SecretStr:
    """Set the OpenRouter API key"""
    if value is None:
        return convert_to_secret_str(os.getenv("OPENROUTER_API_KEY", ""))
    return convert_to_secret_str(value)


class OpenRouterEmbeddings(BaseModel, Embeddings):
    """Using Field with default_factory for automatic client creation."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    client: OpenRouterClient = Field(default_factory=OpenRouterClient)
    aclient: AsyncOpenRouterClient = Field(default_factory=AsyncOpenRouterClient)

    openrouter_api_key: SecretStr = Field(default_factory=set_openrouter_api)
    model: str = Field(default="sentence-transformers/paraphrase-minilm-l6-v2")

    @model_validator(mode="after")
    def validate_environment(self) -> "OpenRouterEmbeddings":
        """Validate the environment and set up the OpenRouter client."""
        _api_key: SecretStr | str = self.openrouter_api_key or os.getenv(
            "OPENROUTER_API_KEY", ""
        )
        if not _api_key:
            raise ValueError(
                "OpenRouter API key not found. Please set the OPENROUTER_API_KEY environment variable."
            )

        if isinstance(_api_key, str):
            _api_key = convert_to_secret_str(_api_key)

        # Set up the OpenRouter client if not already set
        self.client = OpenRouterClient(
            api_key=_api_key.get_secret_value(),  # type: ignore
            default_model=self.model,
        )
        self.aclient = AsyncOpenRouterClient(
            api_key=_api_key.get_secret_value(),  # type: ignore
            default_model=self.model,
        )
        return self

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        """Embed search docs."""
        response: dict[str, Any] = self.client.embeddings.create(
            input=texts, model=self.model
        )
        return [emb["embedding"] for emb in response["data"]]

    def embed_query(self, text: str) -> list[float]:
        """Embed query text."""
        return self.embed_documents([text])[0]

    async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
        """Embed search docs."""
        response: dict[str, Any] = await self.aclient.aembeddings.create(
            input=texts, model=self.model
        )

        return [emb["embedding"] for emb in response["data"]]

    async def aembed_query(self, text: str) -> list[float]:
        """Embed query text."""
        return (await self.aembed_documents([text]))[0]


embeddings = OpenRouterEmbeddings()
result: list[list[float]] = await embeddings.aembed_documents(texts=["Hello there!"])

In [None]:
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

client = QdrantClient(":memory:")

vector_size: int = len(await embeddings.aembed_query("sample text"))
collection_name: str = "smart_rag_collection"

if not client.collection_exists(collection_name):
    client.create_collection(
        collection_name=collection_name,
        vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
    )
vectorstore = QdrantVectorStore(
    client=client,
    collection_name=collection_name,
    embedding=embeddings,
)
# Embed all the documents
document_ids: list[str] = await vectorstore.aadd_documents(
    documents=doc_chunks_with_metadata
)
print(document_ids[:3])

['2164bdb76f7d4d41877efdcbfafa18cc', '3070a7af612b4a5cb0881fe1b1cb722b', 'b5a5482f5f764234a58793599b03d9e4']


In [23]:
doc_chunks_with_metadata[0].model_dump()

{'id': None,
 'metadata': {'source_doc': 'cleaned_doc.txt',
  'section': 'ITEM 1. BUSINESS',
  'chunk_id': 'e6396624-bbe5-4b78-872e-13cd4b01cea5'},
 'page_content': 'Our Company\n\nNVIDIA pioneered accelerated computing to help solve the most challenging computational problems. Since our original focus on PC graphics, we have expanded to several other large and important computationally intensive fields. Fueled by the sustained demand for exceptional 3D graphics and the scale of the gaming market, NVIDIA has leveraged its GPU architecture to create platforms for scientific computing, artificial intelligence, or AI, data science, autonomous vehicles, or AV, robotics, metaverse and 3D internet applications.',
 'type': 'Document'}

## Create Tools

<br>

### VectorSearch (Filtering Chunks by Metadata)

In [None]:
from qdrant_client import QdrantClient, models

query: str = "According to NVIDIA's management, what assurance level do disclosure controls and internal controls actually provide?"

retrieved_docs = vectorstore.similarity_search(
    query,
    k=3,
    filter=models.Filter(
        must=[
            models.FieldCondition(
                key="metadata.section",
                match=models.MatchValue(value="ITEM 9A. CONTROLS AND PROCEDURES"),
            )
        ]
    ),
)
formatted_docs: str = "\n\n".join(
    (f"Source: {doc.metadata}\nContent: {doc.page_content}") for doc in retrieved_docs
)

console.print(formatted_docs)

In [None]:
from langchain.tools import tool
from qdrant_client.models import Filter


@tool
async def avector_search(
    query: str, filter: str | None = None, k: int = 3
) -> list[Document]:
    """Perform a vector search with metadata filtering.

    Parameters
    ----------
    query : str
        The search query string.
    filter : str or None, default=None
        The metadata filter value for 'metadata.section'.
    k : int, default=3
        The number of top similar documents to retrieve.

    Returns
    -------
    list[Document]
        A list of retrieved Document objects.
    """
    key: str = "metadata.section"
    _filter: Filter | None = (
        models.Filter(
            must=[models.FieldCondition(key=key, match=models.MatchValue(value=filter))]
        )
        if filter
        else None
    )
    return await vectorstore.asimilarity_search(query, k=k, filter=_filter)

In [26]:
console.print(avector_search)

In [27]:
avector_search.coroutine

<function __main__.avector_search(query: str, filter: str | None = None, k: int = 3) -> list[langchain_core.documents.base.Document]>

#### Note

```py
# Without the @tool decorator
result = await avector_search(query=query, filter=None, k=3)

# Using the @tool decorator
result = await avector_search.coroutine(query=query, filter=None, k=3)
```

In [28]:
# Without the @tool decorator
# result = await avector_search(query=query, filter=None, k=3)

# Using the @tool decorator
result = await avector_search.coroutine(query=query, filter=None, k=3)


console.print(result)

In [None]:
result = await avector_search.coroutine(
    query=query, filter="ITEM 9A. CONTROLS AND PROCEDURES", k=3
)

console.print(result)

### Keyword Search

In [30]:
from rank_bm25 import BM25Okapi

corpus: list[str] = [
    "Hello there good man!",
    "It is quite windy in London",
    "How is the weather today?",
]

tokenized_corpus: list[list[str]] = [doc.split(" ") for doc in corpus]

bm25 = BM25Okapi(tokenized_corpus)
query: str = "windy London"
tokenized_query: list[str] = query.split(" ")  # type: ignore

doc_scores = bm25.get_scores(tokenized_query)
print(doc_scores)
# Sort in descending order of scores
sorted_idxs = np.argsort(doc_scores)[::-1]
sorted_idxs

[0.         0.93729472 0.        ]


array([1, 2, 0])

In [31]:
[corpus[idx] for idx in sorted_idxs]

['It is quite windy in London',
 'How is the weather today?',
 'Hello there good man!']

In [32]:
import re

from tokenizers import (  # type: ignore
    Regex,
    Tokenizer,
    normalizers,
)
from tokenizers import (
    models as t_models,
)

os.environ["TOKENIZERS_PARALLELISM"] = "false"


class CustomTokenizer:
    """A class for ..."""

    pattern_digits: str = r"[0-9]+"
    pattern_punctuation: str = r"[^\w\s\\\/]"  # Includes `\`, `/`
    pattern_spaces: str = r"\s{2,}"
    pattern_split: str = r"\W"

    unk_str: str = "[UNK]"

    def __init__(self, to_lower: bool = False) -> None:
        """Initialize with a WordPiece tokenizer and normalizer sequence."""
        self.to_lower = to_lower
        self.tokenizer = Tokenizer(t_models.WordPiece(unk_token=self.unk_str))  # type: ignore

        # Create the custom normalizer
        transformations_list = []
        if self.to_lower:
            transformations_list.append(normalizers.Lowercase())

        transformations_list.extend(
            [  # type: ignore
                normalizers.NFD(),
                normalizers.Replace(Regex(self.pattern_digits), " "),
                normalizers.Replace(Regex(self.pattern_punctuation), " "),
                normalizers.StripAccents(),
                normalizers.Strip(),
                # Last step
                normalizers.Replace(Regex(self.pattern_spaces), " "),
            ]
        )
        self.tokenizer.normalizer = normalizers.Sequence(  # type: ignore
            transformations_list  # type: ignore
        )

    def split_on_patterns(self, text: str) -> str:
        """Split a string on a pattern and join the parts with spaces.

        Parameters
        ----------
        text : str
            Input text to be split.

        Returns
        -------
        str
            Processed text with pattern-based splits.
        """
        parts: list[str] = re.split(self.pattern_split, text, flags=re.I)
        # Remove empty strings and join by spaces
        output: str = " ".join(filter(lambda x: x != "", [p.strip() for p in parts]))
        return output

    def format_data(self, data: str) -> str:
        """Format a single text string using pattern splitting and normalization.

        Parameters
        ----------
        data : str
            Input text to be formatted.

        Returns
        -------
        str
            Normalized and formatted text.
        """
        text: str = self.split_on_patterns(data)
        return self.tokenizer.normalizer.normalize_str(text)

    def batch_format_data(self, data: list[str]) -> list[str]:
        """Format a batch of text strings.

        Parameters
        ----------
        data : list[str]
            List of input texts to be formatted.

        Returns
        -------
        list[str]
            List of normalized and formatted texts.
        """
        return [self.format_data(row) for row in data]

In [33]:
custom_tokenizer = CustomTokenizer()
custom_tokenizer.batch_format_data(corpus)

['Hello there good man',
 'It is quite windy in London',
 'How is the weather today']

In [None]:
print("\nBuilding BM25 index for keyword search...")

# Create a list where each element is a list of words from a document
tokenized_corpus = [
    custom_tokenizer.format_data(doc.page_content).split(" ")
    for doc in doc_chunks_with_metadata
]

# Create a list of all unique document IDs
doc_ids: list[str] = [doc.metadata["chunk_id"] for doc in doc_chunks_with_metadata]

# Create a mapping from a document's ID back to the full Document object for easy lookup
doc_dict: dict[str, Document] = {
    doc.metadata["chunk_id"]: doc for doc in doc_chunks_with_metadata
}

# Initialize the BM25Okapi index with our tokenized corpus
bm25 = BM25Okapi(tokenized_corpus)


Building BM25 index for keyword search...


In [35]:
import asyncio


def keyword_search(query: str, k: int = 3) -> list[Document]:
    """Perform keyword search using BM25 and return top k documents."""
    # Tokenize the query
    tokenized_query: list[str] = custom_tokenizer.format_data(query).split()
    doc_scores = bm25.get_scores(tokenized_query)
    # Sort in descending order and select the top k
    top_k_idxs: np.ndarray = np.argsort(doc_scores)[::-1][:k]

    return [doc_dict[doc_ids[i]] for i in top_k_idxs]


@tool
async def akeyword_search(query: str, k: int = 3) -> list[Document]:
    """Perform keyword search asynchronously using BM25 and return top k documents."""
    return await asyncio.to_thread(keyword_search, query, k)

In [None]:
query: str = "According to NVIDIA's management, what assurance level do disclosure controls and internal controls actually provide?"

retrieved_docs = keyword_search(query=query, k=3)

console.print(retrieved_docs)

In [None]:
query: str = "According to NVIDIA's management, what assurance level do disclosure controls and internal controls actually provide?"

retrieved_docs = await akeyword_search.coroutine(query=query, k=3)

console.print(retrieved_docs)

### Hybrid Search

- Keyword + Vector Search

In [41]:
_filter = "ITEM 9A. CONTROLS AND PROCEDURES"
k: int = 3
tasks: list[Coroutine[Any, Any, list[Document]]] = [
    avector_search.coroutine(query=query, filter=_filter, k=k),
    akeyword_search.coroutine(query=query, k=k),
]

semantic_docs, kw_docs = await asyncio.gather(*tasks)

semantic_docs, kw_docs

([Document(metadata={'source_doc': 'cleaned_doc.txt', 'section': 'ITEM 9A. CONTROLS AND PROCEDURES', 'chunk_id': 'a557844f-e1de-481c-9177-8daff38cb7ed', '_id': '971b2711f8714714b583fce1b3de4a53', '_collection_name': 'smart_rag_collection'}, page_content='Controls and Procedures\n\nDisclosure Controls and Procedures\n\nBased on their evaluation as of January\xa029, 2023, our management, including our Chief Executive Officer and Chief Financial Officer, has concluded that our disclosure controls and procedures (as defined in Rule 13a-15(e) under the Exchange Act) were effective to provide reasonable assurance.\n\nManagement’s Annual Report on Internal Control Over Financial Reporting'),
  Document(metadata={'source_doc': 'cleaned_doc.txt', 'section': 'ITEM 9A. CONTROLS AND PROCEDURES', 'chunk_id': '48ee360c-80c7-4800-85cf-68f26def9eca', '_id': '78cb33434b594a05bec9c083aff6e658', '_collection_name': 'smart_rag_collection'}, page_content='Our management, including our Chief Executive Offic

In [None]:
from langchain_core.documents.base import Document


@tool
async def ahybrid_search(
    query: str, k: int = 5, filter: str | None = None
) -> list[Document]:
    """
    Asynchrounously combine vector and keyword search results using Reciprocal Rank Fusion (RRF).

    Parameters
    ----------
    query : str
        The search query string.
    k : int, optional
        Maximum number of documents to return, by default 5.
    filter : str or None, optional
        Optional filter expression passed to the vector search, by default None.

    Returns
    -------
    list[Document]
        Top-k documents ranked by fused scores.

    Notes
    -----
    RRF is a simple, unsupervised method for merging ranked lists.
    The constant ``K`` (set to 61) controls the steepness of the rank
    discount curve and is taken from the original RRF paper.
    """
    K: int = 61  # Default for RRF

    tasks: list[Coroutine[Any, Any, list[Document]]] = [
        avector_search.coroutine(query=query, filter=filter, k=k),  # type: ignore
        akeyword_search.coroutine(query=query, k=k),  # type: ignore
    ]
    semantic_docs, kw_docs = await asyncio.gather(*tasks)

    # Results of vector and kw search
    res_ids: list[list[str]] = [
        [doc.metadata["chunk_id"] for doc in semantic_docs],
        [doc.metadata["chunk_id"] for doc in kw_docs],
    ]
    # Calculate Reciprocal Rank Fusion (RRF)
    rrf_dict: dict[str, float] = {}

    for doc_list in res_ids:
        # Grab each doc_id
        for idx, doc_id in enumerate(doc_list):
            if doc_id not in rrf_dict:
                rrf_dict[doc_id] = 0
            # Add (1 / (idx + k)) to each retrieved doc
            rrf_dict[doc_id] += 1 / (idx + K)
    # Sort result using RRF score in descending order
    ranked_ids: list[str] = sorted(
        rrf_dict.keys(), key=lambda x: rrf_dict[x], reverse=True
    )[:k]

    return [doc_dict[_id] for _id in ranked_ids]

In [None]:
result = await ahybrid_search.coroutine(
    query=query, filter="ITEM 9A. CONTROLS AND PROCEDURES", k=5
)
console.print(result)

<br>

### Re-Ranker

- Re-rank retrieved chunks based on relevance to the query.


In [43]:
from sentence_transformers import CrossEncoder

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# reranker

In [None]:
def rerank_documents(query: str, documents: list[Document]) -> list[Document]:
    """Rerank documents by relevance to query using CrossEncoder.

    Parameters
    ----------
    query : str
        The search query string.
    documents : list[Document]
        List of Document objects to rerank.

    Returns
    -------
    list[Document]
        Documents sorted by relevance score in descending order.
    """
    # Prepare pairs of (query, document content) for scoring
    pairs: list[tuple[str, str]] = [(query, doc.page_content) for doc in documents]
    # Get relevance scores from the CrossEncoder
    scores: list[float] | np.ndarray = reranker.predict(pairs)

    # Combine documents with their scores
    doc_score_pairs: list[tuple[Document, float]] = list(zip(documents, scores))
    # Sort documents by score in descending order
    ranked_docs: list[Document] = [
        doc for doc, _ in sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
    ]
    return ranked_docs

In [45]:
rerank_doc = rerank_documents(query, documents=result)
print(f"Query: {query}\n")
print("Reranked Documents:\n")
console.print(rerank_doc)

Query: According to NVIDIA's management, what assurance level do disclosure controls and internal controls actually provide?

Reranked Documents:



### Web Search

- Using Tavily Search

In [None]:
from langchain_community.tools import DuckDuckGoSearchResults


def truncate_content(content: str | None, max_chars: int | None = None) -> str | None:
    """Truncate content to max_chars with ellipsis indicator."""
    if not content:
        return None

    if max_chars:
        return (
            f"{content[:max_chars]} [truncated]..."
            if len(content) > max_chars
            else content
        )
    return content


def extract_main_content_from_html(content: str) -> str:
    """Extract main content from HTML by removing noise and finding article body.

    Parameters
    ----------
    content : str
        Raw HTML content string.

    Returns
    -------
    str
        BeautifulSoup element containing the main content area.
        Falls back to body element if no main content found.

    Notes
    -----
    Removes scripts, styles, navigation, headers, footers, and ads.
    Searches for common content containers: main, article, or content divs.
    """
    soup = BeautifulSoup(content, "html.parser")

    # Remove unwanted elements
    for tag in soup(
        [
            "script",
            "style",
            "nav",
            "header",
            "footer",
            "aside",
            "iframe",
            "noscript",
        ]
    ):
        tag.decompose()

    # Try to find main content area (common patterns)
    main_content = None
    for selector in [
        soup.find("main"),
        soup.find("article"),
        soup.find(
            "div",
            class_=lambda x: x
            and any(
                c in str(x).lower()  # type: ignore
                for c in ["content", "article", "post", "story"]
            ),
        ),
        soup.find(
            "div",
            id=lambda x: x
            and any(
                c in str(x).lower()  # type: ignore
                for c in ["content", "article", "post", "main"]
            ),
        ),
    ]:
        if selector and selector.get_text(strip=True):
            main_content = selector
            break

    # Fall back to body if no main content found
    if not main_content:
        main_content = soup.find("body") or soup

    return str(main_content)


async def afetch_raw_content(url: str) -> str | None:
    """Fetch HTML content from a URL and convert to markdown.

    Parameters
    ----------
    url : str
        The URL to fetch content from.

    Returns
    -------
    str | None
        Markdown-converted content if successful, None otherwise.

    Notes
    -----
    Uses browser-like headers to avoid bot detection and a 15-second timeout.
    Extracts main content from common article/content tags.
    """
    # Browser-like headers to avoid bot detection
    headers: dict[str, str] = {
        "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
        "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8",
        "Accept-Language": "en-US,en;q=0.9",
        "Accept-Encoding": "gzip, deflate, br",
        "DNT": "1",
        "Connection": "keep-alive",
        "Upgrade-Insecure-Requests": "1",
        "Sec-Fetch-Dest": "document",
        "Sec-Fetch-Mode": "navigate",
        "Sec-Fetch-Site": "none",
        "Cache-Control": "max-age=0",
    }

    try:
        async with HTTPXClient(timeout=15) as client:
            response = await client.get(url, headers=headers)

            # Check if request was successful
            if not response.get("success"):
                return None

            # Response might be dict or str
            if isinstance(response["data"], dict):
                content = json.dumps(response["data"])
            else:
                content = response["data"]
            html_content = content

            # Parse HTML and extract main content
            main_content: str = extract_main_content_from_html(content=html_content)

            # Convert to markdown
            markdown_content = md(
                str(main_content),
                heading_style="ATX",
                bullets="-",
                strip=["script", "style"],
            )

            # Clean up excessive whitespace
            lines: list[str] = [
                line.strip() for line in markdown_content.split("\n") if line.strip()
            ]
            cleaned = "\n\n".join(lines)

            return cleaned if cleaned and len(cleaned) > 100 else None

    except Exception as e:
        print(f"Warning: Failed to fetch full page content for {url}: {str(e)}")
        return None


async def aduckduckgo_search(
    query: str, fetch_full_page: bool = False, max_chars: int | None = None
) -> dict[str, list[dict[str, Any]]]:
    """Search DuckDuckGo and optionally fetch full page content.

    Parameters
    ----------
    query : str
        The search query string.
    fetch_full_page : bool, default=False
        If True, fetch and parse full HTML content for each result.
    max_chars : int or None, default=None
        Maximum characters to return per result. If None, no truncation.

    Returns
    -------
    dict[str, list[dict[str, Any]]]
        Dictionary with "results" key containing list of search results.
        Each result has title, url, content, and raw_content fields.

    Notes
    -----
    When fetch_full_page=True, uses browser-like headers and smart content
    extraction to avoid bot detection and JS-blocking issues.
    """

    try:
        search = DuckDuckGoSearchResults(output_format="list", num_results=6)
        raw_results = await search.ainvoke(query)

        # format the data
        raw_results: list[dict[str, Any]] = [
            {
                "title": row["title"],
                "url": row["link"],
                "content": row["snippet"],
                "raw_content": row["snippet"],
            }
            for row in raw_results
        ]

        if fetch_full_page:
            # Fetch full pages concurrently for better performance
            tasks: list[Coroutine] = [
                afetch_raw_content(row["url"]) for row in raw_results
            ]
            full_contents = await asyncio.gather(*tasks)

            raw_results = [
                {
                    **row,
                    "raw_content": truncate_content(
                        content=full_content
                        or row["content"],  # Fall back to content if fetch fails
                        max_chars=max_chars,
                    ),
                }
                for row, full_content in zip(raw_results, full_contents)
            ]
        return {"results": raw_results}

    except Exception as e:
        print(f"Duckduckgo search failed: {str(e)}")
        return {"results": []}

In [84]:
search_result = await aduckduckgo_search(
    query="Has Nvidia broken any laws?",
    fetch_full_page=True,
    max_chars=5_000,
)

console.print(search_result)

In [None]:
@tool
async def aduckduckgo_web_search(
    query: str, fetch_full_page: bool = False, max_chars: int | None = None
) -> list[Document]:
    """Asynchronously search DuckDuckGo and optionally fetch full page content.

    Parameters
    ----------
    query : str
        The search query string.
    fetch_full_page : bool, default=False
        If True, fetch and parse full HTML content for each result.
    max_chars : int or None, default=None
        Maximum characters to return per result. If None, no truncation.

    Returns
    -------
    list[Document]
    """
    max_chars = 8_000 if not max_chars else max_chars

    search_response: dict[str, list[dict[str, Any]]] = await aduckduckgo_search(
        query=query, fetch_full_page=fetch_full_page, max_chars=max_chars
    )

    formatted_results: list[Document] = [
        Document(
            page_content=f"Title: {result['title']}\nContent: {result['raw_content']}\nURL: {result['url']}",
            metadata={
                "url": result["url"],
                "title": result["title"],
                "chunk_id": str(uuid4()),
            },
        )
        for result in search_response["results"]
    ]
    return formatted_results

In [89]:
search_result = await aduckduckgo_web_search.coroutine(
    query="Has Nvidia broken any laws?",
    fetch_full_page=True,
    max_chars=5_000,
)

console.print(search_result)

In [None]:
from langchain_tavily import TavilySearch

tavily_search = TavilySearch(
    api_key=settings.TAVILY_API_KEY.get_secret_value(),
    max_results=2,
    topic="general",
)
search_response = await tavily_search.ainvoke({"query": "Has Nvidia broken any laws?"})

console.print(search_response)

In [None]:
async def tavily_search_tool(
    query: str, fetch_full_page: bool = False, max_chars: int | None = None
) -> dict[str, list[dict[str, Any]]]:
    """Search the web using TavilySearch and return formatted results.

    Parameters
    ----------
    query : str
        The search query string.
    fetch_full_page : bool, default=False
        If True, include full raw_content from Tavily results.
    max_chars : int or None, default=None
        Maximum characters to return per result. If None, no truncation.

    Returns
    -------
    dict[str, list[dict[str, Any]]]
        Dictionary with "results" key containing list of search results.
        Each result has title, url, content, and raw_content fields.

    Notes
    -----
    Tavily automatically provides raw_content when available.
    No additional fetching needed - Tavily handles this internally.
    """

    tavily_search = TavilySearch(
        api_key=settings.TAVILY_API_KEY.get_secret_value(),
        max_results=3,
        topic="general",
        # include_raw_content tells Tavily to fetch full page content
        include_raw_content=fetch_full_page,
    )
    search_response = await tavily_search.ainvoke({"query": query})
    # return search_response
    raw_results: list[dict[str, Any]] = [
        {
            "title": row["title"],
            "url": row["url"],
            "content": row["content"],
            "raw_content": truncate_content(
                content=row.get("raw_content") or row["content"],
                max_chars=max_chars,
            ),
        }
        for row in search_response["results"]
    ]

    return {"results": raw_results}


@tool
async def atavily_web_search(
    query: str, fetch_full_page: bool = False, max_chars: int | None = None
) -> list[Document]:
    """Asynchronously search using Tavily and optionally fetch full page content.

    Parameters
    ----------
    query : str
        The search query string.
    fetch_full_page : bool, default=False
        If True, fetch and parse full HTML content for each result.
    max_chars : int or None, default=None
        Maximum characters to return per result. If None, no truncation.

    Returns
    -------
    list[Document]
    """
    max_chars = 8_000 if not max_chars else max_chars
    search_response: dict[str, list[dict[str, Any]]] = await tavily_search_tool(
        query=query, fetch_full_page=fetch_full_page, max_chars=max_chars
    )

    formatted_results: list[Document] = [
        Document(
            page_content=f"Title: {result['title']}\nContent: {result['raw_content']}\nURL: {result['url']}",
            metadata={
                "url": result["url"],
                "title": result["title"],
                "chunk_id": str(uuid4()),
            },
        )
        for result in search_response["results"]
    ]
    return formatted_results

In [93]:
search_response = await tavily_search_tool(
    query="Has Nvidia broken any laws?",
    fetch_full_page=True,
    max_chars=5_000,
)

console.print(search_response)

In [95]:
search_response = await atavily_web_search.coroutine(
    query="Has Nvidia broken any laws?",
    fetch_full_page=True,
    max_chars=None,
)

console.print(search_response)

In [68]:
raw_results: list[dict[str, Any]] = [
    {
        "title": row["title"],
        "url": row["url"],
        "content": row["content"],
        "raw_content": row["raw_content"],
    }
    for row in search_response["results"]
]
raw_results

[{'title': 'Nvidia broke antitrust law, China says on second day ...',
  'url': 'https://www.msn.com/en-us/money/companies/nvidia-broke-antitrust-law-china-says-on-second-day-of-u-s-trade-talks/ar-AA1MzY66?apiversion=v2&domshim=1&noservercache=1&noservertelemetry=1&batchservertelemetry=1&renderwebcomponents=1&wcseo=1',
  'content': "Beijing said Monday that Nvidia, the American artificial intelligence chip giant, broke China's antitrust law, ratcheting up tensions with the United States",
  'raw_content': None},
 {'title': 'Nvidia Broke Antitrust Law, China Says, as Tensions With ...',
  'url': 'https://www.nytimes.com/2025/09/15/technology/nvidia-china-antitrust.html',
  'content': "China's antimonopoly regulator said on Monday that Nvidia, America's leading chip maker, had violated the country's antitrust law, the latest",
  'raw_content': None},
 {'title': 'China says Nvidia violated antitrust laws',
  'url': 'https://finance.yahoo.com/news/china-says-nvidia-violated-antitrust-09372

### Create States

<br>

#### 1.) Step

- The smallest unit. Multiple Steps make up a Plan.
- A Step has:
  - question: The question being asked.
  - rationale: The reasoning behind the question.
  - tool: The tool to be used to answer the question. (e.g. web_search or vectorstore_lookup)
  - search_keywords: Keywords to use for searching.
  - target_section: The section of the document to focus on. (Only for vectorstore_lookup tool)


In [None]:
from typing import TypedDict

# class Step(BaseModel):
#     """A single step in the multi-step reasoning process."""

#     question: str = Field(..., description="The question to be answered by the step.")
#     rewritten_queries: list[str] = Field(
#         ...,
#         description="3-5 alternative phrasings of the question optimized for retrieval. ",
#     )
#     rationale: str = Field(..., description="The brief reasoning behind the question.")
#     tool: Literal["web_search", "vector_store", "keyword_search", "hybrid_search"] = (
#         Field(
#             ...,
#             description="The tool to use for this step. Hybrid search combines vector and keyword search.",
#         )
#     )
#     search_keywords: list[str] = Field(
#         ...,
#         description="Critical keywords and phrases to use for web search or vector store "
#         "retrieval to ensure quality results are returned.",
#     )
#     target_section: str | None = Field(
#         default=None,
#         description="The target section in the document to focus on. This is ONLY required when "
#         "the tool is 'vector_store'. e.g., 'ITEM 1A. RISK FACTORS'.",
#     )
#     depends_on: list[int] = Field(
#         default_factory=list,
#         description="List of step indices (0-based) that this step depends on. "
#         "Leave empty if this step can run immediately.",
#     )


class Step(BaseModel):
    """A single step in the multi-step reasoning process."""

    question: str = Field(..., description="The question to be answered by the step.")
    rationale: str = Field(..., description="The brief reasoning behind the question.")
    tool: Literal["web_search", "vector_store", "keyword_search", "hybrid_search"] = (
        Field(
            ...,
            description="The tool to use for this step. Hybrid search combines vector and keyword search.",
        )
    )
    search_keywords: list[str] = Field(
        ...,
        description="Critical keywords and phrases to use for web search or vector store "
        "retrieval to ensure quality results are returned.",
    )
    target_section: str | None = Field(
        default=None,
        description="The target section in the document to focus on. This is ONLY required when "
        "the tool is 'vector_store'. e.g., 'ITEM 1A. RISK FACTORS'.",
    )


class ReWrittenQuery(BaseModel):
    question: str = Field(..., description="Original query to be re-written.")
    rewritten_query: list[str] = Field(..., description="The re-written query.")
    rationale: str = Field(..., description="The brief reasoning behind the decision.")


class RetrieverMethod(BaseModel):
    method: Literal["vector_search", "web_search", "hybrid_search"]
    rationale: str = Field(..., description="The brief reasoning behind the decision.")

#### 2.) Plan

- A Plan is a sequence of Steps to achieve a goal.


#### 2.b) PastStep

- This is used to store the history of executed steps in the plan.


#### 3.) State

- A State represents the current status of the RAG process. It includes:
  - original_question: The initial question posed by the user.
  - plan: The current plan being executed.
  - past_steps: A list of PastStep objects representing the history of executed steps.
  - current_step_index: The index of the current step in the plan.
  - retrieved_docs: A list of Document objects that have been retrieved so far.
  - reranked_docs: A list of Document objects that have been reranked based on relevance.
  - synthesized_content: The content synthesized from the reranked documents.
  - final_answer: The final answer generated for the original question.

In [None]:
class Plan(BaseModel):
    """A multi-step plan for answering a complex question."""

    steps: list[Step] = Field(
        ..., description="A list of steps to execute in the plan."
    )


class StepState(TypedDict):
    """State of a completed step in the multi-step reasoning process."""

    step_index: int  # Index of the step in the plan
    question: str  # The question asked in this step
    re_written_queries: list[str]  # Re-written queries for this step
    retrieved_documents: list[Document]  # Documents retrieved for this step
    summary: str  # Summary of the step's findings


class State(TypedDict):
    """State of the multi-step reasoning process."""

    original_question: str  # The original complex question
    plan: Plan | None  # The multi-step plan
    step_state: list[StepState]  # List of completed steps
    current_step_index: int  # Index of the current step being executed
    retrieved_documents: list[Document]  # Documents retrieved in the current step
    reranked_documents: list[Document]  # Documents reranked based on relevance
    completed_steps_indices: set[int]  # Indices of completed steps
    steps_in_progress_indices: set[int]  # Indices of steps currently in progress
    validation_result: dict | None  # Validation output
    synthesized_context: str  # Synthesized context from reranked documents
    final_answer: str  # The final answer to the original question

In [108]:
tool_descriptions: dict[str, str] = {
    "web_search": "Use this tool to search the web for up-to-date information.",
    "vector_store": "Use this tool to search the vector store for relevant document sections.",
    # "keyword_search": "Use this tool to perform keyword-based search over the document chunks.",
    # "hybrid_search": "Use this tool to perform a combined vector and keyword search over the document chunks.",
}

json.dumps(tool_descriptions)

'{"web_search": "Use this tool to search the web for up-to-date information.", "vector_store": "Use this tool to search the vector store for relevant document sections."}'

In [None]:
"""This module contains prompt templates used for various interactions within the application."""

planner_prompt: str = """
<SYSTEM>
    <ROLE>Expert query planner for multi-step reasoning and retrieval optimization.</ROLE>

    <GUIDELINES>
        - Create 2-5 atomic, logical steps
        - Extract 3-5 critical keywords per step for query rewriting
    </GUIDELINES>

    <SECTIONS>{section_titles}</SECTIONS>
    <TOOLS>{tool_descriptions}</TOOLS>

    <EXAMPLE>
        Q: "What were Apple's R&D expenses and how do they compare to competitors?"

        Step 1:
        - question: "What were Apple's research and development expenses?"
        - rationale: "Gather Apple's R&D data from financial docs"
        - tool: "vector_store"
        - search_keywords: ["R&D", "research", "development", "innovation", "investment"]
        - target_section: "ITEM 7. MANAGEMENT'S DISCUSSION"

        Step 2:
        - question: "What are competitors' R&D expenses?"
        - rationale: "Get competitor data for comparison"
        - tool: "web_search"
        - search_keywords: ["competitor", "R&D spending", "Microsoft", "Google", "budget"]
        - target_section: null
    </EXAMPLE>

    <OUTPUT_FORMAT>
        Return Plan with Step objects containing:
        question, rationale, tool, search_keywords (3-5), target_section
    </OUTPUT_FORMAT>
</SYSTEM>
"""


retriever_type_prompt: str = """
<ROLE>Expert at selecting optimal retrieval methods based on query characteristics.</ROLE>

<QUERY>{question}</QUERY>

<METHODS>
    Choose:
    - vector_search if: Query is conceptual, uses natural language, seeks related information
    - keyword_search if: Query has specific terms, proper nouns, technical codes, exact phrases required
    - hybrid_search if: Query needs both semantic context and precise term matching
</METHODS>
"""

query_re_writer_prompt: str = """
<ROLE>Query optimizer for document retrieval and web search.</ROLE>

<GUIDELINES>
    - Extract core intent, remove ambiguity
    - Use specific, domain-relevant terms
    - Output 5-10 keywords/phrases
</GUIDELINES>

<QUERY>{question}</QUERY>
<KEYWORDS>{search_keywords}</KEYWORDS>

<OUTPUT>Return 2-5 query variations capturing original intent.</OUTPUT>
"""

In [100]:
hello_prompt: str = """{section_titles}"""
hello_prompt.format(section_titles=" | ".join(section_titles))

'ITEM 1. BUSINESS | ITEM 1A. RISK FACTORS | ITEM 1B. UNRESOLVED STAFF COMMENTS | ITEM 2. PROPERTIES | ITEM 3. LEGAL PROCEEDINGS | ITEM 4. MINE SAFETY DISCLOSURES | ITEM 5. MARKET FOR REGISTRANT’S COMMON EQUITY, RELATED STOCKHOLDER MATTERS AND ISSUER PURCHASES OF EQUITY SECURITIES | ITEM 6. [RESERVED] | ITEM 7. MANAGEMENT’S DISCUSSION AND ANALYSIS OF FINANCIAL CONDITION AND RESULTS OF OPERATIONS | ITEM 7A. QUANTITATIVE AND QUALITATIVE DISCLOSURES ABOUT MARKET RISK | ITEM 8. FINANCIAL STATEMENTS AND SUPPLEMENTARY DATA | ITEM 9. CHANGES IN AND DISAGREEMENTS WITH ACCOUNTANTS ON ACCOUNTING AND FINANCIAL DISCLOSURE | ITEM 9A. CONTROLS AND PROCEDURES | ITEM 9C. DISCLOSURE REGARDING FOREIGN JURISDICTIONS THAT PREVENT INSPECTIONS | ITEM 10. DIRECTORS, EXECUTIVE OFFICERS AND CORPORATE GOVERNANCE | ITEM 11. EXECUTIVE COMPENSATION | ITEM 12. SECURITY OWNERSHIP OF CERTAIN BENEFICIAL OWNERS AND MANAGEMENT AND RELATED STOCKHOLDER MATTERS | ITEM 13. CERTAIN RELATIONSHIPS AND RELATED TRANSACTIONS, AND 

In [152]:
# utils.py

from enum import Enum

import instructor
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from openai import AsyncOpenAI

from src.config import app_settings

# from src.schemas.types import OpenRouterModels, PydanticModel


class OpenRouterModels(str, Enum):
    """OpenRouter LLMs."""

    GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-001"
    GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
    GEMINI_2_5_FLASH_LITE = "google/gemini-2.5-flash-lite"
    GPT_OSS_120B = "openai/gpt-oss-120b"
    GPT_OSS_20B = "openai/gpt-oss-20b"
    GPT_5_NANO = "openai/gpt-5-nano"
    LLAMA_3_3_70B_INSTRUCT = "meta-llama/llama-3.3-70b-instruct"
    LLAMA_3_8B_INSTRUCT = "meta-llama/llama-3-8b-instruct"
    NEMOTRON_NANO_9B_V2 = "nvidia/nemotron-nano-9b-v2"
    QWEN3_30B_A3B = "qwen/qwen3-30b-a3b"
    QWEN3_NEXT_80B_A3B_INSTRUCT = "qwen/qwen3-next-80b-a3b-instruct"
    QWEN3_32B = "qwen/qwen3-32b"
    SAO10K_L3_LUNARIS_8B = "sao10k/l3-lunaris-8b"
    X_AI_GROK_4_FAST = "x-ai/grok-4-fast"
    X_AI_GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
    Z_AI_GLM_4_5 = "z-ai/glm-4.5"


_async_client = AsyncOpenAI(
    api_key=app_settings.OPENROUTER_API_KEY.get_secret_value(),
    base_url=app_settings.OPENROUTER_URL,
)

aclient = instructor.from_openai(
    _async_client,
    mode=instructor.Mode.OPENROUTER_STRUCTURED_OUTPUTS,
)
type PydanticModel = type[BaseModel]


async def get_structured_output(
    messages: list[dict[str, Any]],
    schema: PydanticModel,
    model: OpenRouterModels | None = None,
) -> PydanticModel:
    """
    Retrieves structured output from a chat completion model.

    Parameters
    ----------
    messages : list[dict[str, Any]]
        The list of messages to send to the model for the chat completion.
    model : RemoteModel
        The remote model to use for the chat completion (e.g., 'gpt-4o').
    schema : PydanticModel
        The Pydantic schema to enforce for the structured output.

    Returns
    -------
    BaseModel
        An instance of the provided Pydantic schema containing the structured output.

    Notes
    -----
    This is an asynchronous function that awaits the completion of the API call.
    """
    model = model if model else RemoteModel.GEMINI_2_5_FLASH_LITE

    return await aclient.chat.completions.create(
        model=model,
        response_model=schema,
        messages=messages,  # type: ignore
        max_retries=5,
    )


def convert_langchain_messages_to_dicts(
    messages: list[HumanMessage | SystemMessage | AIMessage],
) -> list[dict[str, str]]:
    """Convert LangChain messages to a list of dictionaries.

    Parameters
    ----------
    messages : list[HumanMessage | SystemMessage | AIMessage]
        List of LangChain message objects to convert.

    Returns
    -------
    list[dict[str, str]]
        List of dictionaries with 'role' and 'content' keys.
        Roles are mapped as follows:
        - HumanMessage -> "user"
        - SystemMessage -> "system"
        - AIMessage -> "assistant"

    """
    role_mapping: dict[str, str] = {
        "SystemMessage": "system",
        "HumanMessage": "user",
        "AIMessage": "assistant",
    }

    converted_messages: list[dict[str, str]] = []
    for msg in messages:
        message_type: str = msg.__class__.__name__
        # Default to "user" if unknown
        role: str = role_mapping.get(message_type, "user")
        converted_messages.append({"role": role, "content": msg.content})  # type: ignore

    return converted_messages


def append_memory(existing: dict[str, Any], new: dict[str, Any]) -> dict[str, Any]:
    """Merge new memory data into existing memory, appending lists and merging dicts.

    Parameters
    ----------
    existing: dict[str, Any]
        The existing memory data.
    new: dict[str, Any]
        The new memory data to merge.

    Returns
    -------
    dict[str, Any]
        The merged memory data.
    """
    result: dict[str, Any] = existing.copy()

    for key, new_value in new.items():
        # Skip None or empty values
        if new_value is None or new_value == "" or new_value == []:
            continue

        existing_value = result.get(key)

        # If key doesn't exist, just add it
        if existing_value is None:
            result[key] = new_value
            continue

        # Lists: combine and remove duplicates
        if isinstance(new_value, list):
            combined = existing_value + new_value
            # Preserve order, remove duplicates
            # ["a", "b", "a"] -> ["a", "b"]
            result[key] = list(dict.fromkeys(combined))

        # Dicts: merge
        elif isinstance(new_value, dict):
            result[key] = {**existing_value, **new_value}

        # Everything else: new value overwrites
        else:
            result[key] = new_value

    return result

In [None]:
from typing import cast


async def generate_plan_node(state: State) -> State:
    """Generate a multi-step plan based on the user's question.

    Parameters
    ----------
    state : State
        Current state containing the original_question.

    Returns
    -------
    State
        Updated state with generated plan and initialized fields.
    """
    company: str = "NVIDIA"
    user_question: str = state["original_question"]
    user_query: str = f"<USER_QUESTION>{user_question}</USER_QUESTION>"

    query = planner_prompt.format(
        company=company,
        user_question=user_question,
        tool_descriptions=json.dumps(tool_descriptions),
        section_titles=" | ".join(section_titles),
    )

    messages = convert_langchain_messages_to_dicts(
        messages=[SystemMessage(content=query), HumanMessage(content=user_query)]
    )
    response = await get_structured_output(messages=messages, model=None, schema=Plan)
    response = cast(Plan, response)

    return State(
        original_question=state["original_question"] or "",
        plan=response,
        step_state=state["step_state"] or [],
        current_step_index=state["current_step_index"] or 0,
        retrieved_documents=state["retrieved_documents"] or [],
        reranked_documents=state["reranked_documents"] or [],
        completed_steps_indices=state["completed_steps_indices"] or set(),
        steps_in_progress_indices=state["steps_in_progress_indices"] or set(),
        validation_result=state["validation_result"] or None,
        synthesized_context=state["synthesized_context"] or "",
        final_answer=state["final_answer"] or "",
    )


async def query_rewriter(question: str, search_keywords: list[str]) -> ReWrittenQuery:
    prompt = query_re_writer_prompt.format(
        question=question, search_keywords=", ".join(search_keywords)
    )
    messages = convert_langchain_messages_to_dicts(messages=[HumanMessage(prompt)])
    response = await get_structured_output(
        messages=messages, model=None, schema=ReWrittenQuery
    )
    return cast(ReWrittenQuery, response)


async def determine_retrieval_type(question: str) -> RetrieverMethod:
    prompt = retriever_type_prompt.format(question=question)
    messages = convert_langchain_messages_to_dicts(messages=[HumanMessage(prompt)])
    response = await get_structured_output(
        messages=messages, model=None, schema=RetrieverMethod
    )
    return cast(RetrieverMethod, response)


async def _retrieve_and_rerank_node(state):
    """Retrieve documents and rerank them based on relevance.

    Parameters
    ----------
    state : State
        Current state with query information.

    Returns
    -------
    State
        Updated state with retrieved and reranked documents.
    """
    ...


async def retrieve(**kwargs):
    """Execute retrieval based on step configuration."""
    ...


async def deduplicate(**kwargs):
    """Remove duplicate documents from retrieved results."""
    ...


async def rerank(**kwargs):
    """Rerank documents by relevance to query."""
    ...


async def summarize_step_findings(**kwargs):
    """Summarize findings from a completed step."""
    ...


async def synthesize_node(state):
    """Synthesize final context from reranked documents."""
    ...


async def validate_node(state):
    """Validate the quality and completeness of the answer."""
    ...


async def should_continue_retrieval(state):
    """Determine if more retrieval steps are needed."""
    ...


async def should_retry(state):
    """Determine if the current step should be retried."""
    ...

In [157]:
user_query: str = (
    "Based on NVIDIA's 2023 10-K filing, identify their key risks related to competition. "
    "Then, find recent news (post-filing, from 2024) about AMD's AI chip strategy and explain "
    "how this new strategy directly addresses or exacerbates one of NVIDIA's stated risks."
)
state: State = {
    "original_question": user_query,
    "plan": None,
    "step_state": [],
    "current_step_index": 0,
    "retrieved_documents": [],
    "reranked_documents": [],
    "completed_steps_indices": set(),
    "steps_in_progress_indices": set(),
    "validation_result": None,
    "synthesized_context": "",
    "final_answer": "",
}
response = await generate_plan_node(state)

In [158]:
console.print(response["plan"])

In [177]:
first_step = response["plan"].steps[0]
qs = first_step.question
search_kws = first_step.search_keywords

qs, search_kws

("What are NVIDIA's key risks related to competition, according to their 2023 10-K filing?",
 ['NVIDIA', '2023 10-K', 'competition', 'risks', 'market share'])

In [178]:
res = await query_rewriter(question=qs, search_keywords=search_kws)

console.print(res)

In [185]:
qs: str = " | ".join(res.rewritten_query)


res_1 = await determine_retrieval_type(question=qs)

console.print(res_1)

In [187]:
qs: str = "Tell me about Neidu's brilliance"


res_1 = await determine_retrieval_type(question=qs)

console.print(res_1)