In [19]:
import json
from typing import List, Optional

import gradio as gr
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from chromadb import PersistentClient

from langchain_anthropic import ChatAnthropic
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage, BaseMessage

from transcriber import transcriber


load_dotenv(override=True)

MODEL = "claude-haiku-4-5"
DB_NAME = "preprocessed_db"
COLLECTION_NAME = "docs"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
RETRIEVAL_K = 8
MAX_KB_DISTANCE = 1.1


llm = ChatAnthropic(model=MODEL, temperature=0.0, max_tokens=1800)
embedding_client = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
web_search_tool = DuckDuckGoSearchRun()

In [20]:
class Result(BaseModel):
    """Represents a retrievable chunk plus metadata."""

    page_content: str
    metadata: dict


class RankOrder(BaseModel):
    """Represents reranked chunk ids in descending relevance order."""

    order: List[int] = Field(
        description="The order of relevance of chunks, from most relevant to least relevant, by chunk id number"
    )

In [22]:
def get_collection_or_raise():
    """Return the pre-built vector collection or raise a clear error."""
    chroma = PersistentClient(path=DB_NAME)
    collection_names = [collection.name for collection in chroma.list_collections()]
    if COLLECTION_NAME not in collection_names:
        raise ValueError(
            f"Collection '{COLLECTION_NAME}' was not found in '{DB_NAME}'. Build vectors first."
        )

    collection = chroma.get_or_create_collection(COLLECTION_NAME)
    if collection.count() == 0:
        raise ValueError(
            f"Collection '{COLLECTION_NAME}' is empty in '{DB_NAME}'. Build vectors first."
        )
    return collection


def rerank_chunks(question: str, chunks: List[Result]) -> List[Result]:
    """Rerank retrieved chunks by relevance using Claude structured output."""
    if not chunks:
        return []

    prompt_lines = [
        f"Question:\n{question}\n",
        "Chunks:\n",
    ]
    for index, chunk in enumerate(chunks, start=1):
        prompt_lines.append(f"# CHUNK ID: {index}\n\n{chunk.page_content}\n")

    structured_llm = llm.with_structured_output(RankOrder)
    rank_order = structured_llm.invoke(
        [
            SystemMessage(
                content=(
                    "You are a document re-ranker. Rank chunk ids by relevance and return only JSON: "
                    "{\"order\": [<id_1>, <id_2>, ...]} including all chunk ids once."
                )
            ),
            HumanMessage(content="\n".join(prompt_lines).strip()),
        ]
    )

    valid_order = [index for index in rank_order.order if 1 <= index <= len(chunks)]
    if not valid_order:
        return chunks
    return [chunks[index - 1] for index in valid_order]


def format_rag_context(question: str) -> str:
    """Retrieve and rerank context from ChromaDB, then format for answering."""
    collection = get_collection_or_raise()
    query_vector = embedding_client.embed_query(question)
    results = collection.query(
        query_embeddings=[query_vector],
        n_results=RETRIEVAL_K,
        include=["documents", "metadatas", "distances"],
    )

    documents = results.get("documents", [[]])[0]
    metadatas = results.get("metadatas", [[]])[0]
    distances = results.get("distances", [[]])[0]

    chunks = [
        Result(page_content=document, metadata=metadata)
        for document, metadata in zip(documents, metadatas)
    ]
    chunks = rerank_chunks(question, chunks)

    if not chunks:
        return "NO_RELEVANT_KB_CONTEXT"

    if distances and min(distances) > MAX_KB_DISTANCE:
        return "NO_RELEVANT_KB_CONTEXT"

    context = "\n\n".join(
        f"Extract from {chunk.metadata.get('source', 'unknown')}:\n{chunk.page_content}"
        for chunk in chunks
    )
    return context

In [23]:
@tool
def rag_search(query: str) -> str:
    """Search the Insurellm knowledge base and return relevant context text."""
    context = format_rag_context(query)
    return context


@tool
def web_search(query: str) -> str:
    """Search the public web for general information when KB context is missing."""
    return web_search_tool.run(query)


@tool
def transcribe_audio(path_to_file: str) -> str:
    """Transcribe a local audio file into plain text."""
    return transcriber(path_to_file)


tools = [rag_search, web_search, transcribe_audio]
agent_llm = llm.bind_tools(tools)


def handle_tool_call(tool_call: dict) -> ToolMessage:
    """Execute one tool call and convert result into a LangChain ToolMessage."""
    tool_name = tool_call.get("name", "")
    tool_args = tool_call.get("args", {})
    tool_call_id = tool_call.get("id", "")

    if tool_name == "rag_search":
        content = rag_search.invoke(tool_args)
    elif tool_name == "web_search":
        content = web_search.invoke(tool_args)
    elif tool_name == "transcribe_audio":
        content = transcribe_audio.invoke(tool_args)
    else:
        content = f"Unknown tool: {tool_name}"

    return ToolMessage(content=content, tool_call_id=tool_call_id)


def convert_history_to_messages(history: List[dict]) -> List[BaseMessage]:
    """Convert Gradio message history into LangChain message objects."""
    messages: List[BaseMessage] = []
    for item in history:
        role = item.get("role", "")
        content = str(item.get("content", ""))
        if role == "user":
            messages.append(HumanMessage(content=content))
        elif role == "assistant":
            messages.append(AIMessage(content=content))
    return messages

In [24]:
SYSTEM_PROMPT = """
You are an Insurellm assistant that must use tools when needed.

Rules:
1) For company questions, call `rag_search` first.
2) If `rag_search` returns `NO_RELEVANT_KB_CONTEXT`, call `web_search`.
3) When using web fallback, be transparent and begin with:
   "There is no such information in our Insurellm knowledge base. Here is general information from the internet:"
4) If the user provides an audio file path, call `transcribe_audio` first.
5) Be concise, accurate, and user-friendly.
"""


def run_agent(user_text: str, history: Optional[List[dict]] = None) -> str:
    """Run a tool-calling conversation loop and return the final assistant text."""
    if history is None:
        history = []

    used_web_fallback = False
    messages: List[BaseMessage] = [SystemMessage(content=SYSTEM_PROMPT.strip())]
    messages.extend(convert_history_to_messages(history))
    messages.append(HumanMessage(content=user_text))

    for _ in range(4):
        ai_message = agent_llm.invoke(messages)
        messages.append(ai_message)

        if not ai_message.tool_calls:
            final_text = ai_message.content if isinstance(ai_message.content, str) else str(ai_message.content)
            if used_web_fallback:
                prefix = (
                    "There is no such information in our Insurellm knowledge base. "
                    "Here is general information from the internet:\n\n"
                )
                if not final_text.startswith("There is no such information in our Insurellm knowledge base"):
                    final_text = prefix + final_text
            return final_text

        for tool_call in ai_message.tool_calls:
            if tool_call.get("name") == "web_search":
                used_web_fallback = True
            tool_message = handle_tool_call(tool_call)
            messages.append(tool_message)

    return "I could not complete the request after several tool steps. Please rephrase your question."

In [25]:
def check_vector_store_status() -> str:
    """Return a status string for the preprocessed vector store."""
    try:
        collection = get_collection_or_raise()
        return f"Vector store ready: {collection.count()} chunks in '{COLLECTION_NAME}' from '{DB_NAME}'."
    except Exception as exc:
        return f"Vector store error: {exc}"


def create_audio_tool_call(path_to_file: str) -> dict:
    """Create a synthetic tool call object for transcribing user audio."""
    return {
        "name": "transcribe_audio",
        "args": {"path_to_file": path_to_file},
        "id": "manual_audio_transcribe",
    }


def chat_handler(user_text: str, audio_path: str, history: List[dict]):
    """Handle UI requests, transcribe audio filepath if present, and answer with the agent."""
    if history is None:
        history = []

    final_user_text = (user_text or "").strip()

    if audio_path:
        transcription_tool_message = handle_tool_call(create_audio_tool_call(audio_path))
        transcribed_text = str(transcription_tool_message.content).strip()
        if final_user_text:
            final_user_text = f"{final_user_text}\n\nTranscribed audio: {transcribed_text}"
        else:
            final_user_text = transcribed_text

    if not final_user_text:
        return history, "", None

    response_text = run_agent(final_user_text, history)

    updated_history = history + [
        {"role": "user", "content": final_user_text},
        {"role": "assistant", "content": response_text},
    ]
    return updated_history, "", None


def build_gradio_app() -> gr.Blocks:
    """Create a simple Gradio interface for the RAG tool-calling agent."""
    with gr.Blocks(title="Insurellm RAG Agent") as demo:
        gr.Markdown("# Insurellm RAG Agent")
        gr.Markdown(
            "Uses preprocessed vectors from ChromaDB, falls back to web search when KB is missing details, and supports microphone/upload transcription."
        )

        status_box = gr.Textbox(label="Vector Store Status", interactive=False)
        refresh_status_button = gr.Button("Refresh Vector Store Status")

        chatbot = gr.Chatbot(label="Assistant")
        user_text = gr.Textbox(label="Message", placeholder="Ask about Insurellm...")
        audio_input = gr.Audio(
            label="Optional audio question",
            sources=["microphone"],
            type="filepath",
        )
        send_button = gr.Button("Send")

        send_button.click(
            fn=chat_handler,
            inputs=[user_text, audio_input, chatbot],
            outputs=[chatbot, user_text, audio_input],
        )

        refresh_status_button.click(fn=check_vector_store_status, outputs=status_box)
        demo.load(fn=check_vector_store_status, outputs=status_box)

    return demo


demo_app = build_gradio_app()

In [None]:
demo_app.launch(share=True)