In [0]:
%pip install -U google-adk litellm databricks-sdk langchain databricks-langchain langchain==0.3.7 faiss-cpu wikipedia langgraph==0.5.3  databricks_langchain

Collecting langchain
  Downloading langchain-1.2.3-py3-none-any.whl.metadata (4.9 kB)
Collecting databricks-langchain
  Downloading databricks_langchain-0.12.1-py3-none-any.whl.metadata (3.0 kB)
Collecting langchain
  Downloading langchain-0.3.7-py3-none-any.whl.metadata (7.1 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.13.2-cp310-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl.metadata (7.6 kB)
Collecting wikipedia
  Downloading wikipedia-1.4.0.tar.gz (27 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting langgraph==0.5.3
  Downloading langgraph-0.5.3-py3-none-any.whl.metadata (6.9 kB)
Collecting langchain-core<0.4.0,>=0.3.15 (from langchain)
  Downloading langchain_core-0.3.83-py3-none-any.whl.metadata (3.2 kB)
Collecting langchain-text-splitters<0.4.0,>=0.3.0 (from langchain)
  Downloading langchain_text_splitters-0.3.11-py3-none-any.whl.metadata (1.8 kB)
Collecting langsmith<0.2.0,>=0.1.17 (from langchai

In [0]:
dbutils.library.restartPython()

### Configs

In [0]:
import os
import json
import requests
from dotenv import load_dotenv

from langchain.prompts import ChatPromptTemplate
from databricks_langchain import ChatDatabricks


In [0]:

LLM_ENDPOINT_NAME = "databricks-meta-llama-3-1-8b-instruct" # Model Serving endpoint name; other option see "Serving" under AI/ML tab (e.g. databricks-gpt-oss-20b)

llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME, temperature=0.2)

GOOGLE_API_KEY="AI**"


In [0]:
# -------------------------------------------------------------------
# LLM Instances by Role (Enterprise Pattern)
# -------------------------------------------------------------------

# Deterministic classifier
classifier_llm = ChatDatabricks(
    endpoint=LLM_ENDPOINT_NAME,
    temperature=0.0,   # STRICT determinism
)

# Simple direct answers
simple_llm = ChatDatabricks(
    endpoint=LLM_ENDPOINT_NAME,
    temperature=0.2,
)

# Multi-step reasoning
reasoning_llm = ChatDatabricks(
    endpoint=LLM_ENDPOINT_NAME,
    temperature=0.2,
)

# Internet-search-style synthesis (prompt-based)
search_llm = ChatDatabricks(
    endpoint=LLM_ENDPOINT_NAME,
    temperature=0.2,
)

print("LLM roles initialized: classifier, simple, reasoning, search")

LLM roles initialized: classifier, simple, reasoning, search


In [0]:
classification_prompt = ChatPromptTemplate.from_template("""
You are an enterprise routing controller.

Classify the user prompt into EXACTLY ONE category:
- simple
- reasoning
- internet_search

Definitions:
- simple: directly answerable; no multi-step reasoning; no fresh data needed
- reasoning: requires multi-step logic, comparison, synthesis, or structured thought
- internet_search: requires up-to-date or externally verifiable information

Hard rules:
- Return STRICT JSON only (no markdown, no extra text)
- Use ONLY the two keys: classification, reasoning
- classification must be exactly one of: "simple", "reasoning", "internet_search"
- reasoning must be a short sentence

User prompt:
{user_prompt}

Return JSON (exact keys only):
{{
  "classification": "simple | reasoning | internet_search",
  "reasoning": "..."
}}
""")

classification_chain = classification_prompt | classifier_llm
print("Classification chain initialized (JSON-only).")

Classification chain initialized (JSON-only).


In [0]:
search_prompt = ChatPromptTemplate.from_template("""
You are a research assistant operating WITHOUT web access.

The question requires up-to-date or externally verifiable information.
Generate a structured evidence pack to help answer it.

Rules:
- Do NOT claim real-time access
- Clearly mark uncertainty
- Be concise and factual
- Return STRICT JSON only
- Do NOT include markdown or extra text

User question:
{user_prompt}

Return JSON in this exact format:
{{
  "key_facts": [
    "Important factual points relevant to the question"
  ],
  "likely_current_answer": "Best estimate based on general knowledge (may be outdated)",
  "what_to_verify": [
    "What a human should verify using an authoritative source"
  ],
  "confidence_level": "low | medium | high"
}}
""")

search_chain = search_prompt | search_llm
print("Search (evidence generation) chain initialized.")

Search (evidence generation) chain initialized.


In [0]:
# -------------------------------------------------------------------
# Answer generation prompts
# -------------------------------------------------------------------

simple_answer_prompt = ChatPromptTemplate.from_template("""
Answer the question directly and concisely.

Question:
{user_prompt}
""")

reasoning_answer_prompt = ChatPromptTemplate.from_template("""
Answer the question using clear, step-by-step reasoning.
Keep it structured and avoid unnecessary verbosity.

Question:
{user_prompt}
""")

search_answer_prompt = ChatPromptTemplate.from_template("""
You do NOT have web access. Use the evidence pack below to answer.
Be explicit about uncertainty and what should be verified.

Evidence pack (JSON):
{evidence_pack}

Question:
{user_prompt}

Answer:
""")


# -------------------------------------------------------------------
# Generate response based on routing classification
# -------------------------------------------------------------------

def generate_response(user_prompt: str, classification: str, evidence_pack: dict | None = None) -> str:
    classification = classification.strip().lower()

    if classification == "simple":
        chain = simple_answer_prompt | simple_llm
        msg = chain.invoke({"user_prompt": user_prompt})
        return msg.content

    if classification == "reasoning":
        chain = reasoning_answer_prompt | reasoning_llm
        msg = chain.invoke({"user_prompt": user_prompt})
        return msg.content

    if classification == "internet_search":
        if evidence_pack is None:
            raise ValueError("internet_search requires an evidence_pack, but none was provided.")
        chain = search_answer_prompt | search_llm
        msg = chain.invoke({
            "user_prompt": user_prompt,
            "evidence_pack": json.dumps(evidence_pack, ensure_ascii=False, indent=2),
        })
        return msg.content

    raise ValueError(f"Unknown classification: {classification}")


In [0]:
# -------------------------------------------------------------------
# JSON parsing helper (robust to occasional extra text)
# -------------------------------------------------------------------

def parse_json_message(msg) -> dict:
    """
    Extract JSON object from an LLM message safely.
    Fails fast if no valid JSON object is found.
    """
    text = msg.content if hasattr(msg, "content") else str(msg)
    text = text.strip()

    # remove accidental code fences
    if text.startswith("```"):
        text = text.strip("`").strip()

    start = text.find("{")
    end = text.rfind("}") + 1
    if start == -1 or end <= start:
        raise ValueError(f"No JSON object found. Raw output:\n{text}")

    return json.loads(text[start:end])


def validate_classifier_output(obj: dict) -> None:
    """
    Enforce strict contract for classifier:
    - only keys: classification, reasoning
    - classification in allowed set
    """
    if not isinstance(obj, dict):
        raise ValueError("Classifier output must be a JSON object.")

    allowed_keys = {"classification", "reasoning"}
    extra = set(obj.keys()) - allowed_keys
    missing = allowed_keys - set(obj.keys())

    if missing:
        raise ValueError(f"Classifier JSON missing keys: {missing}")
    if extra:
        raise ValueError(f"Classifier JSON has extra keys (not allowed): {extra}")

    if obj["classification"] not in {"simple", "reasoning", "internet_search"}:
        raise ValueError(f"Invalid classification: {obj['classification']}")


def validate_evidence_pack(obj: dict) -> None:
    """
    Enforce strict contract for evidence pack.
    """
    if not isinstance(obj, dict):
        raise ValueError("Evidence pack must be a JSON object.")

    required = {"key_facts", "likely_current_answer", "what_to_verify", "confidence_level"}
    missing = required - set(obj.keys())
    if missing:
        raise ValueError(f"Evidence pack JSON missing keys: {missing}")

    if obj["confidence_level"] not in {"low", "medium", "high"}:
        raise ValueError(f"Invalid confidence_level: {obj['confidence_level']}")


# -------------------------------------------------------------------
# End-to-end router
# -------------------------------------------------------------------

def run_prompt(user_prompt: str) -> dict:
    # 1) classify
    cls_msg = classification_chain.invoke({"user_prompt": user_prompt})
    cls_obj = parse_json_message(cls_msg)
    validate_classifier_output(cls_obj)

    classification = cls_obj["classification"]
    routing_reason = cls_obj["reasoning"]

    # 2) evidence pack if needed
    evidence_pack = None
    if classification == "internet_search":
        ev_msg = search_chain.invoke({"user_prompt": user_prompt})
        evidence_pack = parse_json_message(ev_msg)
        validate_evidence_pack(evidence_pack)

    # 3) generate final response
    final_answer = generate_response(
        user_prompt=user_prompt,
        classification=classification,
        evidence_pack=evidence_pack
    )

    return {
        "user_prompt": user_prompt,
        "classification": classification,
        "routing_reason": routing_reason,
        "evidence_pack": evidence_pack,
        "response": final_answer
    }


print("Router initialized: run_prompt(user_prompt)")

Router initialized: run_prompt(user_prompt)


In [0]:
# -------------------------------------------------------------------
# Interactive runner
# -------------------------------------------------------------------

user_prompt = input("Enter a prompt: ").strip()

result = run_prompt(user_prompt)

print("\n" + "=" * 80)
print("PROMPT:", result["user_prompt"])

print("\nClassification:", result["classification"])
print("Routing reason:", result["routing_reason"])

if result["evidence_pack"] is not None:
    print("\nEvidence pack:")
    print(json.dumps(result["evidence_pack"], indent=2))

print("\nResponse:")
print(result["response"])

Enter a prompt:  what is p_value?


PROMPT: what is p_value?

Classification: simple
Routing reason: A p-value is a statistical measure of the probability of observing results at least as extreme as those observed during an experiment, assuming that the null hypothesis is true.

Response:
p_value is the probability of observing a result as extreme or more extreme than the one observed, assuming the null hypothesis is true.


### Agents

In [0]:
import json
from dataclasses import dataclass
from typing import Callable, Dict, Any, Optional

# LangChain
from langchain.prompts import ChatPromptTemplate

# Databricks LLM
from databricks_langchain import ChatDatabricks

In [0]:
# -------------------------------------------------------------------
# LLM instances by role (enterprise pattern)
# -------------------------------------------------------------------
classifier_llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME, temperature=0.0)  # deterministic routing
simple_llm     = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME, temperature=0.2)
reasoning_llm  = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME, temperature=0.2)
search_llm     = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME, temperature=0.2)

print("Initialized LLM endpoint:", LLM_ENDPOINT_NAME)
print("Roles: classifier_llm, simple_llm, reasoning_llm, search_llm")

Initialized LLM endpoint: databricks-meta-llama-3-1-8b-instruct
Roles: classifier_llm, simple_llm, reasoning_llm, search_llm


In [0]:
@dataclass
class Agent:
    """
    Lightweight agent abstraction (Databricks-native).
    Each agent exposes a `run()` callable that returns a structured dict output.
    """
    name: str
    description: str
    run: Callable[..., Dict[str, Any]]


# Registry placeholder (we'll fill this after defining agents)
AGENTS: Dict[str, Agent] = {}

print("Agent abstraction ready.")


Agent abstraction ready.


In [0]:
# -------------------------------------------------------------------
# Robust JSON extraction from LLM messages
# -------------------------------------------------------------------
def parse_json_message(msg) -> dict:
    """
    Extract the first JSON object from an LLM message.
    Fails fast if no valid JSON object is found.
    """
    text = msg.content if hasattr(msg, "content") else str(msg)
    text = text.strip()

    # Remove accidental code fences
    if text.startswith("```"):
        text = text.strip("`").strip()

    start = text.find("{")
    end = text.rfind("}") + 1
    if start == -1 or end <= start:
        raise ValueError(f"No JSON object found. Raw output:\n{text}")

    return json.loads(text[start:end])


def validate_classifier_output(obj: dict) -> None:
    """
    Enforce strict classifier contract:
    - only keys: classification, reasoning
    - classification in allowed set
    """
    if not isinstance(obj, dict):
        raise ValueError("Classifier output must be a JSON object.")

    allowed_keys = {"classification", "reasoning"}
    extra = set(obj.keys()) - allowed_keys
    missing = allowed_keys - set(obj.keys())

    if missing:
        raise ValueError(f"Classifier JSON missing keys: {missing}")
    if extra:
        raise ValueError(f"Classifier JSON has extra keys (not allowed): {extra}")

    if obj["classification"] not in {"simple", "reasoning", "internet_search"}:
        raise ValueError(f"Invalid classification: {obj['classification']}")


def validate_evidence_pack(obj: dict) -> None:
    """
    Enforce strict evidence pack contract.
    """
    if not isinstance(obj, dict):
        raise ValueError("Evidence pack must be a JSON object.")

    required = {"key_facts", "likely_current_answer", "what_to_verify", "confidence_level"}
    missing = required - set(obj.keys())
    if missing:
        raise ValueError(f"Evidence pack JSON missing keys: {missing}")

    if obj["confidence_level"] not in {"low", "medium", "high"}:
        raise ValueError(f"Invalid confidence_level: {obj['confidence_level']}")


print("JSON parsing + validation utilities ready.")

JSON parsing + validation utilities ready.


In [0]:
# Minimal stub to mimic LangChain message objects with .content
class FakeMsg:
    def __init__(self, content: str):
        self.content = content

def run_tests():
    print("Running tests...\n")

    # -------------------------
    # parse_json_message tests
    # -------------------------
    msg1 = FakeMsg('{"classification":"simple","reasoning":"Direct definition."}')
    assert parse_json_message(msg1)["classification"] == "simple"

    msg2 = FakeMsg('```json\n{"classification":"reasoning","reasoning":"Needs steps."}\n```')
    assert parse_json_message(msg2)["classification"] == "reasoning"

    msg3 = FakeMsg('Some text before\n{"classification":"internet_search","reasoning":"Needs fresh info."}\nSome text after')
    assert parse_json_message(msg3)["classification"] == "internet_search"

    # -------------------------
    # validate_classifier_output tests
    # -------------------------
    good_cls = {"classification": "simple", "reasoning": "Direct."}
    validate_classifier_output(good_cls)

    try:
        validate_classifier_output({"classification": "simple"})  # missing reasoning
        raise AssertionError("Expected missing key error for classifier")
    except ValueError as e:
        print("OK (classifier missing key):", e)

    try:
        validate_classifier_output({"classification": "simple", "reasoning": "x", "extra": 1})
        raise AssertionError("Expected extra key error for classifier")
    except ValueError as e:
        print("OK (classifier extra key):", e)

    try:
        validate_classifier_output({"classification": "unknown", "reasoning": "x"})
        raise AssertionError("Expected invalid classification error")
    except ValueError as e:
        print("OK (classifier invalid label):", e)

    # -------------------------
    # validate_evidence_pack tests
    # -------------------------
    good_ev = {
        "key_facts": ["A", "B"],
        "likely_current_answer": "May be outdated.",
        "what_to_verify": ["Verify with BLS."],
        "confidence_level": "low"
    }
    validate_evidence_pack(good_ev)

    try:
        validate_evidence_pack({"key_facts": []})  # missing most fields
        raise AssertionError("Expected missing key error for evidence pack")
    except ValueError as e:
        print("OK (evidence missing keys):", e)

    try:
        bad_ev = {
            "key_facts": [],
            "likely_current_answer": "",
            "what_to_verify": [],
            "confidence_level": "certain"  # invalid
        }
        validate_evidence_pack(bad_ev)
        raise AssertionError("Expected invalid confidence_level error")
    except ValueError as e:
        print("OK (evidence invalid confidence_level):", e)

    print("\nAll tests passed ✅")

run_tests()


Running tests...

OK (classifier missing key): Classifier JSON missing keys: {'reasoning'}
OK (classifier extra key): Classifier JSON has extra keys (not allowed): {'extra'}
OK (classifier invalid label): Invalid classification: unknown
OK (evidence missing keys): Evidence pack JSON missing keys: {'confidence_level', 'what_to_verify', 'likely_current_answer'}
OK (evidence invalid confidence_level): Invalid confidence_level: certain

All tests passed ✅


In [0]:
# -------------------------------------------------------------------
# Router prompt (STRICT JSON) — remember to escape braces for LangChain
# -------------------------------------------------------------------
router_prompt = ChatPromptTemplate.from_template("""
You are an enterprise routing controller.

Classify the user prompt into EXACTLY ONE category:
- simple
- reasoning
- internet_search

Definitions:
- simple: directly answerable; no multi-step reasoning; no fresh data needed
- reasoning: requires multi-step logic, comparison, synthesis, or structured thought
- internet_search: requires up-to-date or externally verifiable information

Hard rules:
- Return STRICT JSON only (no markdown, no extra text)
- Use ONLY the two keys: classification, reasoning
- classification must be exactly one of: "simple", "reasoning", "internet_search"
- reasoning must be a short sentence

User prompt:
{user_prompt}

Return JSON:
{{
  "classification": "simple | reasoning | internet_search",
  "reasoning": "..."
}}
""")

router_chain = router_prompt | classifier_llm


def router_agent_run(user_prompt: str) -> Dict[str, Any]:
    msg = router_chain.invoke({"user_prompt": user_prompt})
    obj = parse_json_message(msg)
    validate_classifier_output(obj)
    return obj  # {"classification": ..., "reasoning": ...}


RouterAgent = Agent(
    name="RouterAgent",
    description="Deterministic router that classifies prompts into simple/reasoning/internet_search.",
    run=router_agent_run
)

AGENTS["router"] = RouterAgent

print("RouterAgent registered as AGENTS['router']")

RouterAgent registered as AGENTS['router']


In [0]:
# -------------------------------------------------------------------
# SimpleAgent prompt
# -------------------------------------------------------------------
simple_prompt = ChatPromptTemplate.from_template("""
Answer the question directly and concisely.

Question:
{user_prompt}
""")

simple_chain = simple_prompt | simple_llm


def simple_agent_run(user_prompt: str) -> Dict[str, Any]:
    msg = simple_chain.invoke({"user_prompt": user_prompt})
    return {"response": msg.content}


SimpleAgent = Agent(
    name="SimpleAgent",
    description="Fast and efficient agent for straightforward questions.",
    run=simple_agent_run
)

AGENTS["simple"] = SimpleAgent

print("SimpleAgent registered as AGENTS['simple']")

SimpleAgent registered as AGENTS['simple']


In [0]:
# -------------------------------------------------------------------
# ReasoningAgent prompt
# -------------------------------------------------------------------
reasoning_prompt = ChatPromptTemplate.from_template("""
Answer the question using clear, step-by-step reasoning.
Structure your answer logically and avoid unnecessary verbosity.

Question:
{user_prompt}
""")

reasoning_chain = reasoning_prompt | reasoning_llm


def reasoning_agent_run(user_prompt: str) -> Dict[str, Any]:
    msg = reasoning_chain.invoke({"user_prompt": user_prompt})
    return {"response": msg.content}


ReasoningAgent = Agent(
    name="ReasoningAgent",
    description="Handles prompts requiring multi-step reasoning or synthesis.",
    run=reasoning_agent_run
)

AGENTS["reasoning"] = ReasoningAgent

print("ReasoningAgent registered as AGENTS['reasoning']")


ReasoningAgent registered as AGENTS['reasoning']


In [0]:
# -------------------------------------------------------------------
# SearchAgent (Evidence Pack + Synthesis, No External APIs) — retry
# -------------------------------------------------------------------

evidence_prompt = ChatPromptTemplate.from_template("""
You are a research assistant operating WITHOUT web access.

The question requires up-to-date or externally verifiable information.
Generate a structured evidence pack to help answer it.

Rules:
- Do NOT claim real-time access
- Clearly mark uncertainty
- Be concise and factual
- Return STRICT JSON only
- Do NOT include markdown or extra text

User question:
{user_prompt}

Return JSON in this exact format:
{{
  "key_facts": [
    "Important factual points relevant to the question"
  ],
  "likely_current_answer": "Best estimate based on general knowledge (may be outdated)",
  "what_to_verify": [
    "What a human should verify using an authoritative source"
  ],
  "confidence_level": "low | medium | high"
}}
""")

evidence_chain = evidence_prompt | search_llm


synthesis_prompt = ChatPromptTemplate.from_template("""
You do NOT have web access. Use the evidence pack below to answer.
Be explicit about uncertainty and what should be verified.

Evidence pack (JSON):
{evidence_pack}

Question:
{user_prompt}

Answer:
""")

synthesis_chain = synthesis_prompt | search_llm


def search_agent_run(user_prompt: str) -> Dict[str, Any]:
    ev_msg = evidence_chain.invoke({"user_prompt": user_prompt})
    evidence_pack = parse_json_message(ev_msg)
    validate_evidence_pack(evidence_pack)

    synth_msg = synthesis_chain.invoke({
        "user_prompt": user_prompt,
        "evidence_pack": json.dumps(evidence_pack, ensure_ascii=False, indent=2),
    })

    return {"evidence_pack": evidence_pack, "response": synth_msg.content}


SearchAgent = Agent(
    name="SearchAgent",
    description="Generates a prompt-based evidence pack (no web) and synthesizes an answer with uncertainty + verification steps.",
    run=search_agent_run
)

AGENTS["internet_search"] = SearchAgent

print("SearchAgent registered ✅")
print("Registered agents now:", sorted(AGENTS.keys()))

SearchAgent registered ✅
Registered agents now: ['internet_search', 'reasoning', 'router', 'simple']


In [0]:
# -------------------------------------------------------------------
# Orchestrator: router -> dispatch to specialized agent
# -------------------------------------------------------------------

AGENT_MAP = {
    "simple": AGENTS["simple"],
    "reasoning": AGENTS["reasoning"],
    "internet_search": AGENTS["internet_search"],
}

def run_prompt_with_agents(user_prompt: str) -> Dict[str, Any]:
    # 1) Route
    route = AGENTS["router"].run(user_prompt)
    classification = route["classification"]
    routing_reason = route["reasoning"]

    # 2) Dispatch
    if classification not in AGENT_MAP:
        raise ValueError(f"No agent registered for classification: {classification}")

    agent = AGENT_MAP[classification]
    out = agent.run(user_prompt)

    # 3) Structured result (good for logging / observability)
    return {
        "user_prompt": user_prompt,
        "classification": classification,
        "routing_reason": routing_reason,
        "agent_used": agent.name,
        **out,
    }

print("Orchestrator ready ✅  Use: run_prompt_with_agents(user_prompt)")


Orchestrator ready ✅  Use: run_prompt_with_agents(user_prompt)


In [0]:
tests = [
    ("simple", "What is a p-value?"),
    ("reasoning", "If precision is high but recall is low, what does that imply? Give an example."),
    ("internet_search", "What is the current inflation rate in the US?")
]

for expected, prompt in tests:
    print("\n" + "=" * 90)
    print(f"Expected: {expected}")
    print("Prompt:", prompt)

    out = run_prompt_with_agents(prompt)

    print("\nClassification:", out["classification"])
    print("Agent used:", out["agent_used"])
    print("Routing reason:", out["routing_reason"])

    if out["classification"] == "internet_search":
        print("\nEvidence pack:")
        print(json.dumps(out["evidence_pack"], indent=2))

    print("\nResponse:\n", out["response"])


Expected: simple
Prompt: What is a p-value?

Classification: simple
Agent used: SimpleAgent
Routing reason: A p-value is a statistical measure of the probability of observing results at least as extreme as those observed during the experiment, assuming that the null hypothesis is true.

Response:
 The p-value is the probability of observing a result as extreme or more extreme than the one observed, assuming the null hypothesis is true.

Expected: reasoning
Prompt: If precision is high but recall is low, what does that imply? Give an example.

Classification: simple
Agent used: SimpleAgent
Routing reason: It implies that the model is good at making correct predictions, but misses many actual instances.

Response:
 It implies that the model is good at making correct predictions (high precision) but misses many actual instances (low recall). Example: A medical test that correctly identifies 90% of cancer patients (high precision) but misses 70% of actual cancer cases (low recall).

Expec

In [0]:
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, AsyncGenerator

@dataclass
class Message:
    role: str
    text: str

@dataclass
class InvocationContext:
    """
    Minimal ADK-like invocation context.
    Holds the current message plus optional metadata (request_id, user_id, etc.).
    """
    current_message: Message
    metadata: Dict[str, Any] = field(default_factory=dict)

@dataclass
class Event:
    """
    Minimal ADK-like event for observability.
    """
    author: str
    type: str  # e.g. "route", "tool", "final", "debug"
    content: Any
    metadata: Dict[str, Any] = field(default_factory=dict)

print("Event + InvocationContext ready ✅")


Event + InvocationContext ready ✅


In [0]:
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Callable


class AsyncBaseAgent(ABC):
    """
    ADK-like async agent interface that yields Events.
    """

    name: str = "BaseAgent"
    description: str = ""

    async def run(self, context: InvocationContext) -> AsyncGenerator[Event, None]:
        async for ev in self.run_async_impl(context):
            yield ev

    @abstractmethod
    async def run_async_impl(self, context: InvocationContext) -> AsyncGenerator[Event, None]:
        ...


async def _run_in_thread(fn: Callable, *args, **kwargs):
    """
    Run sync functions without blocking the event loop.
    Databricks notebooks are sync by default; this makes async composition possible.
    """
    return await asyncio.to_thread(fn, *args, **kwargs)


class SyncAgentAdapter(AsyncBaseAgent):
    """
    Wraps an existing *sync* agent (our Agent dataclass with .run(user_prompt)->dict)
    into an ADK-like async agent that yields Events.
    """

    def __init__(self, agent: Agent, event_type: str = "final"):
        self._agent = agent
        self.name = agent.name
        self.description = agent.description
        self._event_type = event_type

    async def run_async_impl(self, context: InvocationContext) -> AsyncGenerator[Event, None]:
        yield Event(
            author=self.name,
            type="debug",
            content="starting",
            metadata={"agent": self.name}
        )

        # Execute the underlying sync agent off-thread
        out = await _run_in_thread(self._agent.run, context.current_message.text)

        yield Event(
            author=self.name,
            type=self._event_type,
            content=out,
            metadata={"agent": self.name}
        )


print("AsyncBaseAgent + SyncAgentAdapter ready ✅")

AsyncBaseAgent + SyncAgentAdapter ready ✅


In [0]:
class QueryRouterAgent(AsyncBaseAgent):
    name: str = "QueryRouter"
    description: str = "Routes user queries to the appropriate agent based on classification (with optional heuristics)."

    def __init__(
        self,
        router_agent: Agent,
        simple_agent: Agent,
        reasoning_agent: Agent,
        search_agent: Agent,
        *,
        enable_length_fallback: bool = True,
        short_query_word_threshold: int = 20,
    ):
        self.router_agent = router_agent
        self.enable_length_fallback = enable_length_fallback
        self.short_query_word_threshold = short_query_word_threshold

        # Wrap sync agents into async, event-yielding agents
        self.simple_async = SyncAgentAdapter(simple_agent, event_type="final")
        self.reasoning_async = SyncAgentAdapter(reasoning_agent, event_type="final")
        self.search_async = SyncAgentAdapter(search_agent, event_type="final")

    async def run_async_impl(self, context: InvocationContext) -> AsyncGenerator[Event, None]:
        user_query = context.current_message.text
        query_len = len(user_query.split())

        # 1) Emit debug info
        yield Event(
            author=self.name,
            type="debug",
            content={"query_word_count": query_len},
            metadata={"threshold": self.short_query_word_threshold}
        )

        # 2) Try deterministic classification first (RouterAgent)
        classification = None
        routing_reason = None
        try:
            route_obj = await _run_in_thread(self.router_agent.run, user_query)
            classification = route_obj["classification"]
            routing_reason = route_obj["reasoning"]

            yield Event(
                author=self.name,
                type="route",
                content={"classification": classification, "reasoning": routing_reason},
            )

        except Exception as e:
            # Router failed (parse error, etc.)
            yield Event(
                author=self.name,
                type="error",
                content=f"RouterAgent failed; will {'use' if self.enable_length_fallback else 'not use'} fallback. Error: {e}",
            )

            if not self.enable_length_fallback:
                raise

        # 3) Optional fallback heuristic if RouterAgent failed
        if classification is None and self.enable_length_fallback:
            # simple heuristic: short => simple, long => reasoning
            classification = "simple" if query_len < self.short_query_word_threshold else "reasoning"
            routing_reason = f"Fallback heuristic used: word_count={query_len}, threshold={self.short_query_word_threshold}"

            yield Event(
                author=self.name,
                type="route",
                content={"classification": classification, "reasoning": routing_reason},
                metadata={"fallback": True}
            )

        # 4) Dispatch to the chosen agent and stream its events
        if classification == "simple":
            async for ev in self.simple_async.run(context):
                yield ev

        elif classification == "reasoning":
            async for ev in self.reasoning_async.run(context):
                yield ev

        elif classification == "internet_search":
            async for ev in self.search_async.run(context):
                yield ev

        else:
            yield Event(
                author=self.name,
                type="error",
                content=f"Unknown classification: {classification}",
                metadata={"classification": classification}
            )
            raise ValueError(f"Unknown classification: {classification}")


# Instantiate router with your existing agents from AGENTS registry
query_router_agent = QueryRouterAgent(
    router_agent=AGENTS["router"],
    simple_agent=AGENTS["simple"],
    reasoning_agent=AGENTS["reasoning"],
    search_agent=AGENTS["internet_search"],
    enable_length_fallback=True,
    short_query_word_threshold=20
)

print("QueryRouterAgent ready ✅ (async, event-streaming)")

QueryRouterAgent ready ✅ (async, event-streaming)


In [0]:
async def run_with_events(user_text: str):
    ctx = InvocationContext(current_message=Message(role="user", text=user_text))

    events = []
    async for ev in query_router_agent.run(ctx):
        events.append(ev)

    return events


def pretty_print_events(events):
    print("\n" + "=" * 90)
    for i, ev in enumerate(events, 1):
        print(f"[{i}] {ev.type.upper()}  | author={ev.author}")
        if ev.metadata:
            print("    metadata:", ev.metadata)
        print("    content:", ev.content)
        print("-" * 90)


async def run_agent_async(user_text: str):
    ctx = InvocationContext(current_message=Message(role="user", text=user_text))

    events = []
    async for ev in query_router_agent.run(ctx):
        events.append(ev)

    # pretty print
    print("\n" + "=" * 90)
    for i, ev in enumerate(events, 1):
        print(f"[{i}] {ev.type.upper()}  | author={ev.author}")
        if ev.metadata:
            print("    metadata:", ev.metadata)
        print("    content:", ev.content)
        print("-" * 90)

    return events


print("Notebook-safe runner ready ✅  Use: await run_agent_async('your question')")

Notebook-safe runner ready ✅  Use: await run_agent_async('your question')


In [0]:
events = await run_agent_async("What is a p-value?")


[1] DEBUG  | author=QueryRouter
    metadata: {'threshold': 20}
    content: {'query_word_count': 4}
------------------------------------------------------------------------------------------
[2] ROUTE  | author=QueryRouter
    content: {'classification': 'simple', 'reasoning': 'A p-value is a statistical measure of the probability of observing results at least as extreme as those observed during the experiment, assuming that the null hypothesis is true.'}
------------------------------------------------------------------------------------------
[3] DEBUG  | author=SimpleAgent
    metadata: {'agent': 'SimpleAgent'}
    content: starting
------------------------------------------------------------------------------------------
[4] FINAL  | author=SimpleAgent
    metadata: {'agent': 'SimpleAgent'}
    content: {'response': 'A p-value is the probability of observing a result at least as extreme as the one observed, assuming the null hypothesis is true.'}
---------------------------------

In [0]:
from time import perf_counter

def extract_result_from_events(events):
    """
    Convert an event stream into a clean result dict:
      - classification
      - routing_reason
      - agent_used
      - response
      - evidence_pack (if any)
    """
    result = {
        "classification": None,
        "routing_reason": None,
        "agent_used": None,
        "response": None,
        "evidence_pack": None,
    }

    # routing info
    for ev in events:
        if ev.type == "route" and isinstance(ev.content, dict):
            result["classification"] = ev.content.get("classification")
            result["routing_reason"] = ev.content.get("reasoning")

    # final output from specialist agent
    for ev in reversed(events):
        if ev.type == "final" and isinstance(ev.content, dict):
            result["agent_used"] = ev.author
            # SearchAgent returns {"evidence_pack": ..., "response": ...}
            if "response" in ev.content:
                result["response"] = ev.content["response"]
            if "evidence_pack" in ev.content:
                result["evidence_pack"] = ev.content["evidence_pack"]
            break

    return result


async def run_and_extract(user_text: str):
    """
    Convenience wrapper: run QueryRouterAgent and return the cleaned result dict.
    """
    events = await run_agent_async(user_text)
    return extract_result_from_events(events)


print("Result extractor ready ✅  Use: result = await run_and_extract('...')")

Result extractor ready ✅  Use: result = await run_and_extract('...')


In [0]:
result = await run_and_extract("What is a p-value?")
result


[1] DEBUG  | author=QueryRouter
    metadata: {'threshold': 20}
    content: {'query_word_count': 4}
------------------------------------------------------------------------------------------
[2] ROUTE  | author=QueryRouter
    content: {'classification': 'simple', 'reasoning': 'A p-value is a statistical measure of the probability of observing results at least as extreme as those observed during the experiment, assuming that the null hypothesis is true.'}
------------------------------------------------------------------------------------------
[3] DEBUG  | author=SimpleAgent
    metadata: {'agent': 'SimpleAgent'}
    content: starting
------------------------------------------------------------------------------------------
[4] FINAL  | author=SimpleAgent
    metadata: {'agent': 'SimpleAgent'}
    content: {'response': 'The p-value is the probability of observing a result as extreme or more extreme than the one observed, assuming that the null hypothesis is true.'}
-----------------

{'classification': 'simple',
 'routing_reason': 'A p-value is a statistical measure of the probability of observing results at least as extreme as those observed during the experiment, assuming that the null hypothesis is true.',
 'agent_used': 'SimpleAgent',
 'response': 'The p-value is the probability of observing a result as extreme or more extreme than the one observed, assuming that the null hypothesis is true.',
 'evidence_pack': None}

In [0]:
class TimedQueryRouterAgent(QueryRouterAgent):
    """
    Same as QueryRouterAgent, but emits timing events.
    """

    async def run_async_impl(self, context: InvocationContext) -> AsyncGenerator[Event, None]:
        user_query = context.current_message.text
        query_len = len(user_query.split())

        yield Event(
            author=self.name,
            type="debug",
            content={"query_word_count": query_len},
            metadata={"threshold": self.short_query_word_threshold}
        )

        # ---- Routing timing ----
        t0 = perf_counter()
        classification = None
        routing_reason = None

        try:
            route_obj = await _run_in_thread(self.router_agent.run, user_query)
            classification = route_obj["classification"]
            routing_reason = route_obj["reasoning"]
        finally:
            t1 = perf_counter()
            yield Event(
                author=self.name,
                type="timing",
                content={"stage": "routing", "seconds": round(t1 - t0, 4)}
            )

        yield Event(
            author=self.name,
            type="route",
            content={"classification": classification, "reasoning": routing_reason},
        )

        # ---- Dispatch timing ----
        t2 = perf_counter()
        try:
            if classification == "simple":
                async for ev in self.simple_async.run(context):
                    yield ev
            elif classification == "reasoning":
                async for ev in self.reasoning_async.run(context):
                    yield ev
            elif classification == "internet_search":
                async for ev in self.search_async.run(context):
                    yield ev
            else:
                raise ValueError(f"Unknown classification: {classification}")
        finally:
            t3 = perf_counter()
            yield Event(
                author=self.name,
                type="timing",
                content={"stage": "dispatch_total", "seconds": round(t3 - t2, 4)},
                metadata={"classification": classification}
            )


# Replace the router with the timed version (same config)
timed_query_router_agent = TimedQueryRouterAgent(
    router_agent=AGENTS["router"],
    simple_agent=AGENTS["simple"],
    reasoning_agent=AGENTS["reasoning"],
    search_agent=AGENTS["internet_search"],
    enable_length_fallback=True,
    short_query_word_threshold=20
)

# Update runner to use timed agent
async def run_agent_async_timed(user_text: str):
    ctx = InvocationContext(current_message=Message(role="user", text=user_text))

    events = []
    async for ev in timed_query_router_agent.run(ctx):
        events.append(ev)

    # pretty print
    print("\n" + "=" * 90)
    for i, ev in enumerate(events, 1):
        print(f"[{i}] {ev.type.upper()}  | author={ev.author}")
        if ev.metadata:
            print("    metadata:", ev.metadata)
        print("    content:", ev.content)
        print("-" * 90)

    return events


print("Timed router ready ✅  Use: events = await run_agent_async_timed('...')")

Timed router ready ✅  Use: events = await run_agent_async_timed('...')


In [0]:
events = await run_agent_async_timed("What is the current inflation rate in the US?")



[1] DEBUG  | author=QueryRouter
    metadata: {'threshold': 20}
    content: {'query_word_count': 9}
------------------------------------------------------------------------------------------
[2] TIMING  | author=QueryRouter
    content: {'stage': 'routing', 'seconds': 0.5534}
------------------------------------------------------------------------------------------
[3] ROUTE  | author=QueryRouter
    content: {'classification': 'internet_search', 'reasoning': 'This question requires access to current and up-to-date economic data to provide an accurate answer.'}
------------------------------------------------------------------------------------------
[4] DEBUG  | author=SearchAgent
    metadata: {'agent': 'SearchAgent'}
    content: starting
------------------------------------------------------------------------------------------
[5] FINAL  | author=SearchAgent
    metadata: {'agent': 'SearchAgent'}
    content: {'evidence_pack': {'key_facts': ['The US inflation rate is typically me

In [0]:
# -------------------------------------------------------------------
# CriticAgent prompt (STRICT JSON) — escape braces
# -------------------------------------------------------------------

CRITIC_SYSTEM_PROMPT = """
You are the **Critic Agent**, serving as the quality assurance arm of our collaborative assistant system.
Your primary function is to meticulously review and challenge information from the answer-generating agent,
guaranteeing accuracy, completeness, and unbiased presentation.

Your duties:
- Assess findings for factual correctness, thoroughness, and potential bias.
- Identify missing data, assumptions, or inconsistencies in reasoning.
- Raise critical questions that could refine or expand understanding.
- Offer constructive suggestions for improvement or alternative angles.
- Validate that the final output is comprehensive and balanced.

All criticism must be constructive. Your goal is to fortify the work, not invalidate it.

Return STRICT JSON only. No markdown. No extra keys.
"""

critic_prompt = ChatPromptTemplate.from_template(
    CRITIC_SYSTEM_PROMPT
    + """

Review the following assistant output.

User prompt:
{user_prompt}

Routing classification:
{classification}

Evidence pack (may be null):
{evidence_pack}

Assistant draft answer:
{draft_answer}

Return JSON in this exact format:
{{
  "verdict": "pass | needs_revision",
  "major_issues": ["..."],
  "minor_issues": ["..."],
  "missing_info_questions": ["..."],
  "suggested_improvements": ["..."]
}}
"""
)

critic_chain = critic_prompt | reasoning_llm  # can switch to a stronger endpoint later


def critic_agent_run(
    user_prompt: str,
    classification: str,
    draft_answer: str,
    evidence_pack: Optional[dict] = None
) -> Dict[str, Any]:
    msg = critic_chain.invoke({
        "user_prompt": user_prompt,
        "classification": classification,
        "draft_answer": draft_answer,
        "evidence_pack": json.dumps(evidence_pack, ensure_ascii=False, indent=2) if evidence_pack else "null"
    })
    obj = parse_json_message(msg)

    # Required contract
    required = {
        "verdict",
        "major_issues",
        "minor_issues",
        "missing_info_questions",
        "suggested_improvements"
    }

    # ✅ Robustness: drop any extra keys the model adds (e.g. "confidence_level")
    obj = {k: obj.get(k) for k in required}

    missing = [k for k, v in obj.items() if v is None]
    if missing:
        raise ValueError(f"Critic JSON missing keys: {missing}")

    if obj["verdict"] not in {"pass", "needs_revision"}:
        raise ValueError(f"Invalid critic verdict: {obj['verdict']}")

    return obj


# Re-register CriticAgent with the patched run()
CriticAgent = Agent(
    name="CriticAgent",
    description="Quality assurance agent that reviews drafted answers; returns structured feedback.",
    run=critic_agent_run
)
AGENTS["critic"] = CriticAgent

print("CriticAgent patched ✅ (extra keys ignored)")


CriticAgent patched ✅ (extra keys ignored)


In [0]:
async def run_with_critic(user_text: str) -> Dict[str, Any]:
    ctx = InvocationContext(current_message=Message(role="user", text=user_text))

    events = []
    async for ev in timed_query_router_agent.run(ctx):
        events.append(ev)

    draft_result = extract_result_from_events(events)

    classification = draft_result["classification"]
    draft_answer = draft_result["response"]
    evidence_pack = draft_result.get("evidence_pack")

    critic_feedback = await _run_in_thread(
        AGENTS["critic"].run,
        user_text,
        classification,
        draft_answer,
        evidence_pack
    )

    return {
        "user_prompt": user_text,
        "classification": classification,
        "routing_reason": draft_result["routing_reason"],
        "agent_used": draft_result["agent_used"],
        "evidence_pack": evidence_pack,
        "draft_answer": draft_answer,
        "critic_feedback": critic_feedback,
        "events": events,
    }

In [0]:
reviser_prompt = ChatPromptTemplate.from_template("""
You are the Reviser Agent.

Goal: Improve the draft answer by addressing the Critic feedback.
Rules:
- Do NOT invent facts. If uncertain, say so explicitly.
- Keep the answer clear and concise.
- If classification is internet_search and evidence is low-confidence, emphasize verification steps.
- Address major issues first; incorporate suggested improvements.
- Return plain text only.

User prompt:
{user_prompt}

Routing classification:
{classification}

Evidence pack (may be null):
{evidence_pack}

Draft answer:
{draft_answer}

Critic feedback (JSON):
{critic_feedback}

Write the revised answer:
""")

reviser_chain = reviser_prompt | reasoning_llm


async def run_with_critic_and_revise(
    user_text: str,
    *,
    max_revision_rounds: int = 1,
    re_critic_after_revision: bool = True
) -> Dict[str, Any]:
    bundle = await run_with_critic(user_text)

    classification = bundle["classification"]
    evidence_pack = bundle.get("evidence_pack")
    revised_answer = bundle["draft_answer"]
    critic_feedback = bundle["critic_feedback"]

    critic_after = None

    for _ in range(max_revision_rounds):
        if critic_feedback.get("verdict") == "pass":
            break

        msg = reviser_chain.invoke({
            "user_prompt": user_text,
            "classification": classification,
            "evidence_pack": json.dumps(evidence_pack, ensure_ascii=False, indent=2) if evidence_pack else "null",
            "draft_answer": revised_answer,
            "critic_feedback": json.dumps(critic_feedback, ensure_ascii=False, indent=2),
        })
        revised_answer = msg.content

        if re_critic_after_revision:
            critic_after = await _run_in_thread(
                AGENTS["critic"].run,
                user_text,
                classification,
                revised_answer,
                evidence_pack
            )
            critic_feedback = critic_after

    return {
        **bundle,
        "revised_answer": revised_answer,
        "critic_after_revision": critic_after
    }

In [0]:
result = await run_with_critic_and_revise("What is the current inflation rate in the US?")
print(result["critic_feedback"])
print(result["revised_answer"])
print(result["critic_after_revision"])


{'missing_info_questions': ['What is the current CPI data release from the BLS website?', 'What are the recent inflation trends according to news articles or economic reports from reputable sources?'], 'suggested_improvements': ["The assistant could provide a more direct answer to the user's question, such as 'The current inflation rate in the US is around 2-3%, but this should be confirmed through more recent and reliable sources.'", 'The assistant could also provide more specific information about the sources used to estimate the inflation rate and the recent inflation trends.'], 'minor_issues': ['The assistant could provide more context about the historical trends and pre-pandemic data used to estimate the inflation rate.', 'The assistant could also provide more specific information about the sources used to estimate the inflation rate.'], 'verdict': 'needs_revision', 'major_issues': ["The assistant's answer is too vague and doesn't provide a clear conclusion about the current infla