In [2]:
from typing import List, Optional

import gradio as gr
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from chromadb import PersistentClient

from langchain_anthropic import ChatAnthropic
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
from langchain_community.tools import DuckDuckGoSearchRun


load_dotenv(override=True)

MODEL = "claude-haiku-4-5"
DB_NAME = "preprocessed_db"
COLLECTION_NAME = "docs"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
RETRIEVAL_K = 8
MAX_KB_DISTANCE = 1.1


llm = ChatAnthropic(model=MODEL, temperature=0.0, max_tokens=1800)
embedding_client = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
web_search_tool = DuckDuckGoSearchRun()

In [3]:
class Result(BaseModel):
    """Represents a retrievable chunk plus metadata."""

    page_content: str
    metadata: dict


class RankOrder(BaseModel):
    """Represents reranked chunk ids in descending relevance order."""

    order: List[int] = Field(
        description="The order of relevance of chunks, from most relevant to least relevant, by chunk id number"
    )


class RelevanceDecision(BaseModel):
    """Represents whether KB context can answer the user question."""

    has_relevant_context: bool = Field(
        description="True when the context has enough information to answer accurately"
    )
    rationale: str = Field(description="Short explanation for the decision")

In [4]:
def normalize_message_content(content: object) -> str:
    """Normalize chat content values into plain text."""
    if isinstance(content, str):
        return content.strip()

    if isinstance(content, list):
        parts = []
        for item in content:
            if isinstance(item, str):
                text = item.strip()
                if text:
                    parts.append(text)
            elif isinstance(item, dict):
                text = str(item.get("text", "")).strip()
                if text:
                    parts.append(text)
        return "\n".join(parts).strip()

    if isinstance(content, dict):
        return str(content.get("text", "")).strip()

    return str(content).strip()


def get_collection_or_raise():
    """Return the pre-built vector collection or raise a helpful error."""
    chroma = PersistentClient(path=DB_NAME)
    collection_names = [collection.name for collection in chroma.list_collections()]
    if COLLECTION_NAME not in collection_names:
        raise ValueError(
            f"Collection '{COLLECTION_NAME}' not found in {DB_NAME}. Build the vector DB first."
        )

    collection = chroma.get_or_create_collection(COLLECTION_NAME)
    if collection.count() == 0:
        raise ValueError(
            f"Collection '{COLLECTION_NAME}' is empty in {DB_NAME}. Build embeddings first."
        )
    return collection


def rewrite_query(question: str, history: Optional[List[dict]] = None) -> str:
    """Rewrite a user question into a concise retrieval query."""
    if history is None:
        history = []

    history_str = ""
    if history:
        lines = []
        for message in history:
            role = "User" if message.get("role", "").lower() == "user" else "Assistant"
            content = normalize_message_content(message.get("content", ""))
            if content:
                lines.append(f"{role}: {content}")
        history_str = "\n".join(lines)
    else:
        history_str = "(no history)"

    system_prompt = f"""
You are in a conversation with a user about Insurellm.
Rewrite the current user question into one short, specific search query for a knowledge base.

Conversation history:
{history_str}

Current question:
{question}

Respond with the rewritten query only.
"""

    response = llm.invoke(
        [
            SystemMessage(content=system_prompt.strip()),
            HumanMessage(content="Rewrite now."),
        ]
    )
    return response.content.strip().strip('"')


def fetch_context_unranked(question: str) -> tuple[List[Result], List[float]]:
    """Retrieve nearest chunks from preprocessed ChromaDB without reranking."""
    collection = get_collection_or_raise()
    query_vector = embedding_client.embed_query(question)
    results = collection.query(
        query_embeddings=[query_vector],
        n_results=RETRIEVAL_K,
        include=["documents", "metadatas", "distances"],
    )

    chunks = []
    distances = results.get("distances", [[]])[0]
    for document, metadata in zip(results["documents"][0], results["metadatas"][0]):
        chunks.append(Result(page_content=document, metadata=metadata))
    return chunks, distances


def rerank(question: str, chunks: List[Result]) -> List[Result]:
    """Rerank retrieved chunks by relevance using structured output."""
    if not chunks:
        return []

    system_prompt = """
You are a document re-ranker.
Rank chunk IDs by relevance to the question.
Reply ONLY as JSON: {"order": [<id_1>, <id_2>, ...]}.
Include all IDs exactly once.
"""

    prompt_lines = [
        f"Question:\n{question}\n",
        "Chunks:\n",
    ]
    for index, chunk in enumerate(chunks, start=1):
        prompt_lines.append(f"# CHUNK ID: {index}\n{chunk.page_content}\n")

    structured_llm = llm.with_structured_output(RankOrder)
    response_obj = structured_llm.invoke(
        [
            SystemMessage(content=system_prompt.strip()),
            HumanMessage(content="\n".join(prompt_lines).strip()),
        ]
    )

    valid_order = [index for index in response_obj.order if 1 <= index <= len(chunks)]
    if not valid_order:
        return chunks
    return [chunks[index - 1] for index in valid_order]


def fetch_context(question: str) -> tuple[List[Result], List[float]]:
    """Retrieve and rerank chunks from preprocessed vector store."""
    chunks, distances = fetch_context_unranked(question)
    reranked_chunks = rerank(question, chunks)
    return reranked_chunks, distances


def judge_context_relevance(question: str, chunks: List[Result], distances: List[float]) -> RelevanceDecision:
    """Decide whether KB context is sufficient for a trustworthy answer."""
    if not chunks:
        return RelevanceDecision(has_relevant_context=False, rationale="No retrieved chunks")

    if distances and min(distances) > MAX_KB_DISTANCE:
        return RelevanceDecision(
            has_relevant_context=False,
            rationale=f"Nearest chunk distance {min(distances):.3f} exceeds threshold",
        )

    context_preview = "\n\n".join(chunk.page_content[:600] for chunk in chunks[:4])
    structured_llm = llm.with_structured_output(RelevanceDecision)
    return structured_llm.invoke(
        [
            SystemMessage(
                content=(
                    "You are a strict relevance judge. Determine if the provided knowledge-base context "
                    "contains enough evidence to answer the user question accurately."
                )
            ),
            HumanMessage(
                content=(
                    f"Question:\n{question}\n\n"
                    f"Context:\n{context_preview}\n\n"
                    "Respond with has_relevant_context=true only when the context directly supports an answer."
                )
            ),
        ]
    )

In [5]:
KB_SYSTEM_PROMPT = """
You are a knowledgeable, friendly assistant representing Insurellm.
Answer only from the provided Insurellm knowledge-base extracts.
If information is missing from context, say you do not know.

Context:
{context}
"""


WEB_SYSTEM_PROMPT = """
You are a helpful assistant.
You are given raw web search findings.
Summarize them accurately and clearly.
If uncertain, say so.
"""


def make_rag_messages(question: str, history: List[dict], chunks: List[Result]) -> List[dict]:
    """Create chat messages for KB-grounded final answer generation."""
    context = "\n\n".join(
        f"Extract from {chunk.metadata.get('source', 'unknown')}:\n{chunk.page_content}"
        for chunk in chunks
    )
    system_prompt = KB_SYSTEM_PROMPT.format(context=context)
    return [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": question}]


def build_langchain_messages(message_dicts: List[dict]) -> List[BaseMessage]:
    """Convert dictionary messages into LangChain message objects."""
    messages: List[BaseMessage] = []
    for message in message_dicts:
        role = message.get("role", "").lower()
        content = normalize_message_content(message.get("content", ""))
        if role == "system":
            messages.append(SystemMessage(content=content))
        elif role == "user":
            messages.append(HumanMessage(content=content))
        elif role == "assistant":
            messages.append(AIMessage(content=content))
    return messages


def answer_from_web_stream(question: str):
    """Search the web and stream a transparent fallback answer."""
    web_results = web_search_tool.run(question)

    intro = (
        "There is no relevant information about this in the Insurellm knowledge base. "
        "Here is general information from the internet:\n\n"
    )
    partial = intro
    yield partial

    messages = [
        SystemMessage(content=WEB_SYSTEM_PROMPT.strip()),
        HumanMessage(
            content=(
                f"User question:\n{question}\n\n"
                f"Web search results:\n{web_results}\n\n"
                "Provide a concise, factual summary."
            )
        ),
    ]

    for chunk in llm.stream(messages):
        if chunk.content:
            partial += chunk.content
            yield partial


def answer_question_stream(question: str, history: Optional[List[dict]] = None):
    """Route question to KB-RAG or web fallback and stream output."""
    if history is None:
        history = []

    rewritten_query = rewrite_query(question, history)
    chunks, distances = fetch_context(rewritten_query)
    relevance = judge_context_relevance(question, chunks, distances)

    if not relevance.has_relevant_context:
        for partial in answer_from_web_stream(question):
            yield partial
        return

    message_dicts = make_rag_messages(question, history, chunks)
    messages = build_langchain_messages(message_dicts)

    partial_answer = ""
    for chunk in llm.stream(messages):
        if chunk.content:
            partial_answer += chunk.content
            yield partial_answer

In [7]:
def gradio_chat_handler(message: str, history: List[dict]):
    """Handle Gradio chat requests and stream agent responses."""
    rag_history = []
    for item in history:
        role = item.get("role", "")
        content = normalize_message_content(item.get("content", ""))
        if role in {"user", "assistant"} and content:
            rag_history.append({"role": role, "content": content})

    for partial_answer in answer_question_stream(message, rag_history):
        yield partial_answer


def check_vector_store_status() -> str:
    """Return a status string for the preprocessed vector store."""
    try:
        collection = get_collection_or_raise()
        return f"Vector store ready: {collection.count()} chunks in '{COLLECTION_NAME}' from '{DB_NAME}'."
    except Exception as exc:
        return f"Vector store error: {exc}"


def build_gradio_app() -> gr.Blocks:
    """Create the Gradio UI for the RAG + web-fallback agent."""
    with gr.Blocks(title="Insurellm RAG Agent") as demo:
        gr.Markdown("# Insurellm RAG Agent")
        gr.Markdown(
            "Uses preprocessed ChromaDB vectors for Insurellm answers and transparently falls back to web search when KB context is not relevant."
        )

        status_box = gr.Textbox(label="Vector Store Status", interactive=False)
        refresh_status_button = gr.Button("Refresh Vector Store Status")

        gr.ChatInterface(
            fn=gradio_chat_handler,
            title="Chat",
            description="Streaming responses enabled.",
        )

        refresh_status_button.click(fn=check_vector_store_status, outputs=status_box)
        demo.load(fn=check_vector_store_status, outputs=status_box)

    return demo


demo_app = build_gradio_app()

In [None]:
demo_app.launch()