#Agent notebook

This is very similar to an auto-generated notebook created by an AI Playground export. There are three notebooks in the same folder:
- [**agent**]($./agent): contains the code to build the agent.

This notebook uses Mosaic AI Agent Framework ([AWS](https://docs.databricks.com/en/generative-ai/retrieval-augmented-generation.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/retrieval-augmented-generation)) to create your agent. It defines a LangChain agent that has access to tools, which we define in this notebook as well.

Use this notebook to iterate on and modify the agent. For example, you could add more tools or change the system prompt.

 **_NOTE:_**  This notebook uses LangChain, however AI Agent Framework is compatible with other agent frameworks like Pyfunc and LlamaIndex.

## Next steps

After testing and iterating on your agent in this notebook, go to the auto-generated [agent-eval-deployment]($./agent-eval-deployment) notebook in this folder to log, register, evaluate, and deploy the agent.

In [0]:
%pip install --upgrade databricks-agents unitycatalog-ai[databricks] unitycatalog-langchain[databricks] databricks-langchain databricks-vectorsearch==0.56 langchain==0.3.20 langgraph==0.3.4 pydantic==2.11.7 mlflow[databricks]
#langchain-community langgraph langgraph-checkpoint langchain_core==0.3.67
dbutils.library.restartPython()

In [0]:
import mlflow
from mlflow.models import ModelConfig

mlflow.langchain.autolog()

## Define the chat model and tools
Create a LangChain chat model that supports [LangGraph tool](https://langchain-ai.github.io/langgraph/how-tos/tool-calling/) calling.

We'll be importing tools from UC as well as defining a retriever. See [LangChain - How to create tools](https://python.langchain.com/v0.2/docs/how_to/custom_tools/) and [LangChain - Using built-in tools](https://python.langchain.com/v0.2/docs/how_to/tools_builtin/).

 **_NOTE:_**  This notebook uses LangChain, however AI Agent Framework is compatible with other agent frameworks like Pyfunc and LlamaIndex.

In [0]:
from databricks_langchain import ChatDatabricks
from langchain_community.tools.databricks import UCFunctionToolkit
from databricks.sdk import WorkspaceClient

llm_endpoint = "databricks-claude-3-7-sonnet"
llm = ChatDatabricks(endpoint=llm_endpoint)

# Create Retriever Tool

In [0]:
from langchain.tools.retriever import create_retriever_tool
from databricks_langchain.vectorstores import DatabricksVectorSearch
from langchain_community.tools.databricks import UCFunctionToolkit

CATALOG = 'media_advertising'
SCHEMA = 'contextual_advertising'

vs_endpoint = 'one-env-shared-endpoint-10'
index_name = f'{CATALOG}.{SCHEMA}.movie_scripts_content_vs'

# Connect to an existing Databricks Vector Search endpoint and index
vector_store = DatabricksVectorSearch(
  endpoint= vs_endpoint, 
  index_name= index_name, 
  columns=[
    "unique_movie_id",
    "title",
    "scene_number",
    "scene_text"
  ]
).as_retriever(search_kwargs={"k": 5}) 
#This parameter determines how many results are returned - important for retrieval tuning

# Create a tool object that performs retrieval against our vector search index
retriever_tool = create_retriever_tool(
  vector_store,
  name="search_movie_scripts", 
  description="Use this tool to search for relevant script chunks from the movie scripts database and provide a recommendation of where to insert ad placements.", 
)

# Specify the return type schema of our retriever, so that evaluation and UIs can
# automatically display retrieved chunks
mlflow.models.set_retriever_schema(
    primary_key="unique_movie_id",
    text_column="scene_text",
    other_columns=["title", "scene_number"],
    name= index_name,
)

tools = [retriever_tool]

## Output parsers
Databricks interfaces, such as the AI Playground, can optionally display pretty-printed tool calls.

Use the following helper functions to parse the LLM's output into the expected format.

In [0]:
from typing import Iterator, Dict, Any
from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    ToolMessage,
    MessageLikeRepresentation,
)

import json

def stringify_tool_call(tool_call: Dict[str, Any]) -> str:
    """
    Convert a raw tool call into a formatted string that the playground UI expects if there is enough information in the tool_call
    """
    try:
        request = json.dumps(
            {
                "id": tool_call.get("id"),
                "name": tool_call.get("name"),
                "arguments": json.dumps(tool_call.get("args", {})),
            },
            indent=2,
        )
        return f"<tool_call>{request}</tool_call>"
    except:
        return str(tool_call)


def stringify_tool_result(tool_msg: ToolMessage) -> str:
    """
    Convert a ToolMessage into a formatted string that the playground UI expects if there is enough information in the ToolMessage
    """
    try:
        result = json.dumps(
            {"id": tool_msg.tool_call_id, "content": tool_msg.content}, indent=2
        )
        return f"<tool_call_result>{result}</tool_call_result>"
    except:
        return str(tool_msg)


def parse_message(msg) -> str:
    """Parse different message types into their string representations"""
    # tool call result
    if isinstance(msg, ToolMessage):
        return stringify_tool_result(msg)
    # tool call
    elif isinstance(msg, AIMessage) and msg.tool_calls:
        tool_call_results = [stringify_tool_call(call) for call in msg.tool_calls]
        return "".join(tool_call_results)
    # normal HumanMessage or AIMessage (reasoning or final answer)
    elif isinstance(msg, (AIMessage, HumanMessage)):
        return msg.content
    else:
        print(f"Unexpected message type: {type(msg)}")
        return str(msg)


def wrap_output(stream: Iterator[MessageLikeRepresentation]) -> Iterator[str]:
    """
    Process and yield formatted outputs from the message stream.
    The invoke and stream langchain functions produce different output formats.
    This function handles both cases.
    """
    for event in stream:
        # the agent was called with invoke()
        if "messages" in event:
            for msg in event["messages"]:
                yield parse_message(msg) + "\n\n"
        # the agent was called with stream()
        else:
            for node in event:
                for key, messages in event[node].items():
                    if isinstance(messages, list):
                        for msg in messages:
                            yield parse_message(msg) + "\n\n"
                    else:
                        print("Unexpected value {messages} for key {key}. Expected a list of `MessageLikeRepresentation`'s")
                        yield str(messages)

## Create the agent
Here we provide a simple graph that uses the model and tools defined by [config.yml]($./config.yml). This graph is adapated from [this LangGraph guide](https://langchain-ai.github.io/langgraph/how-tos/react-agent-from-scratch/).


To further customize your LangGraph agent, you can refer to:
* [LangGraph - Quick Start](https://langchain-ai.github.io/langgraph/tutorials/introduction/) for explanations of the concepts used in this LangGraph agent
* [LangGraph - How-to Guides](https://langchain-ai.github.io/langgraph/how-tos/) to expand the functionality of your agent


In [0]:
from typing import (
    Annotated,
    Optional,
    Sequence,
    TypedDict,
    Union,
)

from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
    BaseMessage,
    SystemMessage,
)
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool

from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
#from langgraph.prebuilt.tool_executor import ToolExecutor
from langchain_core.tools import BaseTool
from langgraph.prebuilt.tool_node import ToolNode


# We create the AgentState that we will pass around
# This simply involves a list of messages
class AgentState(TypedDict):
    """The state of the agent."""

    messages: Annotated[Sequence[BaseMessage], add_messages]


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolNode, Sequence[BaseTool]],
    agent_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: AgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there is no function call, then we finish
        if not last_message.tool_calls:
            return "end"
        else:
            return "continue"

    if agent_prompt:
        system_message = SystemMessage(content=agent_prompt)
        preprocessor = RunnableLambda(
            lambda state: [system_message] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    # Define the function that calls the model
    def call_model(
        state: AgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)
        return {"messages": [response]}

    workflow = StateGraph(AgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        # First, we define the start node. We use agent.
        # This means these are the edges taken after the agent node is called.
        "agent",
        # Next, we pass in the function that will determine which node is called next.
        should_continue,
        # The mapping below will be used to determine which node to go to
        {
            # If tools, then we call the tool node.
            "continue": "tools",
            # END is a special node marking that the graph should finish.
            "end": END,
        },
    )
    # We now add a unconditional edge from tools to agent.
    workflow.add_edge("tools", "agent")

    return workflow.compile()

In [0]:
from langchain_core.runnables import RunnableGenerator
from mlflow.langchain.output_parsers import ChatCompletionsOutputParser

agent_system_prompt = "You are a RAG chatbot working for a television network company. You will be used to help make recommendations for optimal placement of advertisements based on program scripts. For this example, you have access to movie scripts, but the intention is to air commercials. Based on the user query, retrieve the most relevant scripts from the movie scenes and synthesize this information, combined with the proposed advertisement, into a helpful response that accurately recommends where to place an advertisement. Make sure the recommendation includes the title of the movie and the scene number you recommend placing the advertisement after in addition to the context and summary you provide. The scene number is a field returned by the retriever tool you have access to - it will be returned to you as 'scene_number'."

# Create the agent with the system message if it exists
try:
    agent_prompt = agent_system_prompt
    mlflow.langchain.autolog()
    agent_with_raw_output = create_tool_calling_agent(
        llm, 
        tools, 
        agent_prompt=agent_prompt
    )
except KeyError:
    agent_with_raw_output = create_tool_calling_agent(llm, tools)

agent = agent_with_raw_output | RunnableGenerator(wrap_output) | ChatCompletionsOutputParser()

## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

In [0]:
# TODO: replace this placeholder input example with an appropriate domain-specific example for your agent
for event in agent.stream({"messages": [{"role": "user", "content": "When could I insert a commercial for a dog food product targeting 18-35 year old dog owners?"}]}):
    print(event, "---" * 20 + "\n")

In [0]:
for event in agent.stream({"messages": [{"role": "user", "content": "When could I insert a commercial for cat food?"}]}):
    print(event, "---" * 20 + "\n")

In [0]:
# Log agent 
mlflow.models.set_model(agent)