In [0]:
%pip install -r requirements.txt

In [0]:
dbutils.library.restartPython()

In [0]:
import mlflow
from mlflow.models import ModelConfig

mlflow.langchain.autolog()
config = ModelConfig(development_config="config.yml")

In [0]:
from langchain_databricks import ChatDatabricks

# Create the llm
llm = ChatDatabricks(endpoint=config.get("llm_endpoint"))

In [0]:
import re

def clean_string(input_string: str) -> str:
    cleaned = re.sub(r'[^a-zA-Z0-9\s]', '_', input_string)
    cleaned = re.sub(r'\s+', '_', cleaned)
    cleaned = re.sub(r'_+', '_', cleaned)
    return cleaned.strip('_').lower()


genie_space_id = config.get("genie_space_id")
_genie_agent_name = config.get("genie_agent_name")
genie_space_description = config.get("genie_space_description")

assert genie_space_id, f"Configure the genie_space_id in config.yml it is: {genie_space_id}"
assert _genie_agent_name, f"Configure the genie_agent_name in config.yml it is: {_genie_agent_name}"
assert genie_space_id, f"Configure the genie_space_description in config.yml it is: {genie_space_description}"

genie_agent_name = clean_string(_genie_agent_name)

In [0]:
from langchain_core.tools import BaseTool, StructuredTool, tool
from langchain_core.callbacks.manager import CallbackManagerForToolRun
from databricks_langchain.genie import GenieAgent
from pydantic import BaseModel, Field
import typing as t

class GenieAgentInput(BaseModel):
    question: str = Field(description="question to ask the agent")
    summarized_chat_history: str = Field(description="summarized chat history to provide the agent context of what may have been talked about. Say 'No history' if there is no history to provide.")


genie_agent = GenieAgent(
    genie_space_id, 
    config.get("genie_agent_name"), 
    description=genie_space_description)


class GenieQuestionTool(BaseTool):
    name = genie_agent_name
    description = genie_space_description
    args_schema: t.Type[BaseModel] = GenieAgentInput

    def _run(
        self, question: str, summarized_chat_history: str, run_manager: t.Optional[CallbackManagerForToolRun] = None
    ) -> str:
        return genie_agent.invoke({
            "messages": [
                {
                "role": "user",
                "content": f"ChatHistory: {summarized_chat_history}\nQuestion: {question}"
                }
            ]
        })


tools = [GenieQuestionTool()]

In [0]:
from typing import (
    Annotated,
    Optional,
    Sequence,
    TypedDict,
    Union,
)

from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
    BaseMessage,
    SystemMessage,
)
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool

from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_executor import ToolExecutor
from langgraph.prebuilt.tool_node import ToolNode


# We create the AgentState that we will pass around
# This simply involves a list of messages
class AgentState(TypedDict):
    """The state of the agent."""

    messages: Annotated[Sequence[BaseMessage], add_messages]


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolExecutor, Sequence[BaseTool]],
    agent_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: AgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there is no function call, then we finish
        if not last_message.tool_calls:
            return "end"
        else:
            return "continue"

    if agent_prompt:
        system_message = SystemMessage(content=agent_prompt)
        preprocessor = RunnableLambda(
            lambda state: [system_message] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    # Define the function that calls the model
    def call_model(
        state: AgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)
        return {"messages": [response]}

    workflow = StateGraph(AgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        # First, we define the start node. We use agent.
        # This means these are the edges taken after the agent node is called.
        "agent",
        # Next, we pass in the function that will determine which node is called next.
        should_continue,
        # The mapping below will be used to determine which node to go to
        {
            # If tools, then we call the tool node.
            "continue": "tools",
            # END is a special node marking that the graph should finish.
            "end": END,
        },
    )
    # We now add a unconditional edge from tools to agent.
    workflow.add_edge("tools", "agent")

    return workflow.compile()

In [0]:
import json
from typing import Iterator, Dict, Any

from langgraph.prebuilt import create_react_agent
from langchain_core.runnables import RunnableGenerator
from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    ToolMessage,
    MessageLikeRepresentation,
)
from mlflow.langchain.output_parsers import ChatCompletionsOutputParser

agent = create_react_agent(
    llm,
    tools,
    state_modifier="You are a helpful assistant. Make sure to use tool for information.",
)

def stringify_tool_call(tool_call: Dict[str, Any]) -> str:
    """
    Convert a raw tool call into a formatted string that the playground UI expects if there is enough information in the tool_call
    """
    try:
        request = json.dumps(
            {
                "id": tool_call.get("id"),
                "name": tool_call.get("name"),
                "arguments": json.dumps(tool_call.get("args", {})),
            },
            indent=2,
        )
        return f"<tool_call>{request}</tool_call>"
    except:
        return str(tool_call)
    
def stringify_tool_result(tool_msg: ToolMessage) -> str:
    """
    Convert a ToolMessage into a formatted string that the playground UI expects if there is enough information in the ToolMessage
    """
    try:
        result = json.dumps(
            {"id": tool_msg.tool_call_id, "content": tool_msg.content}, indent=2
        )
        return f"<tool_call_result>{result}</tool_call_result>"
    except:
        return str(tool_msg)


def parse_message(msg) -> str:
    """Parse different message types into their string representations"""
    # tool call result
    if isinstance(msg, ToolMessage):
        return stringify_tool_result(msg)
    # tool call
    elif isinstance(msg, AIMessage) and msg.tool_calls:
        tool_call_results = [stringify_tool_call(call) for call in msg.tool_calls]
        return "".join(tool_call_results)
    # normal HumanMessage or AIMessage (reasoning or final answer)
    elif isinstance(msg, (AIMessage, HumanMessage)):
        return msg.content
    else:
        print(f"Unexpected message type: {type(msg)}")
        return str(msg)

def wrap_output(stream: Iterator[MessageLikeRepresentation]) -> Iterator[str]:
    """
    Process and yield formatted outputs from the message stream.
    The invoke and stream langchain functions produce different output formats.
    This function handles both cases.
    """
    for event in stream:
        # the agent was called with invoke()
        if "messages" in event:
            for msg in event["messages"]:
                yield parse_message(msg) + "\n\n"
        # the agent was called with stream()
        else:
            for node in event:
                for key, messages in event[node].items():
                    if isinstance(messages, list):
                        for msg in messages:
                            yield parse_message(msg) + "\n\n"
                    else:
                        print("Unexpected value {messages} for key {key}. Expected a list of `MessageLikeRepresentation`'s")
                        yield str(messages)


In [0]:
# modify wrap input to make this simpler
chain = agent | RunnableGenerator(wrap_output) | ChatCompletionsOutputParser()

In [0]:
# chain.invoke({"messages": [{"role": "user", "content": "Which platforms have the highest number of churned users based on the last event before churning?"}]})

In [0]:
mlflow.models.set_model(chain)