You may find this series of notebooks at [databricks-solutions/realtime-rag-agents-databricks-youcom](https://github.com/databricks-solutions/realtime-rag-agents-databricks-youcom). For more information about this solution accelerator, visit the [blog post](https://you.com/articles/unlocking-real-time-intelligence-for-ai-agents-with-you.com-and-databricks).

# Agent definition

The UC Function from above is used as a tool within the Agent.

The code below demonstrates an approach to create the MLflow Trace span for the UC Function call, and convert from the You.com `hit` structure to the expected Document structure. This [RETRIEVER](https://mlflow.org/docs/latest/genai/tracing/concepts/span/#retriever-spans) span enables the [RetrievalGroundedness scorer](https://docs.databricks.com/aws/en/mlflow3/genai/eval-monitor/concepts/judges/is_grounded), used in Evaluation steps in this notebook, to evaluate the content retrieved from [You.com](https://you.com)

---
Be aware of all the `NOTES` left in the code. There are some areas that you can customize your agent moving forward!

In [None]:
import os

# Set up widgets to accept job parameters
dbutils.widgets.text("catalog", "main", "UC Catalog")
dbutils.widgets.text("schema", "default", "UC Schema")
dbutils.widgets.text("llm_endpoint_name", "databricks-claude-3-7-sonnet", "LLM Endpoint Name")

# Set environment variables for agent.py to use
# agent.py reads from environment variables since it's a module that gets imported
os.environ["AGENT_CATALOG"] = dbutils.widgets.get("catalog")
os.environ["AGENT_SCHEMA"] = dbutils.widgets.get("schema")
os.environ["AGENT_LLM_ENDPOINT"] = dbutils.widgets.get("llm_endpoint_name")

print(f"Agent configuration:")
print(f"  Catalog: {os.environ['AGENT_CATALOG']}")
print(f"  Schema: {os.environ['AGENT_SCHEMA']}")
print(f"  LLM Endpoint: {os.environ['AGENT_LLM_ENDPOINT']}")

In [None]:
%%writefile agent.py
from typing import Any, Generator, Optional, Sequence, Union

import mlflow
import json
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from mlflow.entities import SpanType, Document
import os

mlflow.langchain.autolog()

# Load configuration from environment variables or defaults
# These are set by the notebook that imports this module
catalog = os.environ.get("AGENT_CATALOG", "main")
schema = os.environ.get("AGENT_SCHEMA", "default")
llm_endpoint_name = os.environ.get("AGENT_LLM_ENDPOINT", "databricks-claude-3-7-sonnet")

function_name = "you_com_search_function"
function_path = f"{catalog}.{schema}.{function_name}"
connection = "you_com_connection"

## method to make an MLflow Trace span around a UC function call.
def make_retriever(structured_tool):
    def _make(func):
        @mlflow.trace(span_type="RETRIEVER")
        def apply(*args, **kwargs):
            resp = func(*args, **kwargs)
            resp = json.loads(resp)
            hits = json.loads(resp.get("value", "")).get("hits", None)
            res = []
            if hits is not None:
                for hit in hits:
                    content = {"title": hit["title"], 
                               "description": hit["description"],
                               "snippets": hit["snippets"]}
                    metadata = {
                        "doc_uri": hit["url"]
                    }
                    res.append(Document(page_content=json.dumps(content), metadata=metadata))
            span = mlflow.get_current_active_span()
            span.set_outputs(res)
            return res
        return apply
    
    structured_tool.func = _make(structured_tool.func)
    return structured_tool

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define your LLM endpoint and system prompt
############################################
# NOTE - Choose a supported model serving LLM on Databricks
llm = ChatDatabricks(endpoint=llm_endpoint_name)

# NOTE - Amend System Prompt to tailor the agent to a specific use case
system_prompt = """You are a helpful AI assistant with access to real-time information through web search.

Your capabilities:
- Access current, up-to-date information through web search
- Answer questions about recent events, breaking news, and real-time data
- Provide accurate information with citations to sources

When answering questions:
1. Use the search tool to find relevant, current information
2. Synthesize information from multiple search results when appropriate
3. Cite your sources by mentioning the URLs or titles from search results
4. If information is time-sensitive, mention the recency of the data
5. Be clear when information may have changed since your search

Always strive to provide accurate, helpful, and well-sourced responses."""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################
tools = []

# You can use UDFs in Unity Catalog as agent tools
uc_tool_names = [function_path]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend([make_retriever(t) for t in uc_toolkit.tools])


#####################
## Define agent logic
#####################


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[Sequence[BaseTool], ToolNode],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                )


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphChatAgent(agent)
mlflow.models.set_model(AGENT)