# Personalized Marketing Agent

This notebook guides you through creating an AI Agent that creates targeted social media campaigns for hotels.

> **‚ú® Improved Version**: This notebook now includes **automatic resource discovery** - no manual configuration needed!

## Setup Overview

This notebook will:
1. **Auto-discover** your LLM endpoint, SQL Warehouse, Unity Catalog, and database
2. **Create** Unity Catalog functions as agent tools
3. **Build** a LangGraph-based agent with MLflow ResponsesAgent interface
4. **Test** the agent locally
5. **Register & Deploy** to Model Serving

## ‚ö†Ô∏è Important: Run Cells in Order

Execute cells sequentially from top to bottom. Most configuration is automatic!

## Step 0: Initial Setup & Configuration

This step will automatically configure your notebook by:
1. Creating configuration widgets
2. Auto-detecting your LLM endpoint, SQL Warehouse, Catalog, and Database
3. Displaying detected values for you to copy into widgets
4. Setting the catalog context

üí° **Important**: You'll need to **copy detected values into the widget form** at the top (one-time setup).  
These widget values will persist across Python restarts!

Run the cells below in order.


In [0]:
# === 0.1: Create Configuration Widgets ===

# Create empty widgets for configuration (these will persist across Python restarts)
dbutils.widgets.text("llm_endpoint", "")
dbutils.widgets.text("catalog", "")
dbutils.widgets.text("database", "")
dbutils.widgets.text("warehouse_id", "")

# User input widgets
dbutils.widgets.text("hotel_to_promote", "")
dbutils.widgets.dropdown(
    "hotel_class", "Resort", ["Resort", "Extended Stay", "Luxury", "Economy", "Airport"]
)

print("‚úÖ Configuration widgets created!")
print("üìù Check the form at the top of the notebook")
print()
print("‚è≠Ô∏è  Continue to the next cell to auto-detect values...")

### 0.2: Auto-Detect Configuration

This cell will automatically discover your Databricks resources and display the values.

‚ö†Ô∏è **Important**: Due to a known widget caching issue, the form at the top may not auto-populate. If it doesn't, you will need to **manually copy the detected values** into the widget fields (one-time setup).


In [0]:
# === 0.2: Auto-Detect Configuration ===
from databricks.sdk import WorkspaceClient

print("üîç Auto-detecting Databricks resources...")
print()

try:
    w = WorkspaceClient()

    # 1. Detect LLM endpoint with priority order
    endpoints = list(w.serving_endpoints.list())
    endpoint_names = [ep.name for ep in endpoints]

    # Priority order: check for specific endpoints first
    priority_endpoints = [
        "databricks-meta-llama-3-3-70b-instruct",
    ]

    llm_endpoint = ""
    for preferred in priority_endpoints:
        if preferred in endpoint_names:
            llm_endpoint = preferred
            break

    # Fallback: if none of the priority endpoints exist, use original detection
    if not llm_endpoint:
        foundation_endpoints = [
            ep.name
            for ep in endpoints
            if any(x in ep.name.lower() for x in ["gpt", "llama", "dbrx", "meta"])
        ]
        llm_endpoint = foundation_endpoints[0] if foundation_endpoints else ""

    # 2. Detect SQL Warehouse
    warehouses = list(w.warehouses.list())
    running = [wh for wh in warehouses if wh.state.value == "RUNNING"]
    available = running if running else warehouses
    warehouse_id = available[0].id if available else ""

    # 3. Detect Unity Catalog
    catalogs = list(w.catalogs.list())
    tf_cats = [
        c for c in catalogs if any(p in c.name for p in ["tfdb", "tf-dbx", "tableflow"])
    ]

    if tf_cats:
        catalog = tf_cats[0].name
        # 4. Detect Schema/Database
        schemas = list(w.schemas.list(catalog_name=catalog))
        user_schemas = [
            s for s in schemas if s.name not in ["information_schema", "default"]
        ]
        tf_schemas = [s for s in user_schemas if s.name.startswith("lkc-")]
        database = (
            tf_schemas[0].name
            if tf_schemas
            else (user_schemas[0].name if user_schemas else "")
        )
    else:
        catalog = ""
        database = ""

    # Display detected values
    print("=" * 80)
    print("‚úÖ AUTO-DETECTED CONFIGURATION")
    print("=" * 80)
    print()
    print(f"llm_endpoint:  {llm_endpoint}")
    print(f"catalog:       {catalog}")
    print(f"database:      {database}")
    print(f"warehouse_id:  {warehouse_id}")
    print()
    print("=" * 80)
    print("üìã PLEASE COPY THESE VALUES INTO THE WIDGET FORM AT THE TOP")
    print("=" * 80)
    print()
    print("‚ö†Ô∏è  Due to a Databricks widget caching issue, the form won't auto-update.")
    print(
        "   Please manually copy the 4 values above into the corresponding widget fields."
    )
    print()
    print("üí° After copying, these values will persist across Python restarts!")
    print()

    if not all([llm_endpoint, catalog, database]):
        print(
            "‚ö†Ô∏è  Warning: Some values were not detected. Please verify your Terraform deployment."
        )

except Exception as e:
    print(f"‚ùå Error during detection: {str(e)}")
    print()
    print("üí° You can manually enter values in the widget form at the top")

### 0.3: Verify Widget Values

This cell checks whether you've successfully copied the values into the widget form.


In [0]:
# === 0.3: Verify Widget Values ===

print("üîç Checking widget values...")
print()

# Read from widgets (user should have copied values)
llm_endpoint = dbutils.widgets.get("llm_endpoint")
catalog = dbutils.widgets.get("catalog")
database = dbutils.widgets.get("database")
warehouse_id = dbutils.widgets.get("warehouse_id")

print("=" * 80)
print("üìã CURRENT WIDGET CONFIGURATION")
print("=" * 80)
print()
print(f"llm_endpoint:  {llm_endpoint or '(empty - please copy from cell above)'}")
print(f"catalog:       {catalog or '(empty - please copy from cell above)'}")
print(f"database:      {database or '(empty - please copy from cell above)'}")
print(f"warehouse_id:  {warehouse_id or '(empty - please copy from cell above)'}")
print()
print("=" * 80)

# Validate
if all([llm_endpoint, catalog, database]):
    print("‚úÖ Configuration complete!")
    print()
    print(
        "üí° These values are now stored in widgets and will persist across Python restarts"
    )
else:
    print("‚ö†Ô∏è  Configuration incomplete!")
    print()
    print(
        "üìã Please go back to cell 0.2 and copy the detected values into the widget form"
    )
    print("   (Look for the form at the very top of the notebook)")

print()
print("=" * 80)

### 0.4: Set Catalog Context

This cell sets the Spark SQL context using your configured catalog and database.


In [0]:
# === 0.4: Set Catalog Context ===

# Get configuration from widgets
catalog = dbutils.widgets.get("catalog")
database = dbutils.widgets.get("database")

if not catalog or not database:
    print("‚ùå ERROR: Catalog and database must be set!")
    print()
    print("üí° Please:")
    print("   1. Go back to cell 0.2")
    print("   2. Copy the detected values into the widget form at the top")
    print("   3. Re-run this cell")
    raise ValueError("Configuration incomplete: catalog and database required")

# Set Spark context
spark.sql(f"USE CATALOG `{catalog}`")
spark.sql(f"USE DATABASE `{database}`")

print("‚úÖ Catalog context set successfully!")
print()
print(f"   Using: {catalog}.{database}")
print()
print("=" * 80)
print("üéâ Step 0 Complete! Configuration is ready.")
print("=" * 80)
print()
print("üí° Widget values will persist even after Python restarts!")

## Step 1: Create Unity Catalog Functions (Agent Tools)

These SQL functions become the "tools" that your agent can use. We'll create three functions:

1. **get_hotel_to_promote** - Finds the lowest performing hotel in a class with good reviews
2. **summarize_customer_reviews** - Generates a summary of why customers like a hotel
3. **identify_target_customers** - Finds potential customers with high interest but few bookings

### 1.1: Create get_hotel_to_promote function

In [0]:
%sql
-- This function takes a hotel class as an input and finds the lowest performing hotel in that class that has at least 3 customer reviews and has an above-average customer satisfaction rating

CREATE OR REPLACE FUNCTION get_hotel_to_promote (input_hotel_class STRING COMMENT 'Hotel class to filter by')
RETURNS TABLE (
  hotel_id STRING,
  hotel_name STRING,
  hotel_city STRING,
  hotel_country STRING,
  hotel_description STRING,
  hotel_class STRING,
  average_review_rating DOUBLE,
  review_count INT
)
LANGUAGE SQL
COMMENT 'This function takes a hotel class as an input and finds the lowest performing hotel in that class that has at least 3 customer reviews and has an above-average customer satisfaction rating'
RETURN
  (SELECT
    `HOTEL_ID`,
    `HOTEL_NAME`,
    `HOTEL_CITY`,
    `HOTEL_COUNTRY`,
    `HOTEL_DESCRIPTION`,
    `HOTEL_CLASS`,
    `AVERAGE_REVIEW_RATING`,
    `REVIEW_COUNT`
    FROM
      hotel_stats
    WHERE HOTEL_CLASS = input_hotel_class
    AND `REVIEW_COUNT` > 2
    AND `AVERAGE_REVIEW_RATING` > (
      SELECT
        AVG(`AVERAGE_REVIEW_RATING`)
      FROM
        hotel_stats
    )
    ORDER BY
      `TOTAL_BOOKINGS_COUNT` ASC
    LIMIT 1
  )

#### Test the function

Verify the `hotel_to_promote` function returns a hotel.


In [0]:
%sql
-- Test the get_hotel_to_promote function
-- Expected: Returns one hotel with good reviews and low bookings

SELECT * FROM get_hotel_to_promote(:hotel_class);

-- ‚úÖ Success: You should see one hotel row with details
-- üìù Note: Copy the hotel_id if you want to test with a specific hotel later
-- ‚ö†Ô∏è  No results? Try a different hotel_class value in the widget above

### 1.2: Create summarize_customer_reviews function

This function uses AI_GEN to analyze customer reviews and extract top reasons why guests enjoyed their stay.


In [0]:
%sql
-- This statement creates a function that takes a HOTEL_ID as input and generates a summary of the top 3 reasons why customers enjoyed their hotel stay.

CREATE OR REPLACE FUNCTION
summarize_customer_reviews(input_hotel_id STRING COMMENT 'ID of the hotel to be searched')
RETURNS STRING
LANGUAGE SQL
COMMENT 'This function takes a HOTEL_ID as input and generates a summary of the top 3 reasons why customers enjoyed their hotel stay'
RETURN (
  SELECT AI_GEN(
    SUBSTRING('Extract the top 3 reasons people like the hotels based on this list of reviews:' || ARRAY_JOIN(COLLECT_LIST(REVIEW_TEXT), ' - '), 1, 80000)
  ) AS all_reviews
  FROM denormalized_hotel_bookings
    WHERE `HOTEL_ID` = input_hotel_id
    -- Try to exclude negative reviews
    AND `REVIEW_RATING` >= 3
)

#### Test out summarize_customer_reviews

Verify that `summarize_customer_reviews` returns the expected result.

In [0]:
%sql
-- Try out this function to see the top 3 summary of customer reviews for a hotel

SELECT summarize_customer_reviews(:hotel_to_promote)

### 1.3: Create identify_target_customers function

This function finds the top 10 customers who have the fewest bookings but show the most interest (via page views and clicks) for a given hotel class. These are prime candidates for targeted marketing campaigns.

In [0]:
%sql
-- This function finds the top 10 customers who transacted the fewest bookings but showed the most interest (via page-views and page-clicks) for a given hotel class

CREATE OR REPLACE FUNCTION identify_target_customers (input_hotel_class STRING COMMENT 'Hotel class to filter by')
RETURNS TABLE (
  customer_email STRING,
  page_views INT,
  page_clicks INT,
  bookings INT
)
LANGUAGE SQL
COMMENT 'This function finds the top 10 customers who transacted the fewest bookings but showed the most interest (via page-views and page-clicks) for a given hotel class'
RETURN
  (
WITH filtered_clickstream AS (
  SELECT
    `CUSTOMER_EMAIL`,
    `ACTION`
  FROM
    `clickstream`
  WHERE
    `ACTION` IN ('page-view', 'page-click', 'booking-click')
),
filtered_dhb AS (
  SELECT
    `CUSTOMER_EMAIL`
  FROM
    `denormalized_hotel_bookings`
  WHERE
    `HOTEL_CLASS` = input_hotel_class
),
joined_table AS (
  SELECT
    a.`CUSTOMER_EMAIL`,
    a.`ACTION`
  FROM
    filtered_clickstream a
      JOIN filtered_dhb b
        ON a.`CUSTOMER_EMAIL` = b.`CUSTOMER_EMAIL`
),
ranked_customers AS (
  SELECT
    `CUSTOMER_EMAIL`,
    COUNT(
      CASE
        WHEN `ACTION` = 'page-view' THEN 1
      END
    ) AS `page_views`,
    COUNT(
      CASE
        WHEN `ACTION` = 'page-click' THEN 1
      END
    ) AS `page_clicks`,
    COUNT(
      CASE
        WHEN `ACTION` = 'booking-click' THEN 1
      END
    ) AS `bookings`,
    ROW_NUMBER() OVER (
        ORDER BY
          COUNT(
            CASE
              WHEN `ACTION` = 'booking-click' THEN 1
            END
          ) ASC,
          COUNT(
            CASE
              WHEN `ACTION` = 'page-view' THEN 1
            END
          ) DESC,
          COUNT(
            CASE
              WHEN `ACTION` = 'page-click' THEN 1
            END
          ) DESC
      ) AS `rank`
  FROM
    joined_table
  GROUP BY
    `CUSTOMER_EMAIL`
)
SELECT
  `CUSTOMER_EMAIL`,
  `page_views`,
  `page_clicks`,
  `bookings`
FROM
  ranked_customers
WHERE
  `rank` <= 10
ORDER BY
  `rank`
)

#### Validate the identify_target_customers function

Run this to check that the function returns a list of customers with their contact information.


In [0]:
%sql
-- Test out the identify_target_customers function

SELECT * FROM identify_target_customers(:hotel_class);

---

## Step 2: Install Dependencies and Create Agent

This step installs required packages, creates the `agent.py` file, and restarts Python to make the module importable.


### 2.1: Install Dependencies

Install the required Python packages for the LangGraph agent.

‚ö†Ô∏è **Important**: After this cell runs, Python will restart. Continue to the next cell after the restart.



In [0]:
%pip install -U -qqqq databricks-langchain langgraph==0.5.3 uv databricks-agents mlflow-skinny[databricks]
dbutils.library.restartPython()

#### Restore Variables from Widgets

After a Python restart, all variables in memory are cleared. However, `dbutils.widgets` values persist across restarts, so we can retrieve our configuration by reading from the widgets you populated in Step 0.


In [0]:
# After Python restart, variables are cleared BUT widgets persist! ‚ú®
print("üîÑ Retrieving configuration from widgets (they persist across restarts)...")
print()

# Get configuration from widgets (these survived the restart)
catalog = dbutils.widgets.get("catalog")
database = dbutils.widgets.get("database")
llm_endpoint = dbutils.widgets.get("llm_endpoint")

print("üìã Current configuration:")
print(f"  Catalog:      {catalog or '(empty - please fill widgets)'}")
print(f"  Database:     {database or '(empty - please fill widgets)'}")
print(f"  LLM Endpoint: {llm_endpoint or '(empty - please fill widgets)'}")

if not all([catalog, database, llm_endpoint]):
    print()
    print("‚ö†Ô∏è  Please go back to Step 0 and fill the widget form before continuing")
else:
    print()
    print("‚úÖ Ready to create agent.py!")

### 2.2: Write the Agent Code

This cell creates an `agent.py` file that defines a LangGraph-based conversational agent. The agent:

**Key Components:**
- **LLM**: Uses ChatDatabricks with your auto-detected endpoint
- **Tools**: Loads the 3 Unity Catalog functions you created in Step 1
- **System Prompt**: Instructs the agent to create targeted hotel marketing campaigns
- **Architecture**: LangGraph StateGraph with tool-calling workflow
- **MLflow Integration**: Wrapped in ResponsesAgent for deployment compatibility

**How It Works:**
1. Receives user queries (e.g., "Find low-performing hotels in the Resort class")
2. Calls the appropriate UC functions to retrieve data
3. Uses the LLM to synthesize results into marketing campaigns

‚úÖ **No manual editing needed** - configuration is automatically loaded from Step 0!

In [0]:
# === 2.2: Write the Agent Code (with dynamic configuration) ===

# Get configuration from widgets
catalog = dbutils.widgets.get("catalog")
database = dbutils.widgets.get("database")
llm_endpoint = dbutils.widgets.get("llm_endpoint")

# Validate configuration
if not all([catalog, database, llm_endpoint]):
    print("‚ùå Configuration incomplete!")
    print(f"  catalog:      {catalog or '(empty)' }")
    print(f"  database:     {database or '(empty)'}")
    print(f"  llm_endpoint: {llm_endpoint or '(empty)'}")
    print()
    print("üí° Please go back to Step 0 and fill the widget form")
    raise ValueError("Configuration incomplete")

# Build fully qualified function names
uc_function_get_hotel = f"{catalog}.{database}.get_hotel_to_promote"
uc_function_identify_customers = f"{catalog}.{database}.identify_target_customers"
uc_function_summarize_reviews = f"{catalog}.{database}.summarize_customer_reviews"

print("‚úÖ Configuration loaded from widgets")
print()
print(f"  Catalog:      {catalog}")
print(f"  Database:     {database}")
print(f"  LLM Endpoint: {llm_endpoint}")
print()
print("üìù Creating agent.py with your configuration...")

# Create agent.py with interpolated values
agent_code = f'''import os

import json
from typing import Annotated, Any, Generator, Optional, Sequence, TypedDict, Union
from uuid import uuid4

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    UCFunctionToolkit,
    VectorSearchRetrieverTool,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    convert_to_openai_messages,
)
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.entities import SpanType
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)

############################################
# CONFIGURATION (Auto-injected from notebook widgets)
############################################
LLM_ENDPOINT_NAME = "{llm_endpoint}"
CATALOG = "{catalog}"
DATABASE = "{database}"

# Build fully qualified Unity Catalog function names
UC_TOOL_NAMES = [
    "{uc_function_get_hotel}",
    "{uc_function_identify_customers}",
    "{uc_function_summarize_reviews}"
]

############################################
# Define LLM and system prompt
############################################
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME, temperature=0.0)

# Review and modify as needed
system_prompt = """You are a helpful assistant for a global hotel company. Your task is to assist the marketing leadership in understanding and acting on their products and sales metrics. You have access to functions that can retrieve and analyze relevant data.

    You have these three main tasks:

    1. Determine which hotel to promote (the 'selected hotel') based on the hotel_class parameter. Use the available function to identify which hotel is underperforming in that class and should be promoted.

    2. Craft a positive social marketing post to promote the selected hotel. Use customer review summaries and hotel descriptions to create compelling content. Mention the hotel by name, highlight positive aspects, but do not mention poor performance or flaws.

    3. Create a list of potential target customers to send the marketing post to. Use the available function to identify customers who might be interested based on their preferences and history.

    Format the results of all three tasks into a single cohesive output.

    Follow these guidelines:
    1. Use the available functions at each step and ensure results are retrieved before proceeding.
    2. Provide clear, coherent responses without mentioning the underlying function names.
    3. Do not reference the hotel by its ID.
    4. Mention the hotel's city and country.
    5. Answer only what the user asks for, no unnecessary information.
    6. If asked to generate social media posts, first determine what customers like most to ensure relevance.
    """

###############################################################################
## Define tools for your agent
###############################################################################
tools = []

# Unity Catalog functions as agent tools
uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES)
tools.extend(uc_toolkit.tools)

# Vector search tools (optional)
VECTOR_SEARCH_TOOLS = []
tools.extend(VECTOR_SEARCH_TOOLS)

#####################
## Define agent logic
#####################

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    custom_inputs: Optional[dict[str, Any]]
    custom_outputs: Optional[dict[str, Any]]


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolNode, Sequence[BaseTool]],
    system_prompt: Optional[str] = None,
):
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: AgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if isinstance(last_message, AIMessage) and last_message.tool_calls:
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{{"role": "system", "content": system_prompt}}] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: AgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)
        return {{"messages": [response]}}

    workflow = StateGraph(AgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {{
            "continue": "tools",
            "end": END,
        }},
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphResponsesAgent(ResponsesAgent):
    def __init__(self, agent):
        self.agent = agent

    def _langchain_to_responses(self, messages: list[BaseMessage]) -> list[dict[str, Any]]:
        "Convert from ChatCompletion dict to Responses output item dictionaries"
        for message in messages:
            message = message.model_dump()
            role = message["type"]
            if role == "ai":
                if tool_calls := message.get("tool_calls"):
                    return [
                        self.create_function_call_item(
                            id=message.get("id") or str(uuid4()),
                            call_id=tool_call["id"],
                            name=tool_call["name"],
                            arguments=json.dumps(tool_call["args"]),
                        )
                        for tool_call in tool_calls
                    ]
                else:
                    # Handle both string content and list of content blocks
                    content = message["content"]
                    if isinstance(content, list):
                        # Extract text from content blocks (for structured responses)
                        text_parts = []
                        for block in content:
                            if isinstance(block, dict):
                                text_parts.append(block.get("text", ""))
                            else:
                                text_parts.append(str(block))
                        text = " ".join(text_parts)
                    else:
                        text = content

                    return [
                        self.create_text_output_item(
                            text=text,
                            id=message.get("id") or str(uuid4()),
                        )
                    ]
            elif role == "tool":
                return [
                    self.create_function_call_output_item(
                        call_id=message["tool_call_id"],
                        output=message["content"],
                    )
                ]
            elif role == "user":
                return [message]

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        cc_msgs = self.prep_msgs_for_cc_llm([i.model_dump() for i in request.input])

        for event in self.agent.stream({{"messages": cc_msgs}}, stream_mode=["updates", "messages"]):
            if event[0] == "updates":
                for node_data in event[1].values():
                    for item in self._langchain_to_responses(node_data["messages"]):
                        yield ResponsesAgentStreamEvent(type="response.output_item.done", item=item)
            # filter the streamed messages to just the generated text messages
            elif event[0] == "messages":
                try:
                    chunk = event[1][0]
                    if isinstance(chunk, AIMessageChunk) and (content := chunk.content):
                        yield ResponsesAgentStreamEvent(
                            **self.create_text_delta(delta=content, item_id=chunk.id),
                        )
                except Exception as e:
                    print(e)


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
mlflow.langchain.autolog()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphResponsesAgent(agent)
mlflow.models.set_model(AGENT)
'''

# Write the file
with open('agent.py', 'w') as f:
    f.write(agent_code)

print()
print("=" * 80)
print("‚ú® agent.py created with your configuration!")
print("=" * 80)
print()
print("üìã Configuration injected:")
print(f"  - LLM endpoint: {llm_endpoint}")
print(f"  - UC functions: {catalog}.{database}.*")
print()
print("‚è≠Ô∏è  Continue to the next cell to restart Python...")

### 2.3: Restart Python

After writing `agent.py`, we need to restart the Python kernel so the new module can be imported.

**Why is this necessary?**
- When you use `%%writefile agent.py`, you create a new Python file on disk
- Python's import system only scans for modules at startup
- Without a restart, `from agent import AGENT` would fail with `ModuleNotFoundError`
- Restarting clears the import cache and allows Python to discover the new file

‚ö†Ô∏è **Note**: After restart, you'll need to re-establish configuration variables from Step 0 in the next cell.

In [0]:
# Restart Python to make the newly created agent.py module importable
# After this restart, we'll re-load configuration and import the agent
dbutils.library.restartPython()

In [0]:
import os

# After Python restart, variables are cleared BUT widgets persist! ‚ú®
print("üîÑ Retrieving configuration from widgets (they persist across restarts)...")
print()

# Get configuration from widgets (these survived the restart)
llm_endpoint = dbutils.widgets.get("llm_endpoint")
catalog = dbutils.widgets.get("catalog")
database = dbutils.widgets.get("database")

# Validate configuration
if not all([llm_endpoint, catalog, database]):
    print("‚ùå Configuration incomplete!")
    print(f"  llm_endpoint: {llm_endpoint or '(empty)'}")
    print(f"  catalog:      {catalog or '(empty)'}")
    print(f"  database:     {database or '(empty)'}")
    print()
    print("üí° Please go back to Step 0 and:")
    print("   1. Run cell 0.2 to auto-detect values")
    print("   2. Copy the values into the widget form at the top")
    print("   3. Come back and re-run this cell")
    raise ValueError("Configuration incomplete")

print("‚úÖ Configuration retrieved from widgets")

# Set environment variables for agent.py
os.environ["CATALOG"] = catalog
os.environ["DATABASE"] = database
os.environ["LLM_ENDPOINT"] = llm_endpoint

print()
print("üîß Configuration loaded for agent:")
print(f"  Catalog:      {catalog}")
print(f"  Database:     {database}")
print(f"  LLM Endpoint: {llm_endpoint}")
print()

# Import the agent
from agent import AGENT
print("‚ú® Agent ready!")


## Step 3: Initialize and Test the Agent

Now that `agent.py` is created (in Step 2), we'll import and test the agent locally before deploying.

**How the agent works:**
- **LangGraph** for agent orchestration
- **ResponsesAgent** interface for MLflow compatibility  
- **UCFunctionToolkit** to load your Unity Catalog functions as tools
- **Configuration** passed via environment variables:
  - `LLM_ENDPOINT` - Your LLM endpoint name
  - `CATALOG` - Unity Catalog name
  - `DATABASE` - Database/schema name

In the cells below, we can test the agent with sample queries using both **non-streaming** and **streaming** prediction methods.

The agent will automatically call the appropriate UC functions to answer user queries.

### 3.1: Non-Streaming Prediction

Returns the complete agent response in a single call after all processing (tool calls, LLM generation, etc.) is finished. Best for testing, debugging, and batch processing where you need the full result at once rather than incremental updates.

In [0]:
from agent import AGENT

# Get hotel class from widget for dynamic testing
hotel_class = dbutils.widgets.get("hotel_class")

result = AGENT.predict(
    {
        "input": [
            {
                "role": "user",
                "content": f"Find the hotel with the fewest bookings and positive customer reviews in the {hotel_class} class.",
            }
        ]
    }
)
display(result.model_dump(exclude_none=True))

### 3.2: Streaming Prediction

Returns response in chunks as it's generated (better UX for real-time applications).

In [0]:
from agent import AGENT

# Get hotel class from widget for dynamic testing
hotel_class = dbutils.widgets.get("hotel_class")

# Collect response text
response_text = ""

for chunk in AGENT.predict_stream(
    {
        "input": [
            {
                "role": "user",
                "content": f"Find the hotel with the fewest bookings and positive customer reviews in the {hotel_class} class.",
            }
        ]
    }
):
    chunk_data = chunk.model_dump(exclude_none=True)

    # Extract and display text deltas in real-time
    if chunk_data.get("type") == "response.output_text.delta":
        delta = chunk_data.get("delta", "")
        print(delta, end="", flush=True)
        response_text += delta

print("\n\n" + "=" * 80)
print("‚úÖ Complete response:")
print("=" * 80)
print(response_text)

---

## Step 4: Register and Deploy Agent

Now we'll log, register, and deploy the agent to Databricks Model Serving.

### 4.1: Log the Agent to MLflow

Package the agent as an MLflow model for deployment.

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
from agent import UC_TOOL_NAMES, VECTOR_SEARCH_TOOLS
import mlflow
from mlflow.models.resources import DatabricksFunction
from pkg_resources import get_distribution

resources = []
for tool in VECTOR_SEARCH_TOOLS:
    resources.extend(tool.resources)
for tool_name in UC_TOOL_NAMES:
    resources.append(DatabricksFunction(function_name=tool_name))

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        pip_requirements=[
            "databricks-langchain",
            f"langgraph=={get_distribution('langgraph').version}",
            f"databricks-connect=={get_distribution('databricks-connect').version}",
        ],
        resources=resources,
    )

### 4.2: Evaluate the Agent (Optional)

Use Mosaic AI Agent Evaluation to evaluate the agent's responses based on expected responses and other evaluation criteria. This helps you validate quality before deploying.

See Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-evaluation) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-evaluation/)).


In [0]:
from agent import AGENT

# Optional: Evaluate the agent with sample queries
import mlflow
from mlflow.genai.scorers import RelevanceToQuery, Safety

# Use list of dicts with "inputs" key (NOT pandas DataFrame)
eval_dataset = [
    {
        "inputs": {
            "input": [
                {
                    "role": "user",
                    "content": "Which hotel in the Airport class should we promote this quarter?",
                }
            ]
        },
        "expected_response": "Should identify a specific Airport hotel with low bookings and provide its name, location, and reasoning.",
    },
    {
        "inputs": {
            "input": [
                {
                    "role": "user",
                    "content": "Find customers who might be interested in a Resort hotel promotion",
                }
            ]
        },
        "expected_response": "Should provide a list of customer IDs or profiles who have shown interest in resorts.",
    },
]

# Use predict_fn with lambda to pass input correctly
eval_results = mlflow.genai.evaluate(
    data=eval_dataset,
    predict_fn=lambda input: AGENT.predict({"input": input}),
    scorers=[RelevanceToQuery(), Safety()],
)

# Review the evaluation results in the MLflow UI (see console output)

### 4.3: Pre-deployment Validation

Before registering and deploying, validate that the packaged model works correctly using the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API.

See Databricks documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/model-serving-debug.html#validate-inputs) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/model-serving-debug#before-model-deployment-validation-checks)).


In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={
        "input": [
            {
                "role": "user",
                "content": "Find the lowest performing hotel with good reviews in the Resort class",
            }
        ]
    },
    env_manager="uv",
)


### 4.4: Register the Model to Unity Catalog

Now that your agent is tested and validated, it's time to register it to Unity Catalog. This is an important step that bridges development and production.

#### What Unity Catalog Registration Does

When you register an MLflow model to Unity Catalog, you're doing more than just saving it - you're:

- **Creating a Governed Asset**: The model becomes a first-class data asset with permissions, lineage, and audit trails
- **Enabling Version Control**: Each registration creates a new version, making it easy to roll back or compare models
- **Providing Discoverability**: Teams across your organization can find and reuse your agent through the Unity Catalog UI
- **Establishing Lineage**: Unity Catalog tracks where the model came from (training data, code, dependencies) and where it's deployed

#### Why This Matters for Production

In production environments, you need more than just a working model - you need governance:

- **Access Control**: Decide who can view, use, or update your agent
- **Compliance**: Track all changes and access for auditing purposes
- **Collaboration**: Share agents across teams with clear ownership and documentation
- **Lifecycle Management**: Transition models through stages (Staging ‚Üí Production) with approval workflows

Think of Unity Catalog registration as moving your agent from your personal workspace into a shared, governed production environment where teams can rely on it.

In [0]:
mlflow.set_registry_uri("databricks-uc")

catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("database")  # Note: we call it "database" in the widget

# TODO: rename the model if you want
model_name = "river_hotel_marketing_agent"
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)



### 4.5: Deploy the Agent to Production

The `agents.deploy()` function automatically creates a production-ready serving endpoint - no manual UI configuration needed!

#### What Gets Created

- **Agent Endpoint**: `{catalog}.{schema}.river_hotel_marketing_agent` - REST API for your agent
- **Feedback Model**: `{catalog}.{schema}.feedback` - Collects user ratings for evaluation
- **Auto-Scaling Infrastructure**: Compute, load balancing, and zero-downtime updates

#### What Happens

- **First Run**: Creates a new serving endpoint
- **Subsequent Runs**: Updates the existing endpoint to the new model version

#### Testing Your Deployed Agent

After deployment completes, you have two ways to interact with your agent:

1. **AI Playground**: Navigate to the AI Playground in Databricks UI - your agent will appear in the dropdown. This provides a chat interface for natural language testing.
2. **Serving Page**: Go to **Serving** ‚Üí find your endpoint ‚Üí view status, logs, metrics, and send test queries via the built-in query UI.

No additional configuration or manual endpoint creation is needed - just deploy and start testing!

In [0]:
from databricks import agents
import re

# Deploy the agent
deployment_info = agents.deploy(
    UC_MODEL_NAME,
    uc_registered_model_info.version,
    tags={"endpointSource": "docs"},
)

# Print deployment details
print("=" * 80)
print("üöÄ AGENT DEPLOYED SUCCESSFULLY!")
print("=" * 80)
print(f"\nüì¶ Model: {UC_MODEL_NAME}")
print(f"üî¢ Version: {uc_registered_model_info.version}")
print(f"\nüîó Endpoint Name: {deployment_info.endpoint_name}")

# Get workspace URL and construct links
workspace_url = spark.conf.get("spark.databricks.workspaceUrl")
endpoint_name_encoded = deployment_info.endpoint_name.replace(".", "-")

# Direct links to the deployed resources
serving_endpoint_url = f"https://{workspace_url}/ml/endpoints/{endpoint_name_encoded}"
ai_playground_url = f"https://{workspace_url}/ml/ai-playground"

print(f"\nüåê Serving Endpoint UI:")
print(f"   {serving_endpoint_url}")
print(f"\nüéÆ AI Playground (test your agent with chat interface):")
print(f"   {ai_playground_url}")

print("\n" + "=" * 80)

## Recap

Congratulations! You successfully:

- **Built Agent Tools**: Created three Unity Catalog SQL functions to identify underperforming hotels, analyze customer reviews, and target potential customers
- **Packaged with LangGraph**: Wrapped the tools in a LangGraph agent using MLflow's ResponsesAgent interface for production deployment
- **Deployed to Production**: Made the agent accessible via AI Playground and REST API for real-time marketing recommendations

## Next Steps

1. Navigate to the Serving Endpoints to monitor for when your agent is ready.
2. Once the endpoint is successfully deployed, then navigate to the AI Playground and 