#Agent notebook

This is very similar to an auto-generated notebook created by an AI Playground export. There are three notebooks in the same folder:
- [**agent**]($./agent): contains the code to build the agent.

This notebook uses Mosaic AI Agent Framework ([AWS](https://docs.databricks.com/en/generative-ai/retrieval-augmented-generation.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/retrieval-augmented-generation)) to create your agent. It defines a LangChain agent that has access to tools, which we define in this notebook as well.

Use this notebook to iterate on and modify the agent. For example, you could add more tools or change the system prompt.

 **_NOTE:_**  This notebook uses LangChain, however AI Agent Framework is compatible with other agent frameworks like Pyfunc and LlamaIndex.

## Next steps

After testing and iterating on your agent in this notebook, go to the auto-generated [agent-eval-deployment]($./agent-eval-deployment) notebook in this folder to log, register, evaluate, and deploy the agent.

In [0]:
%pip install --upgrade databricks-agents unitycatalog-ai[databricks] unitycatalog-langchain[databricks] databricks-langchain databricks-vectorsearch==0.56 langchain==0.3.20 langgraph==0.3.4 pydantic==2.11.7 mlflow[databricks]
#langchain-community langgraph langgraph-checkpoint langchain_core==0.3.67
dbutils.library.restartPython()

## Define the chat model and tools
Create a LangChain chat model that supports [LangGraph tool](https://langchain-ai.github.io/langgraph/how-tos/tool-calling/) calling.

We'll be importing tools from UC as well as defining a retriever. See [LangChain - How to create tools](https://python.langchain.com/v0.2/docs/how_to/custom_tools/) and [LangChain - Using built-in tools](https://python.langchain.com/v0.2/docs/how_to/tools_builtin/).

 **_NOTE:_**  This notebook uses LangChain, however AI Agent Framework is compatible with other agent frameworks like Pyfunc and LlamaIndex.

# Define Retriever Tool

We'll use the vector retreival tool in conjunction with a Unity Catalog function, to enrich the output with some additional data that the agent will use downstream

In [0]:
from langchain.tools.retriever import create_retriever_tool
from databricks_langchain.vectorstores import DatabricksVectorSearch
from langchain_community.tools.databricks import UCFunctionToolkit
import mlflow
from mlflow.models import ModelConfig

mlflow.langchain.autolog()
CATALOG = 'media_advertising'
SCHEMA = 'contextual_advertising'

vs_endpoint = 'one-env-shared-endpoint-10'
index_name = f'{CATALOG}.{SCHEMA}.movie_scripts_content_vs'

# Specify the return type schema of our retriever, so that evaluation and UIs can
# automatically display retrieved chunks
mlflow.models.set_retriever_schema(
    primary_key="unique_movie_scene_id",
    text_column="scene_text",
    other_columns=["title", "scene_number"],
    name= index_name,
)

In [None]:
%sql
CREATE OR REPLACE FUNCTION media_advertising.contextual_advertising.search_movie_scripts (
  -- The agent uses this comment to determine how to generate the query string parameter.
  query STRING
  COMMENT 'The query string for searching the movie script database'
) RETURNS TABLE
-- The agent uses this comment to determine when to call this tool. It describes the types of documents and information contained within the index.
COMMENT 'Search for relevant script chunks from the movie scripts database that matches the intent of the user request including elements of the scene that would make it a good fit' RETURN
SELECT
  scene_text as page_content,
  map('title', TRY_CAST(title AS STRING), 'scene_number', TRY_CAST(scene_number AS STRING), 'search_score', TRY_CAST(search_score AS STRING)) as metadata
FROM
  vector_search(
    -- Specify your Vector Search index name here
    index => 'media_advertising.contextual_advertising.movie_scripts_content_vs',
    query_text => query,
    query_type => "hybrid",
    num_results => 5
  )

## Create the agent
Here we provide a simple graph that uses the model and tools defined by [config.yml]($./config.yml). This graph is adapated from [this LangGraph guide](https://langchain-ai.github.io/langgraph/how-tos/react-agent-from-scratch/).


To further customize your LangGraph agent, you can refer to:
* [LangGraph - Quick Start](https://langchain-ai.github.io/langgraph/tutorials/introduction/) for explanations of the concepts used in this LangGraph agent
* [LangGraph - How-to Guides](https://langchain-ai.github.io/langgraph/how-tos/) to expand the functionality of your agent


In [0]:
%%writefile agent.py
import json
from typing import (
    Any,
    Annotated,
    Optional,
    Generator,
    Sequence,
    TypedDict,
    Union,
)
from uuid import uuid4

import mlflow
from databricks_langchain import ChatDatabricks
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    convert_to_openai_messages,
)
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from langgraph.graph.message import add_messages
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)
from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit

# --- Setup LLM and tools ---
function_name = "media_advertising.contextual_advertising.search_movie_scripts"
toolkit = UCFunctionToolkit(
    function_names=[
        function_name
    ]
)
tools = toolkit.tools
llm_endpoint = "databricks-claude-sonnet-4"

index_name = f'media_advertising.contextual_advertising.movie_scripts_content_vs'
llm = ChatDatabricks(endpoint=llm_endpoint)

# --- Agent Graph Definition ---
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    custom_inputs: Optional[dict[str, Any]]
    custom_outputs: Optional[dict[str, Any]]

def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolNode, Sequence[BaseTool]],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)
    def should_continue(state: AgentState):
        messages = state["messages"]
        last_message = messages[-1]
        if isinstance(last_message, AIMessage) and last_message.tool_calls:
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])

    model_runnable = preprocessor | model
    def call_model(state: AgentState, config: RunnableConfig):
        response = model_runnable.invoke(state, config)
        return {"messages": [response]}

    workflow = StateGraph(AgentState)
    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ToolNode(tools))
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges("agent", should_continue, {"continue": "tools", "end": END})
    workflow.add_edge("tools", "agent")
    return workflow.compile()

# --- Agent Wrapper Class ---
class LangGraphResponsesAgent(ResponsesAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def _responses_to_cc(self, message: dict[str, Any]) -> list[dict[str, Any]]:
        """Convert from a Responses API output item to ChatCompletion messages."""
        msg_type = message.get("type")
        if msg_type == "function_call":
            return [
                {
                    "role": "assistant",
                    "content": "tool call",
                    "tool_calls": [
                        {
                            "id": message["call_id"],
                            "type": "function",
                            "function": {
                                "arguments": message["arguments"],
                                "name": message["name"],
                            },
                        }
                    ],
                }
            ]
        elif msg_type == "message" and isinstance(message["content"], list):
            return [
                {"role": message["role"], "content": content["text"]}
                for content in message["content"]
            ]
        elif msg_type == "reasoning":
            return [{"role": "assistant", "content": json.dumps(message["summary"])}]
        elif msg_type == "function_call_output":
            return [
                {
                    "role": "tool",
                    "content": message["output"],
                    "tool_call_id": message["call_id"],
                }
            ]
        compatible_keys = ["role", "content", "name", "tool_calls", "tool_call_id"]
        filtered = {k: v for k, v in message.items() if k in compatible_keys}
        return [filtered] if filtered else []

    def _langchain_to_responses(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
        "Convert from ChatCompletion dict to Responses output item dictionaries"
        for message in messages:
            message = message.model_dump()
            role = message["type"]
            if role == "ai":
                if tool_calls := message.get("tool_calls"):
                    return [
                        self.create_function_call_item(
                            id=message.get("id") or str(uuid4()),
                            call_id=tool_call["id"],
                            name=tool_call["name"],
                            arguments=json.dumps(tool_call["args"]),
                        )
                        for tool_call in tool_calls
                    ]
                else:
                    return [
                        self.create_text_output_item(
                            text=message["content"],
                            id=message.get("id") or str(uuid4()),
                        )
                    ]
            elif role == "tool":
                return [
                    self.create_function_call_output_item(
                        call_id=message["tool_call_id"],
                        output=message["content"],
                    )
                ]
            elif role == "user":
                return [message]

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        cc_msgs = []
        for msg in request.input:
            cc_msgs.extend(self._responses_to_cc(msg.model_dump()))

        for event in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates", "messages"]):
            if event[0] == "updates":
                for node_data in event[1].values():
                    for item in self._langchain_to_responses(node_data["messages"]):
                        yield ResponsesAgentStreamEvent(type="response.output_item.done", item=item)
            elif event[0] == "messages":
                try:
                    chunk = event[1][0]
                    if isinstance(chunk, AIMessageChunk) and (content := chunk.content):
                        yield ResponsesAgentStreamEvent(
                            **self.create_text_delta(delta=content, item_id=chunk.id),
                        )
                except Exception as e:
                    print(e)

# --- Agent Initialization ---
agent_system_prompt = """You are a RAG chatbot working for a television network company. You will be used to help make recommendations for optimal placement of advertisements based on program scripts. For this example, you have access to movie scripts, but the intention is to air commercials. Based on the user query, query the retreiver tool to get the most relevant scripts from the movie scenes and synthesize this information, combined with the proposed advertisement into a helpful response that accurately recommends the best scene to place an advertisement. The retreiver will return required metadata as well as the similarity score.

Only make a second request to the retreiver tool if the initial scenes returned are extremely poor, or have low similarity scores. Similarity scores under 0.80 should be deemed poor, if any of the requests are above 0.80, do not make an additional request. Do not make more than 2 requests to the retriever tool. Include relevant keywords in the request to the retriever tool that would be reflected in a scene that matches the user intent.

The title and scene number are fields returned by the retriever tool you have access to - it will be returned to you as 'scene_number' from the vector search retreival tool, and the title will be returned as 'title'. Structure the response in the following manner, provide nothing else outside of the structure below, the structure includes the desired markdown formatting:

# Movie Title: INSERT MOVIE
# Scene Number: INSERT SCENE NUMBER
# Scene Description: 
  Describe the scene in 1-3 sentences
# Scene Justification 
  Provide a justification for why this scene is the best option in 2-3 sentences
"""
# The agent object is now a pyfunc model that can be logged.
mlflow.langchain.autolog()
agent = create_tool_calling_agent(llm, tools, system_prompt=agent_system_prompt)
AGENT = LangGraphResponsesAgent(agent)
mlflow.models.set_model(AGENT)


In [None]:
dbutils.library.restartPython()

## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

In [0]:
from agent import AGENT
result = AGENT.predict({"input": [{"role": "user", "content": "When could I insert a commercial for an alien pet food brand in a movie based underwater"}]})
print(result.model_dump(exclude_none=True))

In [0]:
for event in AGENT.predict_stream({"input": [{"role": "user", "content": "Where should I place the following advertisement: This advertisement for Bricks High Quality Pet Food shows a bulldog with its tongue out standing next to a large bag of their Premium Blend teacup pig food on a grassy outdoor setting. for the following target audience: Spanish-speaking women age 18-34 primarily based in the american southwest"}]}):
    event.model_dump(exclude_none=True)
    

In [0]:
# Log the model to MLflow
import os
import mlflow
from agent import function_name, llm_endpoint, index_name
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint, DatabricksVectorSearchIndex
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool
from pkg_resources import get_distribution

print(f'mlflow version: {mlflow.__version__}')

input_example = {
    "input": [
        {
            "role": "user",
            "content": "When could I insert a commercial for a light hearted basketball-themed comedy movie we want to promote for next summer?"
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name='agent',
        python_model="agent.py",
        pip_requirements=[
            "databricks-langchain",
            f"langgraph=={get_distribution('langgraph').version}",
            f"databricks-connect=={get_distribution('databricks-connect').version}",
        ],
        resources=[
            DatabricksServingEndpoint(endpoint_name=llm_endpoint),
            DatabricksVectorSearchIndex(index_name=index_name),
            DatabricksFunction(function_name=function_name)
        ],
        input_example=input_example
    )


In [None]:
from databricks.sdk import WorkspaceClient
import mlflow
import os

mlflow.set_registry_uri("databricks-uc")

# Use the workspace client to retrieve information about the current user
w = WorkspaceClient()
user_email = w.current_user.me().display_name
username = user_email.split("@")[0]

# Catalog and schema have been automatically created
catalog_name = 'media_advertising'
schema_name = 'contextual_advertising'

# TODO: define the catalog, schema, and model name for your UC model
model_name = "movie_scripts_placement_agent" # Change to a different model name if desired
UC_MODEL_NAME = f"{catalog_name}.{schema_name}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME)

In [None]:
import mlflow

client = mlflow.MlflowClient()

CATALOG = 'media_advertising'
SCHEMA = 'contextual_advertising'
model_name = "movie_scripts_placement_agent" # Change to a different model name if desired
UC_MODEL_NAME = f"{CATALOG}.{SCHEMA}.{model_name}"

uc_registered_model_detail = client.get_model_version(name=UC_MODEL_NAME, version=uc_registered_model_info.version)
print(uc_registered_model_detail)

In [None]:
from databricks import agents

# Deploy the model to the review app and a model serving endpoint
agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version)