In [0]:
%pip install --upgrade --quiet mlflow databricks-sdk langgraph databricks-langchain databricks-agents openai gepa
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


## Sales Support Multi-Agent Framework

 This notebook creates a multi-agent system for sales support with the following components:
  - **Structured Data Agent**: Queries structured sales data (opportunities, accounts, activities)
  - **Vector search Agent**: Retrieves information from unstructured documents (emails, meeting notes, feedback)
  - **Supervisor**: Routes queries to appropriate agents and orchestrates responses


In [0]:
%run ./00-init-requirements

In [0]:
import warnings
import os
os.environ["DATABRICKS_DISABLE_NOTICE"] = "true"
warnings.filterwarnings("ignore", message=".*notebook authentication token.*")

import functools
import json
import time
from typing import Any, Generator, Literal, Optional, Dict
from langchain.agents import create_agent

import mlflow
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import sql
from databricks.vector_search.client import VectorSearchClient
from databricks_langchain import ChatDatabricks, VectorSearchRetrieverTool
from langgraph.graph.state import CompiledStateGraph
from langchain_core.runnables import RunnableLambda
from langgraph.graph import END, StateGraph
from mlflow.entities import Feedback
from mlflow.genai import evaluate, scorer
from mlflow.genai.judges import CategoricalRating
from mlflow.genai.optimize import GepaPromptOptimizer
from databricks_langchain.genie import GenieAgent
from langchain_core.runnables import RunnableLambda
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from mlflow.genai.optimize import GepaPromptOptimizer
from mlflow.genai.scorers import Correctness
from mlflow.entities import SpanType
from mlflow.langchain.chat_agent_langgraph import ChatAgentState
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from pydantic import BaseModel

# Import libraries for Genie API
from genie_api_classes import *

In [0]:
w = WorkspaceClient()
host = w.config.host

# Set your catalog / schema
CATALOG_NAME = catalog_name  
SCHEMA_NAME = schema_name  

print("Host:", host)
print("Catalog:", CATALOG_NAME)
print("Schema:", SCHEMA_NAME)

# LLM for prompt routing + RAG
# llm = ChatDatabricks(endpoint="databricks-claude-sonnet-4-5")
llm = ChatDatabricks(endpoint="databricks-gpt-5-1")

# MLflow experiment
MLFLOW_EXPERIMENT_NAME = f"/Users/{w.current_user.me().user_name}/multiagent_genie_{CATALOG_NAME}"
mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)
print("Experiment:", MLFLOW_EXPERIMENT_NAME)

Host: https://e2-demo-field-eng.cloud.databricks.com
Catalog: andrea_tardif_v2
Schema: workday_demos
Experiment: /Users/andrea.tardif@databricks.com/multiagent_genie_andrea_tardif_v2


In [0]:
## Create SQL Warehouse for the Genie Space

SQL_WAREHOUSE = w.warehouses.create(
    name=f"multiagent-demo-{int(time.time())}",
    auto_stop_mins=5,
    max_num_clusters=1,
    cluster_size="Small",
    enable_serverless_compute=True,
    tags=sql.EndpointTags(
        custom_tags=[sql.EndpointTagPair(key="Demo", value="Multi-Agent Demo Blog")]
    ),
)

print("Created SQL Warehouse:", SQL_WAREHOUSE.id)

Created SQL Warehouse: 6ced32b2405bc6de


In [0]:
## Create Genie Space using the Genie API

client = GenieSpacesClient()

# This JSON file should contain your Genie Space definition
GENIE_SPACE_JSON = "genie_space_blog_demo.json" 

with open(GENIE_SPACE_JSON, "r", encoding="utf-8") as f:
    space_json = json.load(f)

def replace_catalog(obj, old: str, new: str):
    """Recursively replace catalog name in Genie JSON."""
    if isinstance(obj, dict):
        return {k: replace_catalog(v, old, new) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [replace_catalog(i, old, new) for i in obj]
    elif isinstance(obj, str):
        return obj.replace(old, new)
    return obj

space_json = replace_catalog(space_json, "andrea_tardif", CATALOG_NAME)
serialized_space = space_json["serialized_space"]

created_space = client.create_space(
    warehouse_id=SQL_WAREHOUSE.id,
    serialized_space=serialized_space,
    title=f"Sales Support Agent - {CATALOG_NAME}",
)

GENIE_SPACE_ID = created_space["space_id"]
print("Created Genie Space:", GENIE_SPACE_ID)

Created Genie Space: 01f0c710ce8b14fa903c46d6ab98d085


In [0]:
## Create RAG Agent with Vector Search Tools

EMAIL_INDEX = f"{CATALOG_NAME}.{SCHEMA_NAME}.email_communications_index"
MEETING_NOTES_INDEX = f"{CATALOG_NAME}.{SCHEMA_NAME}.meeting_notes_index"
CUSTOMER_FEEDBACK_INDEX = f"{CATALOG_NAME}.{SCHEMA_NAME}.customer_feedback_index"

email_retriever = VectorSearchRetrieverTool(
    index_name=EMAIL_INDEX,
    columns=["content", "doc_uri"],
    name="email_search",
    description=(
        "Searches through email communications between sales reps and customers. "
        "Use this to find pricing discussions, objections, follow-ups, proposals, "
        "and customer correspondence."
    ),
    disable_notice=True,
)

meeting_notes_retriever = VectorSearchRetrieverTool(
    index_name=MEETING_NOTES_INDEX,
    columns=["content", "doc_uri"],
    name="meeting_notes_search",
    description=(
        "Searches meeting notes and summaries from customer calls and demos. "
        "Use this for context from past meetings, agreed actions, and open questions."
    ),
    disable_notice=True,
)

feedback_retriever = VectorSearchRetrieverTool(
    index_name=CUSTOMER_FEEDBACK_INDEX,
    columns=["content", "doc_uri"],
    name="customer_feedback_search",
    description=(
        "Searches customer feedback, NPS surveys, and comments. "
        "Use this for sentiment, complaints, and feature requests."
    ),
    disable_notice=True,
)

rag_tools = [email_retriever, meeting_notes_retriever, feedback_retriever]
print("Configured RAG tools.")

Configured RAG tools.


In [0]:
pat = w.tokens.create(comment=f"genie-agent-{int(time.time())}").token_value

genie_description = (
    "Use GenieAgent for questions that require querying structured, tabular, "
    "numeric sales data from CRM or data warehouse tables. Examples: pipeline "
    "amount, revenue, win rate, deal counts, metrics by region, time, or segment."
)

genie_agent = GenieAgent(
    genie_space_id=GENIE_SPACE_ID,
    genie_agent_name="GenieAgent",
    description=genie_description,
    client=WorkspaceClient(host=host, token=pat),
)

rag_description = (
    "Use RAGAgent for questions that require reading unstructured text like emails, "
    "meeting notes, call transcripts, or NPS feedback. Examples: what someone said, "
    "sentiment, objections raised, qualitative feedback."
)

rag_agent = create_agent(llm, rag_tools)

print("Created GenieAgent and RAGAgent.")

Created GenieAgent and RAGAgent.


Trace(trace_id=tr-8c81843c1f697ff76e522367ef64c4be)

In [0]:
worker_descriptions = { "GenieAgent": genie_description, "RAGAgent": rag_description, }

formatted_descriptions = "\n".join( f"- {name}: {desc}" for name, desc in worker_descriptions.items() )

system_prompt = f"""You are a strategic supervisor coordinating between specialized sales support agents.

Your role is to:
1. Analyze the user's question to determine which agent(s) can best answer it
2. Route to the appropriate agent based on the question type
3. Ensure complete answers without redundant work
4. Synthesize information from multiple agents if needed

Available agents:
{formatted_descriptions}

Routing Guidelines:
- Use GenieAgent for: metrics, numbers, quotas, pipeline values, rep performance, account counts, etc.
- Use RAGAgent for: customer communications, meeting context, feedback, concerns, proposals, etc.
- You can route to multiple agents if the question requires both types of information

Only respond with FINISH when:
- The user's question has been fully answered
- All necessary information has been gathered and processed

Avoid routing to the same agent multiple times for the same information.

Important:
- Do not choose FINISH until at least one specialized agent has been invoked.
- Prefer GenieAgent for numeric/metric queries; RAGAgent for unstructured text queries.
"""

In [0]:
def build_multi_agent_supervisor(system_prompt: str):
    """
    Builds the full multi-agent supervisor system using the given system_prompt.
    Returns:
      - multi_agent: compiled LangGraph graph
      - agent: wrapped LangGraph ChatAgent
    """

    # Register prompt for tracking in Unity Catalog
    prompt_location = f"{CATALOG_NAME}.{SCHEMA_NAME}.sales_multiagent_supervisor"

    supervisor_prompt = mlflow.genai.register_prompt(
        name=prompt_location,
        template=system_prompt,
        commit_message="Supervisor routing prompt (auto-generated).",
    )


    # Supervisor agent definition
    options = ["FINISH"] + list(worker_descriptions.keys())
    FINISH = {"next_node": "FINISH"}

    def load_system_prompt():
        prompt = mlflow.genai.load_prompt(supervisor_prompt.uri)
        return prompt.template

    @mlflow.trace(span_type=SpanType.AGENT, name="supervisor_agent")
    def supervisor_agent(state):
        MAX_ITERATIONS = 4

        system_prompt = load_system_prompt()

        count = state.get("iteration_count", 0) + 1
        if count > MAX_ITERATIONS:
            return FINISH

        class NextNode(BaseModel):
            next_node: Literal[tuple(options)]

        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
        )
        supervisor_chain = preprocessor | llm.with_structured_output(NextNode)
        result = supervisor_chain.invoke(state)
        next_node = result.next_node

        if state.get("next_node") == next_node:
            return FINISH

        return {
            "iteration_count": count,
            "next_node": next_node
        }


    # Agent nodes + final synthesis node
    def agent_node(state, agent, name):
        result = agent.invoke({"messages": state["messages"]})
        return {
            "messages": [{
                "role": "assistant",
                "content": result["messages"][-1].content,
                "name": name,
            }]
        }

    def final_answer(state):
        prompt = (
            "Based on the information gathered by the specialized agents, "
            "provide a comprehensive answer to the user's question."
        )
        preproc = RunnableLambda(
            lambda state: state["messages"] + [{"role": "user", "content": prompt}]
        )
        final_chain = preproc | llm
        return {"messages": [final_chain.invoke(state)]}

    class AgentState(ChatAgentState):
        next_node: str
        iteration_count: int

    # Agents
    rag_node = functools.partial(agent_node, agent=rag_agent, name="RAGAgent")
    genie_node = functools.partial(agent_node, agent=genie_agent, name="GenieAgent")

    # Build workflow graph
    workflow = StateGraph(AgentState)
    workflow.add_node("GenieAgent", genie_node)
    workflow.add_node("RAGAgent", rag_node)
    workflow.add_node("supervisor", supervisor_agent)
    workflow.add_node("final_answer", final_answer)

    workflow.set_entry_point("supervisor")

    for worker in worker_descriptions.keys():
        workflow.add_edge(worker, "supervisor")

    workflow.add_conditional_edges(
        "supervisor",
        lambda x: x["next_node"],
        {**{k: k for k in worker_descriptions.keys()}, "FINISH": "final_answer"},
    )

    workflow.add_edge("final_answer", END)
    multi_agent = workflow.compile()

    # Wrap in Databricks ChatAgent
    class LangGraphChatAgent(ChatAgent):
        def __init__(self, agent):
            self.agent = agent

        def predict(self, messages, context=None, custom_inputs=None):
            request = {"messages": [m.model_dump(exclude_none=True) for m in messages]}
            msgs = []
            for event in self.agent.stream(request, stream_mode="updates"):
                for node_data in event.values():
                    msgs.extend(ChatAgentMessage(**m) for m in node_data.get("messages", []))
            return ChatAgentResponse(messages=msgs)

    agent = LangGraphChatAgent(multi_agent)

    return supervisor_prompt, multi_agent, agent


In [0]:
supervisor_prompt, multi_agent, AGENT = build_multi_agent_supervisor(system_prompt)

def create_predict_fn(prompt_uri: str):
    """
    GEPA-compatible predict_fn:
    - Loads a supervisor prompt from MLflow Prompt Registry.
    - Uses the global AGENT (LangGraphChatAgent) for prediction.
    - Returns a single answer string.
    """
    prompt_obj = mlflow.genai.load_prompt(prompt_uri)

    @mlflow.trace
    def predict_fn(question: str) -> str:
        system_prompt = prompt_obj.template

        msgs = [
            ChatAgentMessage(role="system", content=system_prompt),
            ChatAgentMessage(role="user", content=question),
        ]
        response = AGENT.predict(messages=msgs)
        last = next((m for m in reversed(response.messages) if m.role == "assistant"), None)
        if last:
            return last.content
        return ""

    return predict_fn

In [0]:
import json

with open("training_data_multi_agent_blog.json", "r") as f:
    train_data = json.load(f)

In [0]:
correctness = Correctness(
    reference_key="expected_response", 
    task_type="qa",
)


def run_benchmark(
    prompt_uri: str,
    num_samples: int,
    split: str = "validation",
) -> dict:
    """Run the agent on train_data using multiple scorers."""

    # Use the first N examples from train_data
    eval_data = train_data[:num_samples]

    # Create prediction fn bound to this prompt
    predict_fn = create_predict_fn(prompt_uri)

    print(f"\nRunning evaluation on {len(eval_data)} samples...\n")

    results = evaluate(
        data=eval_data,
        predict_fn=predict_fn,
        scorers=[correctness],
    )

    correctness_acc = results.metrics.get("correctness/mean", 0.0) / 100.0

    return {
        "correctness": correctness_acc,
        "metrics": results.metrics,
        "results": results,
    }


baseline_metrics = run_benchmark(
    prompt_uri=supervisor_prompt.uri,
    num_samples=20,          
)

print("Correctness Accuracy :", f"{baseline_metrics['correctness']:.2%}")

2025/11/21 19:35:43 INFO mlflow.genai.utils.data_validation: Testing model prediction with the first sample in the dataset. To disable this check, set the MLFLOW_GENAI_EVAL_SKIP_TRACE_VALIDATION environment variable to True.



Running evaluation on 20 samples...



Evaluating:   0%|          | 0/20 [Elapsed: 00:00, Remaining: ?] 

Helpfulness Accuracy : 1.00%
Correctness Accuracy : 0.35%


In [0]:
result = mlflow.genai.optimize_prompts(
    predict_fn=create_predict_fn(supervisor_prompt.uri),
    train_data=train_data,
    prompt_uris=[supervisor_prompt.uri],
    optimizer=GepaPromptOptimizer(
        reflection_model="databricks:/databricks-gpt-5-1",
        max_metric_calls=50,
    ),
    scorers=[correctness],
    enable_tracking=True,
)

# Get the optimized prompt URI
optimized_prompt_uri = result.optimized_prompts[0].uri
print(f"  Base prompt: {supervisor_prompt.uri}")
print(f"  Optimized prompt: {optimized_prompt_uri}")

2025/11/21 19:38:51 INFO mlflow.genai.utils.data_validation: Testing model prediction with the first sample in the dataset. To disable this check, set the MLFLOW_GENAI_EVAL_SKIP_TRACE_VALIDATION environment variable to True.
  return _dataset_source_registry.resolve(
  return _dataset_source_registry.resolve(


Iteration 0: Base program full valset score: 0.6 over 20 / 20 examples
Iteration 1: Selected program 0 score: 0.6
Iteration 1: Proposed new text for andrea_tardif_v2.workday_demos.sales_multiagent_supervisor: You are a **strategic supervisor** coordinating between specialized **sales and customer insights agents** in a Databricks / Workday demo environment.

Your primary goal is to:
1. Correctly decide **which specialized agent(s)** to invoke for each user question.
2. Invoke those agents **at least once** before ever returning a final answer.
3. **Synthesize** the agents’ outputs into a concise, decision‑useful answer to the user’s question.
4. **Avoid redundant calls** to the same agent for the same information.

You never directly query data warehouses or document indexes yourself; you reason about what should be done, call the right tools/agents, then interpret their outputs.

---

## Available Specialized Agents

You coordinate between **two** agents:

### 1. GenieAgent (Structure

Catalog created andrea_tardif_v2
Schema created andrea_tardif_v2.workday_demos
Volume created /Volumes/andrea_tardif_v2/workday_demos/workday_unstructure_data


In [0]:
optimized_metrics = run_benchmark(optimized_prompt_uri, num_samples=100)

print(f"Optimized correctness: {optimized_metrics['correctness']:.2%}")

improvement = optimized_metrics['correctness'] - baseline_metrics['correctness']
print(f"Improvement: {improvement:+.2%}")


Running evaluation on 20 samples...



2025/11/21 19:50:40 INFO mlflow.genai.utils.data_validation: Testing model prediction with the first sample in the dataset. To disable this check, set the MLFLOW_GENAI_EVAL_SKIP_TRACE_VALIDATION environment variable to True.


Evaluating:   0%|          | 0/20 [Elapsed: 00:00, Remaining: ?] 

Optimized correctness: 0.37%
Improvement: +0.02%


In [0]:
optimized_prompt = mlflow.genai.load_prompt(optimized_prompt_uri)

supervisor_prompt_opt, multi_agent_opt, AGENT_OPT = build_multi_agent_supervisor(
    optimized_prompt.template
)

In [0]:
def ask_optimized(question: str) -> str:
    """
    Send a single question to the optimized multi-agent supervisor
    and return the final assistant message content.
    """
    messages = [ChatAgentMessage(role="user", content=question)]
    response = AGENT_OPT.predict(messages=messages)

    # Grab the last assistant message
    last = next(
        (m for m in reversed(response.messages) if m.role == "assistant"),
        None,
    )
    return last.content if last else ""

In [0]:
ask_optimized("How many opportunities closed last quarter for the Enterprise segment?")

'Based on the information gathered by the specialized agents, there were **8 opportunities closed last quarter for the Enterprise segment**.'

[Trace(trace_id=tr-0666b9c1620f94f2c41cc7991798f64c), Trace(trace_id=tr-253e0fd37e98a79e2f846bb73939621a), Trace(trace_id=tr-e39fff3c831bcd075d4f59307e199e26)]

In [0]:
ask_optimized("Summarize the main concerns raised by the customer in recent emails about pricing.")

'Here are the main pricing-related concerns the customer has raised in recent emails:\n\n1. **Total cost vs. budget constraints**  \n   - They’re worried the quoted pricing may exceed their current budget.  \n   - They’ve asked whether there is any room to adjust scope, license tiers, or services to better align with their spending limits.\n\n2. **Perceived value and ROI**  \n   - They want clearer justification for the price: what concrete business outcomes, efficiencies, or savings they can expect.  \n   - They’ve requested more quantification of ROI (e.g., cost savings, productivity gains, payback period).\n\n3. **Comparison to alternatives and current setup**  \n   - They’re comparing the proposal against both competitors’ pricing and the cost of maintaining their existing solution.  \n   - They question whether the premium over cheaper options is warranted by additional features or support.\n\n4. **Pricing transparency and potential hidden costs**  \n   - They are concerned about 

[Trace(trace_id=tr-2295af1dacffd7567770aa1f6b229087), Trace(trace_id=tr-e7874853bed1a0c4975049840cbc8d3f), Trace(trace_id=tr-5633b122024f8d2752287d242e4b965b)]