# Text-to-SQL Agent Driver Notebook

This notebook demonstrates how to:
1. Load and test the Text-to-SQL agent
2. Log the agent as an MLflow model
3. Evaluate the agent with MLflow
4. Register and deploy the agent to Unity Catalog

For more information, check out [MLflow Agent Framework documentation](https://docs.databricks.com/aws/en/generative-ai/agent-framework/index.html).

## Setup

This notebook is designed to run in a Databricks notebook environment with access to Unity Catalog.

In [None]:
import os
import mlflow

# Load configuration
configs = mlflow.models.ModelConfig(development_config="./config.yml")
databricks_config = configs.get("databricks")
agent_config = configs.get("agent")
tools_config = configs.get("tools")

CATALOG = databricks_config["catalog"]
SCHEMA = databricks_config["schema"]
UC_MODEL = databricks_config["model"]
WORKSPACE_URL = databricks_config["workspace_url"]
SQL_WAREHOUSE_ID = databricks_config["sql_warehouse_id"]
MLFLOW_EXPERIMENT_ID = databricks_config["mlflow_experiment_id"]

UC_TABLES = tools_config["tables"]
UC_FUNCTIONS = tools_config.get("uc_functions", [])
UC_FUNCTIONS_SCHEMA = f"{CATALOG}.{SCHEMA}"
UC_CONNECTION = tools_config["uc_connection"]["name"]
LLM_ENDPOINT = agent_config["llm"]["endpoint"]

SECRET_SCOPE_NAME = databricks_config.get("databricks_pat").get("secret_scope_name")
SECRET_KEY_NAME = databricks_config.get("databricks_pat").get("secret_key_name")

os.environ["DB_MODEL_SERVING_HOST_URL"] = WORKSPACE_URL
os.environ["DATABRICKS_TOKEN"] = dbutils.secrets.get(
    scope=SECRET_SCOPE_NAME, key=SECRET_KEY_NAME
)

In [None]:
import mlflow

mlflow.set_registry_uri("databricks-uc")
mlflow.set_tracking_uri("databricks")

try:
    experiment = mlflow.get_experiment(experiment_id=MLFLOW_EXPERIMENT_ID)
    mlflow.set_experiment(experiment_id=MLFLOW_EXPERIMENT_ID)
    print(f"Set to existing experiment: {MLFLOW_EXPERIMENT_ID}")
except mlflow.exceptions.RestException as e:
    if "does not exist" in str(e):
        print(f"Experiment not found. Must create one first.")
    else:
        raise e

## Load & Test Agent

Make sure you go to the MLflow experiment to look at trace data as you develop & debug the agent.

In [None]:
from agent import AGENT

In [None]:
# Sample questions for testing the text-to-sql agent
sample_questions = [
    "What tables are available?",
    "Show me the schema for the balance_sheet table.",
    "What were the annual net income over the last 10 years for AAPL?",
    "Compare annual total assets over the last 10 years between Apple and Bank of America.",
]

input_example = {
    "input": [
        {
            "role": "user",
            "content": sample_questions[0],
        }
    ]
}

In [None]:
# Test predict (non-streaming)
result = AGENT.predict(input_example)
print(result.model_dump(exclude_none=True))

In [None]:
# Test predict_stream (streaming)
for event in AGENT.predict_stream(input_example):
    print(event, "-----------\n")

## Log the Agent as an MLflow Model

In [None]:
from mlflow.models.resources import (
    DatabricksFunction,
    DatabricksSQLWarehouse,
    DatabricksServingEndpoint,
    DatabricksTable,
    DatabricksUCConnection,
)

# Define resources that the agent depends on
resources = [
    DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT),
    DatabricksSQLWarehouse(warehouse_id=SQL_WAREHOUSE_ID),
    DatabricksUCConnection(connection_name=UC_CONNECTION),
]

# Add UC Functions as resources from config
uc_functions = tools_config.get("uc_functions", [])

for function_name in uc_functions:
    resources.append(
        DatabricksFunction(function_name=f"{UC_FUNCTIONS_SCHEMA}.{function_name}")
    )

# Add tables as resources
for table_name in UC_TABLES:
    resources.append(DatabricksTable(table_name=table_name))

print("Resources:", resources)

In [None]:
with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model=os.path.join(os.getcwd(), "agent.py"),
        model_config=os.path.join(os.getcwd(), "config.yml"),
        code_paths=[os.path.join(os.getcwd(), "system_prompt.md")],
        input_example=input_example,
        resources=resources,
        pip_requirements=["-r ../requirements.txt"],
    )

print(f"Logged agent: {logged_agent_info.model_uri}")

## Evaluate the Agent with MLflow

Create evaluation questions and use MLflow's GenAI scorers to evaluate agent performance.

In [None]:
import json

# Load evaluation dataset if it exists
evals_json_path = "./evals/eval-questions.json"

with open(evals_json_path, "r") as f:
    eval_dataset_list = json.load(f)

In [None]:
import mlflow
from mlflow.genai.scorers import (
    Correctness,
    RelevanceToQuery,
    Safety,
)

# Run evaluation
eval_results = mlflow.genai.evaluate(
    data=eval_dataset_list,
    predict_fn=lambda input: AGENT.predict({"input": input}),
    scorers=[
        Correctness(),
        RelevanceToQuery(),
        Safety(),
    ],
)

print("Evaluation complete. Check MLflow UI for detailed results.")

## Run Pre-Deployment Agent Validation

Test the logged model before deploying to ensure it works correctly.

In [None]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={"input": [{"role": "user", "content": "What tables are available?"}]},
    env_manager="uv",
)

## Register the Model to Unity Catalog

In [None]:
FULL_UC_MODEL_NAME = f"{CATALOG}.{SCHEMA}.{UC_MODEL}"

uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri,
    name=FULL_UC_MODEL_NAME,
)

print(
    f"Registered model: {FULL_UC_MODEL_NAME} (version {uc_registered_model_info.version})"
)

## Deploy the Agent

Deploy the agent to a Model Serving endpoint for production use.

In [None]:
from databricks import agents

agents.deploy(
    FULL_UC_MODEL_NAME,
    uc_registered_model_info.version,
    tags={"endpointSource": "docs"},
    environment_vars={
        "DATABRICKS_TOKEN": f"{{{{secrets/{SECRET_SCOPE_NAME}/{SECRET_KEY_NAME}}}}}"
    },
)

print(f"Deployed {FULL_UC_MODEL_NAME} version {uc_registered_model_info.version}")
print(f"Access the endpoint at: {WORKSPACE_URL}/ml/endpoints/{FULL_UC_MODEL_NAME}")

## Next Steps

* Test the agent endpoint via Playground or the Review App
* Continue to iterate on the agent based on evaluation results
* Create more comprehensive evaluation datasets
* Monitor agent performance using Inference Tables
* Set up alerts for query failures or performance issues