#Tool-calling Agent

This is an auto-generated notebook created by an AI playground export. In this notebook, you will:
- Author a tool-calling [MLflow's `ResponsesAgent`](https://mlflow.org/docs/latest/api_reference/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ResponsesAgent) that uses the OpenAI client
- Manually test the agent's output
- Evaluate the agent with Mosaic AI Agent Evaluation
- Log and deploy the agent

This notebook should be run on serverless or a cluster with DBR<17.

 **_NOTE:_**  This notebook uses the OpenAI SDK, but AI Agent Framework is compatible with any agent authoring framework, including LlamaIndex or LangGraph. To learn more, see the [Authoring Agents](https://docs.databricks.com/generative-ai/agent-framework/author-agent) Databricks documentation.

## Prerequisites

- Address all `TODO`s in this notebook.

In [0]:
%pip install -U -qqqq backoff databricks-openai uv databricks-agents mlflow-skinny[databricks] gepa
dbutils.library.restartPython()

## Define the agent in code
Below we define our agent code in a single cell, enabling us to easily write it to a local Python file for subsequent logging and deployment using the `%%writefile` magic command.

For more examples of tools to add to your agent, see [docs](https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html).

In [0]:
%%writefile agent.py
import json
from typing import Any, Callable, Generator, Optional
from uuid import uuid4
import warnings

import backoff
import mlflow
import mlflow.genai
import openai
from databricks.sdk import WorkspaceClient
from databricks_openai import UCFunctionToolkit, VectorSearchRetrieverTool
from mlflow.entities import SpanType, Document
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
    output_to_responses_items_stream,
    to_chat_completions_input,
)
from openai import OpenAI
from pydantic import BaseModel
from unitycatalog.ai.core.base import get_uc_function_client
from mlflow.tracking import MlflowClient

############################################
# Define your LLM endpoint and prompt from registry
############################################
LLM_ENDPOINT_NAME = "databricks-gpt-5-mini"

# Specify the prompt name from the MLflow prompt registry
# In Databricks, prompts are registered in Unity Catalog format: catalog.schema.promptname
PROMPT_NAME = "jack_demos_classic.data_catalogue_demo.catalog_rag_system_prompt"

# Load the system prompt from the prompt registry
try:
    # Get the latest version number first
    client = MlflowClient()
    response = client.search_prompt_versions(PROMPT_NAME)
    if response.prompt_versions:
        latest_version = max(v.version for v in response.prompt_versions)
        # Load the latest version using mlflow.genai.load_prompt
        prompt_obj = mlflow.genai.load_prompt(name_or_uri=PROMPT_NAME, version=latest_version)
        SYSTEM_PROMPT = prompt_obj.template
        print(f"Loaded prompt '{PROMPT_NAME}' (version {prompt_obj.version}) from registry")
    else:
        raise Exception("No versions found")
except Exception as e:
    # Fallback to default prompt if registry prompt doesn't exist
    print(f"Warning: Could not load prompt '{PROMPT_NAME}' from registry: {e}")
    print("Using default system prompt.")
    SYSTEM_PROMPT = """you are a helpful assistant, when prompted with a query search for related entries in the vector search. You report back with tables and datasets related to the user question."""


###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################
class ToolInfo(BaseModel):
    """
    Class representing a tool for the agent.
    - "name" (str): The name of the tool.
    - "spec" (dict): JSON description of the tool (matches OpenAI Responses format)
    - "exec_fn" (Callable): Function that implements the tool logic
    - "is_retriever" (bool): Whether this tool is a retriever (for proper span typing)
    """

    name: str
    spec: dict
    exec_fn: Callable
    is_retriever: bool = False


def create_tool_info(tool_spec, exec_fn_param: Optional[Callable] = None, is_retriever: bool = False):
    tool_spec["function"].pop("strict", None)
    tool_name = tool_spec["function"]["name"]
    udf_name = tool_name.replace("__", ".")

    # Define a wrapper that accepts kwargs for the UC tool call,
    # then passes them to the UC tool execution client
    def exec_fn(**kwargs):
        function_result = uc_function_client.execute_function(udf_name, kwargs)
        if function_result.error is not None:
            return function_result.error
        else:
            return function_result.value
    return ToolInfo(name=tool_name, spec=tool_spec, exec_fn=exec_fn_param or exec_fn, is_retriever=is_retriever)


def create_vs_tool_wrapper(vs_tool_instance):
    """Create a proper closure for vector search tool execution."""
    def execute_wrapper(**kwargs):
        return vs_tool_instance.execute(**kwargs)
    return execute_wrapper


TOOL_INFOS = []

# You can use UDFs in Unity Catalog as agent tools
# TODO: Add additional tools
UC_TOOL_NAMES = []

uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES)
uc_function_client = get_uc_function_client()
for tool_spec in uc_toolkit.tools:
    TOOL_INFOS.append(create_tool_info(tool_spec))


# Use Databricks vector search indexes as tools
# See [docs](https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html) for details

# Use Databricks vector search indexes as tools
# See the [Databricks Documentation](https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html) for details
VECTOR_SEARCH_TOOLS = []
VECTOR_SEARCH_TOOLS.append(
        VectorSearchRetrieverTool(
            index_name="jack_demos_classic.data_catalogue_demo.metadata_docs_index",
            disable_notice=True,
            # TODO: specify index description for better agent tool selection
            # tool_description=""
        )
    )
for vs_tool in VECTOR_SEARCH_TOOLS:
    TOOL_INFOS.append(create_tool_info(vs_tool.tool, create_vs_tool_wrapper(vs_tool), is_retriever=True))



class ToolCallingAgent(ResponsesAgent):
    """
    Class representing a tool-calling Agent
    """

    def __init__(self, llm_endpoint: str, tools: list[ToolInfo]):
        """Initializes the ToolCallingAgent with tools."""
        self.llm_endpoint = llm_endpoint
        self.workspace_client = WorkspaceClient()
        self.model_serving_client: OpenAI = (
            self.workspace_client.serving_endpoints.get_open_ai_client()
        )
        self._tools_dict = {tool.name: tool for tool in tools}

    def get_tool_specs(self) -> list[dict]:
        """Returns tool specifications in the format OpenAI expects."""
        return [tool_info.spec for tool_info in self._tools_dict.values()]

    def execute_tool(self, tool_name: str, args: dict) -> Any:
        """Executes the specified tool with the given arguments."""
        tool_info = self._tools_dict[tool_name]
        
        # Use RETRIEVER span for retrieval tools, TOOL span for others
        span_type = SpanType.RETRIEVER if tool_info.is_retriever else SpanType.TOOL
        
        with mlflow.start_span(name=tool_name, span_type=span_type) as span:
            result = tool_info.exec_fn(**args)
            
            # For retriever tools, set the retrieved documents as span outputs
            if tool_info.is_retriever and result:
                try:
                    # Parse the result - VectorSearchRetrieverTool returns JSON string
                    if isinstance(result, str):
                        parsed_result = json.loads(result)
                    else:
                        parsed_result = result
                    
                    # Extract documents and convert to Document objects
                    documents = []
                    if isinstance(parsed_result, list):
                        for doc in parsed_result:
                            if isinstance(doc, dict):
                                # Extract page_content and metadata
                                page_content = doc.get("page_content", doc.get("content", doc.get("text", str(doc))))
                                metadata = doc.get("metadata", {})
                                doc_id = doc.get("id")
                                
                                # Create Document object
                                doc_obj = Document(
                                    page_content=page_content,
                                    metadata=metadata
                                )
                                if doc_id:
                                    doc_obj.id = doc_id
                                    
                                documents.append(doc_obj)
                    
                    # Set documents as span outputs for RAG scorers
                    if documents:
                        span.set_outputs(documents)
                        
                except Exception as e:
                    # If parsing fails, log but don't break the flow
                    print(f"Warning: Could not parse retriever results: {e}")
            
            return result

    def call_llm(self, messages: list[dict[str, Any]]) -> Generator[dict[str, Any], None, None]:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", message="PydanticSerializationUnexpectedValue")
            for chunk in self.model_serving_client.chat.completions.create(
                model=self.llm_endpoint,
                messages=to_chat_completions_input(messages),
                tools=self.get_tool_specs(),
                stream=True,
            ):
                chunk_dict = chunk.to_dict()
                if len(chunk_dict.get("choices", [])) > 0:
                    yield chunk_dict

    def handle_tool_call(
        self,
        tool_call: dict[str, Any],
        messages: list[dict[str, Any]],
    ) -> ResponsesAgentStreamEvent:
        """
        Execute tool calls, add them to the running message history, and return a ResponsesStreamEvent w/ tool output
        """
        arguments_raw = tool_call.get("arguments")
        
        # Handle different argument formats
        if arguments_raw is None or arguments_raw == "":
            args = {}
        elif isinstance(arguments_raw, dict):
            args = arguments_raw
        elif isinstance(arguments_raw, str):
            try:
                args = json.loads(arguments_raw)
            except json.JSONDecodeError as e:
                # Try to extract the first valid JSON object if multiple are concatenated
                try:
                    # Find the first complete JSON object
                    decoder = json.JSONDecoder()
                    args, idx = decoder.raw_decode(arguments_raw)
                except Exception as e2:
                    args = {}
        else:
            args = {}
        
        result = str(self.execute_tool(tool_name=tool_call["name"], args=args))

        tool_call_output = self.create_function_call_output_item(tool_call["call_id"], result)
        messages.append(tool_call_output)
        return ResponsesAgentStreamEvent(type="response.output_item.done", item=tool_call_output)

    def call_and_run_tools(
        self,
        messages: list[dict[str, Any]],
        max_iter: int = 10,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        for _ in range(max_iter):
            last_msg = messages[-1]
            if last_msg.get("role", None) == "assistant":
                return
            elif last_msg.get("type", None) == "function_call":
                yield self.handle_tool_call(last_msg, messages)
            else:
                yield from output_to_responses_items_stream(
                    chunks=self.call_llm(messages), aggregator=messages
                )

        yield ResponsesAgentStreamEvent(
            type="response.output_item.done",
            item=self.create_text_output_item("Max iterations reached. Stopping.", str(uuid4())),
        )

    def _build_trace_metadata(self, request: ResponsesAgentRequest) -> dict[str, Any]:
        metadata: dict[str, Any] = {}

        session_id = None
        if request.custom_inputs and "session_id" in request.custom_inputs:
            session_id = request.custom_inputs.get("session_id")
        elif request.context and request.context.conversation_id:
            session_id = request.context.conversation_id

        if session_id:
            metadata["mlflow.trace.session"] = session_id

        if request.custom_inputs and "assistant_message_id" in request.custom_inputs:
            assistant_message_id = request.custom_inputs.get("assistant_message_id")
            if assistant_message_id:
                metadata["app.assistant_message_id"] = assistant_message_id

        return metadata

    def _apply_trace_metadata(self, request: ResponsesAgentRequest) -> None:
        metadata = self._build_trace_metadata(request)
        if metadata:
            mlflow.update_current_trace(metadata=metadata)

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        self._apply_trace_metadata(request)

        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs, custom_inputs=request.custom_inputs)

    def predict_stream(self, request: ResponsesAgentRequest) -> Generator[ResponsesAgentStreamEvent, None, None]:
        self._apply_trace_metadata(request)

        messages = to_chat_completions_input([i.model_dump() for i in request.input])
        if SYSTEM_PROMPT:
            messages.insert(0, {"role": "system", "content": SYSTEM_PROMPT})
        yield from self.call_and_run_tools(messages=messages)


# Log the model using MLflow
mlflow.openai.autolog()
AGENT = ToolCallingAgent(llm_endpoint=LLM_ENDPOINT_NAME, tools=TOOL_INFOS)
mlflow.models.set_model(AGENT)

## Prompt Registry Setup

The agent has been refactored to use **MLflow Prompt Registry** for managing system prompts. This provides:

* **Version Control**: Track changes to prompts over time
* **Centralized Management**: Update prompts without modifying code
* **Collaboration**: Share and iterate on prompts across teams
* **Rollback**: Easily revert to previous prompt versions

### How it works:

1. **Register Prompt**: Create a prompt in Unity Catalog format (`catalog.schema.promptname`)
2. **Agent Loads Prompt**: The agent automatically loads the latest version at runtime
3. **Update Prompt**: Register new versions to iterate without redeploying the agent
4. **Fallback**: If the prompt isn't found, the agent uses a default prompt

### Prompt Name Format:

In Databricks, prompts must be registered using Unity Catalog three-level naming:
```
catalog.schema.promptname
```

Example: `jack_demos_classic.data_catalogue_demo.catalog_rag_system_prompt`

In [0]:
import mlflow.genai

# Define the prompt name and initial template
# In Databricks, prompts are registered in Unity Catalog format: catalog.schema.promptname
prompt_name = "jack_demos_classic.data_catalogue_demo.catalog_rag_system_prompt"
initial_prompt_template = """you are a helpful assistant, when prompted with a query search for related entries in the vector search. You report back with tables and datasets related to the user question."""

# Create and register the prompt in MLflow prompt registry
try:
    # Try to register the prompt - if it exists, this will create a new version
    print(f"Registering prompt '{prompt_name}' in the registry...")
    prompt = mlflow.genai.register_prompt(
        name=prompt_name,
        template=initial_prompt_template,
        commit_message="Initial system prompt for catalog RAG agent",
        tags={
            "task": "catalog_search",
            "agent": "catalog_rag_agent"
        }
    )
    print(f"âœ“ Prompt '{prompt.name}' (version {prompt.version}) registered successfully!")
    print(f"Template: {initial_prompt_template}")
except Exception as e:
    print(f"Error registering prompt: {e}")
    import traceback
    traceback.print_exc()

In [0]:
# Optional: Update the prompt template in the registry
# Uncomment and modify the template below to update the prompt

# new_prompt_template = """You are an expert data catalog assistant. When users ask about data, 
# search the vector index for relevant tables and datasets. Provide clear, structured responses 
# that include:
# - Table names and their purpose
# - Key columns and data types
# - Relationships between tables
# - Usage examples when relevant
# """

# # Register a new version of the prompt
# updated_prompt = mlflow.genai.register_prompt(
#     name="jack_demos_classic.data_catalogue_demo.catalog_rag_system_prompt",  # Must match the prompt name used in agent.py
#     template=new_prompt_template,
#     commit_message="Enhanced prompt with structured output format"
# )
# print(f"Prompt updated to version {updated_prompt.version}!")

print("To update the prompt, uncomment the code above and modify the template.")
print("Each update creates a new version in the prompt registry for version control.")
print("Note: In Databricks, prompts use Unity Catalog format: catalog.schema.promptname")

## Test the agent

Interact with the agent to test its output. Since we manually traced methods within `ResponsesAgent`, you can view the trace for each step the agent takes, with any LLM calls made via the OpenAI SDK automatically traced by autologging.

Replace this placeholder input with an appropriate domain-specific example for your agent.

In [0]:
dbutils.library.restartPython()

In [0]:
from agent import AGENT

AGENT.predict(
    {"input": [{"role": "user", "content": "what is 4*3 in python"}], "custom_inputs": {"session_id": "test-session-123"}},
)

In [0]:
for chunk in AGENT.predict_stream(
    {"input": [{"role": "user", "content": "What is 4*3 in Python?"}], "custom_inputs": {"session_id": "test-session-123"}}
):
    print(chunk.model_dump(exclude_none=True))

### Log the `agent` as an MLflow model
Determine Databricks resources to specify for automatic auth passthrough at deployment time
- **TODO**: If your Unity Catalog Function queries a [vector search index](https://docs.databricks.com/generative-ai/agent-framework/unstructured-retrieval-tools.html) or leverages [external functions](https://docs.databricks.com/generative-ai/agent-framework/external-connection-tools.html), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See [docs](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) for more details.

Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import LLM_ENDPOINT_NAME, VECTOR_SEARCH_TOOLS, uc_toolkit
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from mlflow.tracking import MlflowClient
from pkg_resources import get_distribution

# Store runs in a shared experiment path.
experiment_path = "/Shared/rag_catalog_agent"

mlflow_client = MlflowClient()
experiment = mlflow_client.get_experiment_by_name(experiment_path)
if experiment is None:
    experiment_id = mlflow_client.create_experiment(experiment_path)
else:
    experiment_id = experiment.experiment_id

mlflow.set_experiment(experiment_id=experiment_id)
print(f"Using MLflow experiment: {experiment_path}")

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool in VECTOR_SEARCH_TOOLS:
    resources.extend(tool.resources)
for tool in uc_toolkit.tools:
    # TODO: If the UC function includes dependencies like external connection or vector search, please include them manually.
    # See the TODO in the markdown above for more information.
    udf_name = tool.get("function", {}).get("name", "").replace("__", ".")
    resources.append(DatabricksFunction(function_name=udf_name))

input_example = {
    "input": [
        {
            "role": "user",
            "content": "crm"
        }
    ],
    "custom_inputs": {
        "session_id": "test-session"
    }
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        input_example=input_example,
        pip_requirements=[
            "databricks-openai",
            "backoff",
            f"databricks-connect=={get_distribution('databricks-connect').version}",
        ],
        resources=resources,
    )

## Evaluate the agent with [Agent Evaluation](https://docs.databricks.com/mlflow3/genai/eval-monitor)

You can edit the requests or expected responses in your evaluation dataset and run evaluation as you iterate your agent, leveraging mlflow to track the computed quality metrics.

Evaluate your agent with one of our [predefined LLM scorers](https://docs.databricks.com/mlflow3/genai/eval-monitor/predefined-judge-scorers), or try adding [custom metrics](https://docs.databricks.com/mlflow3/genai/eval-monitor/custom-scorers).

In [0]:
import mlflow
from mlflow.genai.scorers import RelevanceToQuery, Safety, RetrievalRelevance, RetrievalGroundedness, Correctness

eval_dataset = [
    {
        "inputs": {
            "input": [
                {
                    "role": "user",
                    "content": "Show me all tables related to CRM and sales opportunities"
                }
            ]
        },
        "expectations": {
            "expected_response": "should include the SALES_DB.CRM.OPPORTUNITIES table and related CRM tables"
        }
    },
    {
        "inputs": {
            "input": [
                {
                    "role": "user",
                    "content": "What tables contain financial data like invoices and general ledger entries?"
                }
            ]
        },
        "expectations": {
            "expected_response": "should include FINANCE_DB.AP.INVOICES for accounts payable invoices and FINANCE_DB.GL.GENERAL_LEDGER for accounting journal entries"
        }
    },
    {
        "inputs": {
            "input": [
                {
                    "role": "user",
                    "content": "I need to analyze inventory levels by warehouse location"
                }
            ]
        },
        "expectations": {
            "expected_response": "should recommend SUPPLY_DB.WAREHOUSE.INVENTORY_SNAPSHOT which contains daily inventory levels by SKU and location"
        }
    },
    {
        "inputs": {
            "input": [
                {
                    "role": "user",
                    "content": "Where can I find marketing campaign performance metrics and spend data?"
                }
            ]
        },
        "expectations": {
            "expected_response": "should reference MARKETING_DB.ATTRIBUTION.CAMPAIGN_PERFORMANCE which tracks daily campaign performance by channel including spend in USD"
        }
    }
]

eval_results = mlflow.genai.evaluate(
    data=eval_dataset,
    predict_fn=lambda input: AGENT.predict({"input": input, "custom_inputs": {"session_id": "evaluation-session"}}),
    scorers=[
        RelevanceToQuery(),           # Response addresses the query
        Safety(),                      # No harmful content
        Correctness(),                 # Matches expected response
        RetrievalRelevance(),          # Retrieved docs are relevant
        RetrievalGroundedness()        # Response grounded in retrieved context
    ],
)

# Review the evaluation results in the MLflow UI (see console output)

## Optimize the System Prompt

Use MLflow's **Prompt Optimization** to automatically improve your agent's system prompt based on evaluation metrics. This uses the GEPA (Generative Prompt Adaptation) algorithm to iteratively refine prompts through LLM-driven reflection and automated feedback.

### How it works:

1. **Analyzes Performance**: Reviews your agent's outputs on the evaluation dataset
2. **Identifies Issues**: Uses LLM reflection to understand where the prompt falls short
3. **Generates Improvements**: Creates an optimized prompt that addresses the issues
4. **Validates**: Tests the new prompt against your evaluation metrics
5. **Registers**: Automatically saves the optimized prompt as a new version in the registry

### Benefits:

* **Data-Driven**: Uses your actual evaluation data to guide optimization
* **Automatic**: No manual prompt engineering required
* **Version Controlled**: New prompt versions are tracked in the registry
* **Measurable**: Shows improvement in evaluation metrics

In [0]:
from mlflow.genai.optimize import GepaPromptOptimizer
from mlflow.genai.scorers import RelevanceToQuery, Correctness
import mlflow.genai
from agent import PROMPT_NAME
from mlflow.tracking import MlflowClient

# Get the latest version number
client = MlflowClient()
response = client.search_prompt_versions(PROMPT_NAME)
latest_version = max(v.version for v in response.prompt_versions)

# Define a predict function that explicitly loads and formats the prompt
def predict_with_prompt(input):
    """Prediction function that loads the prompt from registry and uses it."""
    # Load the specific version of the prompt that we're optimizing
    prompt = mlflow.genai.load_prompt(f"prompts:/{PROMPT_NAME}/{latest_version}")
    
    # Call format() to signal to the optimizer that this prompt is being used
    _ = prompt.format()
    
    # Now call the agent which will load and use the prompt internally
    return AGENT.predict({
        "input": input,
        "custom_inputs": {"session_id": "optimization-session"}
    })

# Prepare training data in the format expected by optimize_prompts
train_data = [
    {
        "inputs": item["inputs"],
        "expectations": item["expectations"]
    }
    for item in eval_dataset
]

print(f"Optimizing prompt: {PROMPT_NAME}")
print(f"Current version: {latest_version}")
print(f"Training data size: {len(train_data)} examples")
print("\nStarting prompt optimization... This may take several minutes.\n")

# Run prompt optimization
# Note: Only using scorers that return single numerical values
# RetrievalRelevance and RetrievalGroundedness return lists and cannot be used for optimization
optimization_result = mlflow.genai.optimize_prompts(
    predict_fn=predict_with_prompt,
    train_data=train_data,
    prompt_uris=[f"prompts:/{PROMPT_NAME}/{latest_version}"],
    optimizer=GepaPromptOptimizer(
        reflection_model="databricks:/databricks-gpt-5",
        max_metric_calls=10,
    ),
    scorers=[
        RelevanceToQuery(),
        Correctness(),
    ],
)

print("\n" + "="*80)
print("OPTIMIZATION COMPLETE")
print("="*80)
print(f"\nOptimized prompt has been registered as a new version in: {PROMPT_NAME}")
print(f"\nOriginal prompt version: {latest_version}")

if hasattr(optimization_result, 'metrics'):
    print(f"\nPerformance improvement:")
    for metric_name, scores in optimization_result.metrics.items():
        if isinstance(scores, dict) and 'optimized' in scores and 'original' in scores:
            print(f"  {metric_name}: {scores['optimized']:.3f} (was {scores['original']:.3f})")
        else:
            print(f"  {metric_name}: {scores}")

if hasattr(optimization_result, 'optimized_prompt'):
    print("\n" + "="*80)
    print("OPTIMIZED PROMPT")
    print("="*80)
    print(optimization_result.optimized_prompt)
    print("\n" + "="*80)

## Perform pre-deployment validation of the agent
Before registering and deploying the agent, we perform pre-deployment checks via the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See [documentation](https://docs.databricks.com/machine-learning/model-serving/model-serving-debug.html#validate-inputs) for details

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={"input": [{"role": "user", "content": "Hello!"}], "custom_inputs": {"session_id": "validation-session"}},
    env_manager="uv",
)

## Register the model to Unity Catalog

Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "jack_demos_classic"
schema = "data_catalogue_demo"
model_name = "catalog_rag_agent"
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

## Deploy the agent

In [0]:
from databricks import agents
# NOTE: pass scale_to_zero=True to agents.deploy() to enable scale-to-zero for cost savings.
# This is not recommended for production workloads, as capacity is not guaranteed when scaled to zero.
# Scaled to zero endpoints may take extra time to respond when queried, while they scale back up.
agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags = {"endpointSource": "playground"})

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See [docs](https://docs.databricks.com/generative-ai/deploy-agent.html) for details