#Agent notebook

This is an auto-generated notebook created by an AI Playground export. We generated three notebooks in the same folder:
- [**agent**]($./agent): contains the code to build the agent.
- [config.yml]($./config.yml): contains the configurations.
- [driver]($./driver): logs, evaluate, registers, and deploys the agent.

This notebook uses Mosaic AI Agent Framework ([AWS](https://docs.databricks.com/en/generative-ai/retrieval-augmented-generation.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/retrieval-augmented-generation)) to recreate your agent from the AI Playground. It defines a LangChain agent that has access to the tools from the source Playground session.

Use this notebook to iterate on and modify the agent. For example, you could add more tools or change the system prompt.

 **_NOTE:_**  This notebook uses LangChain, however AI Agent Framework is compatible with other agent frameworks like Pyfunc and LlamaIndex.

## Prerequisites

- Address all `TODO`s in this notebook.
- Review the contents of [config.yml]($./config.yml) as it defines the tools available to your agent, the LLM endpoint, and the agent prompt.

## Next steps

After testing and iterating on your agent in this notebook, go to the auto-generated [driver]($./driver) notebook in this folder to log, register, and deploy the agent.

In [0]:
%pip install -U -qqqq mlflow-skinny langchain==0.2.16 langgraph-checkpoint==1.0.12 langchain_core langchain-community==0.2.16 langgraph==0.2.16 pydantic
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


## Import and setup

Use `mlflow.langchain.autolog()` to set up [MLflow traces](https://docs.databricks.com/en/mlflow/mlflow-tracing.html).

In [0]:
import mlflow
from mlflow.models import ModelConfig

mlflow.langchain.autolog()
config = ModelConfig(development_config="config.yml")

## Define the chat model and tools
Create a LangChain chat model that supports [LangGraph tool](https://langchain-ai.github.io/langgraph/how-tos/tool-calling/) calling.

Modify the tools your agent has access to by modifying the `uc_functions` list in [config.yml]($./config.yml). Any non-UC function spec tools can be defined in this notebook. See [LangChain - How to create tools](https://python.langchain.com/v0.2/docs/how_to/custom_tools/) and [LangChain - Using built-in tools](https://python.langchain.com/v0.2/docs/how_to/tools_builtin/).

 **_NOTE:_**  This notebook uses LangChain, however AI Agent Framework is compatible with other agent frameworks like Pyfunc and LlamaIndex.

In [0]:
from langchain_community.chat_models import ChatDatabricks
from langchain_community.tools.databricks import UCFunctionToolkit

# Create the llm
llm = ChatDatabricks(endpoint=config.get("llm_endpoint"))

uc_functions = config.get("uc_functions")

uc_function_tools = (
    UCFunctionToolkit(warehouse_id=config.get("warehouse_id"))
    .include(*uc_functions)
    .get_tools()
)


# TODO: fill in implementations for functions search_databricks_index

## Output parsers
Databricks interfaces, such as the AI Playground, can optionally display pretty-printed tool calls.

Use the following helper functions to parse the LLM's output into the expected format.

In [0]:
from typing import Iterator, Dict, Any
from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    ToolMessage,
    MessageLikeRepresentation,
)

import json

def stringify_tool_call(tool_call: Dict[str, Any]) -> str:
    """
    Convert a raw tool call into a formatted string that the playground UI expects if there is enough information in the tool_call
    """
    try:
        request = json.dumps(
            {
                "id": tool_call.get("id"),
                "name": tool_call.get("name"),
                "arguments": json.dumps(tool_call.get("args", {})),
            },
            indent=2,
        )
        return f"<tool_call>{request}</tool_call>"
    except:
        return str(tool_call)


def stringify_tool_result(tool_msg: ToolMessage) -> str:
    """
    Convert a ToolMessage into a formatted string that the playground UI expects if there is enough information in the ToolMessage
    """
    try:
        result = json.dumps(
            {"id": tool_msg.tool_call_id, "content": tool_msg.content}, indent=2
        )
        return f"<tool_call_result>{result}</tool_call_result>"
    except:
        return str(tool_msg)


def parse_message(msg) -> str:
    """Parse different message types into their string representations"""
    # tool call result
    if isinstance(msg, ToolMessage):
        return stringify_tool_result(msg)
    # tool call
    elif isinstance(msg, AIMessage) and msg.tool_calls:
        tool_call_results = [stringify_tool_call(call) for call in msg.tool_calls]
        return "".join(tool_call_results)
    # normal HumanMessage or AIMessage (reasoning or final answer)
    elif isinstance(msg, (AIMessage, HumanMessage)):
        return msg.content
    else:
        print(f"Unexpected message type: {type(msg)}")
        return str(msg)


def wrap_output(stream: Iterator[MessageLikeRepresentation]) -> Iterator[str]:
    """
    Process and yield formatted outputs from the message stream.
    The invoke and stream langchain functions produce different output formats.
    This function handles both cases.
    """
    for event in stream:
        # the agent was called with invoke()
        if "messages" in event:
            for msg in event["messages"]:
                yield parse_message(msg) + "\n\n"
        # the agent was called with stream()
        else:
            for node in event:
                for key, messages in event[node].items():
                    if isinstance(messages, list):
                        for msg in messages:
                            yield parse_message(msg) + "\n\n"
                    else:
                        print("Unexpected value {messages} for key {key}. Expected a list of `MessageLikeRepresentation`'s")
                        yield str(messages)

## Create the agent
Here we provide a simple graph that uses the model and tools defined by [config.yml]($./config.yml). This graph is adapated from [this LangGraph guide](https://langchain-ai.github.io/langgraph/how-tos/react-agent-from-scratch/).


To further customize your LangGraph agent, you can refer to:
* [LangGraph - Quick Start](https://langchain-ai.github.io/langgraph/tutorials/introduction/) for explanations of the concepts used in this LangGraph agent
* [LangGraph - How-to Guides](https://langchain-ai.github.io/langgraph/how-tos/) to expand the functionality of your agent


In [0]:
from typing import (
    Annotated,
    Optional,
    Sequence,
    TypedDict,
    Union,
)

from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
    BaseMessage,
    SystemMessage,
)
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool

from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt.tool_node import ToolNode

from pyspark.sql import SparkSession  
from pyspark.sql.functions import col  

class DataFetchingAgent:  
    def __init__(self, catalog: str, schema: str, index_name: str):  
        self.catalog = catalog  
        self.schema = schema  
        self.index_name = index_name  
        self.spark = SparkSession.builder \
            .appName("DataFetchingAgent") \
            .getOrCreate()  
    
    def fetch_data(self, query: str):  
        try:  
            # Assuming the index is registered in the metastore  
            index_path = "databrickshackathon.default.spentinsights"             
            df = self.spark.table(index_path).filter(col("content").contains(query))  
            return df.collect()  
        except Exception as e:  
            return f"Error fetching data: {str(e)}"  

# We create the AgentState that we will pass around
# This simply involves a list of messages
class AgentState(TypedDict):
    """The state of the agent."""

    messages: Annotated[Sequence[BaseMessage], add_messages]


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolExecutor, Sequence[Sequence[DataFetchingAgent]
 ]],
    agent_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: AgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there is no function call, then we finish
        if not last_message.tool_calls:
            return "end"
        else:
            return "continue"

    if agent_prompt:
        system_message = SystemMessage(content=agent_prompt)
        preprocessor = RunnableLambda(
            lambda state: [system_message] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    # Define the function that calls the model
    def call_model(
        state: AgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)
        return {"messages": [response]}

    workflow = StateGraph(AgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        # First, we define the start node. We use agent.
        # This means these are the edges taken after the agent node is called.
        "agent",
        # Next, we pass in the function that will determine which node is called next.
        should_continue,
        # The mapping below will be used to determine which node to go to
        {
            # If tools, then we call the tool node.
            "continue": "tools",
            # END is a special node marking that the graph should finish.
            "end": END,
        },
    )
    # We now add a unconditional edge from tools to agent.
    workflow.add_edge("tools", "agent")

    return workflow.compile()

In [0]:
from langchain_core.runnables import RunnableGenerator
from mlflow.langchain.output_parsers import ChatCompletionsOutputParser

# Create the agent with the system message if it exists
try:
    agent_prompt = config.get("agent_prompt")
    agent_with_raw_output = create_tool_calling_agent(
        llm, uc_function_tools, agent_prompt=agent_prompt
    )
except KeyError:
    agent_with_raw_output = create_tool_calling_agent(llm, uc_function_tools)
agent = agent_with_raw_output | RunnableGenerator(wrap_output) | ChatCompletionsOutputParser()

## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

Replace this placeholder input with an appropriate domain-specific example for your agent.

In [0]:
# TODO: replace this placeholder input example with an appropriate domain-specific example for your agent
for event in agent.stream({"messages": [{"role": "user", "content": "What is the cheapest laptop?"}]}):
    print(event, "---" * 20 + "\n")

{'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': "The cheapest laptop can vary depending on several factors such as the retailer, location, and availability. However, I can give you some general ideas on what to expect.\n\n**Under $200:**\n\n* Acer Aspire 3: A basic laptop with an Intel Celeron processor, 4GB of RAM, and a 64GB eMMC storage. ($150-$200)\n* Lenovo IdeaPad 130S: A compact laptop with an Intel Celeron processor, 4GB of RAM, and a 64GB eMMC storage. ($150-$200)\n* HP Stream 14: A budget-friendly laptop with an Intel Celeron processor, 4GB of RAM, and a 64GB eMMC storage. ($150-$200)\n\n**Under $300:**\n\n* Acer Aspire E 15: A budget laptop with an Intel Core i3 processor, 6GB of RAM, and a 256GB hard drive. ($250-$300)\n* Lenovo IdeaPad 330S: A slim laptop with an Intel Core i3 processor, 4GB of RAM, and a 256GB hard drive. ($250-$300)\n* Dell Inspiron 15 3000: A budget laptop with an Intel Core i3 processor, 4GB of RAM, and a 256GB hard drive. ($250-$

Trace(request_id=tr-f8537f0adb6e4fbb83cd891936113583)

In [0]:
mlflow.models.set_model(agent)

## Next steps

You can rerun the cells above to iterate and test the agent.

Go to the auto-generated [driver]($./driver) notebook in this folder to log, register, and deploy the agent.