In [0]:
%pip install -r requirements.txt

In [0]:
dbutils.library.restartPython()

In [0]:
import mlflow
from mlflow.models import ModelConfig

mlflow.langchain.autolog()
config = ModelConfig(development_config="config.yml")

In [0]:
from langchain_databricks import ChatDatabricks

# Create the llm
llm = ChatDatabricks(endpoint=config.get("llm_endpoint"))

In [0]:
import re

def clean_string(input_string: str) -> str:
    cleaned = re.sub(r'[^a-zA-Z0-9\s]', '_', input_string)
    cleaned = re.sub(r'\s+', '_', cleaned)
    cleaned = re.sub(r'_+', '_', cleaned)
    return cleaned.strip('_').lower()


genie_space_id = config.get("genie_space_id")
_genie_agent_name = config.get("genie_agent_name")
genie_space_description = config.get("genie_space_description")

assert genie_space_id, f"Configure the genie_space_id in config.yml it is: {genie_space_id}"
assert _genie_agent_name, f"Configure the genie_agent_name in config.yml it is: {_genie_agent_name}"
assert genie_space_id, f"Configure the genie_space_description in config.yml it is: {genie_space_description}"

genie_agent_name = clean_string(_genie_agent_name)

In [0]:
from langchain_core.tools import BaseTool, StructuredTool, tool
from langchain_core.callbacks.manager import CallbackManagerForToolRun
from databricks_langchain.genie import GenieAgent
from pydantic import BaseModel, Field
import typing as t

class GenieAgentInput(BaseModel):
    question: str = Field(description="question to ask the agent")
    summarized_chat_history: str = Field(description="summarized chat history to provide the agent context of what may have been talked about. Say 'No history' if there is no history to provide.")


genie_agent = GenieAgent(
    genie_space_id, 
    config.get("genie_agent_name"), 
    description=genie_space_description)


class GenieQuestionTool(BaseTool):
    name = genie_agent_name
    description = genie_space_description
    args_schema: t.Type[BaseModel] = GenieAgentInput

    def _run(
        self, question: str, summarized_chat_history: str, run_manager: t.Optional[CallbackManagerForToolRun] = None
    ) -> str:
        return genie_agent.invoke({
            "messages": [
                {
                "role": "user",
                "content": f"ChatHistory: {summarized_chat_history}\nQuestion: {question}"
                }
            ]
        })


tools = [GenieQuestionTool()]

In [0]:
from langchain import hub

# Get the prompt to use - can be replaced with any prompt that includes variables "agent_scratchpad" and "input"!
prompt = hub.pull("hwchase17/openai-tools-agent")
prompt.pretty_print()
prompt

In [0]:
from langchain.agents import AgentExecutor, create_tool_calling_agent

agent = create_tool_calling_agent(llm, tools, prompt)

agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

In [0]:
from mlflow.langchain.output_parsers import ChatCompletionsOutputParser

output_parser = ChatCompletionsOutputParser()

In [0]:
from langchain_core.runnables import RunnableLambda

# needed because agent executor returns {"input": "...", "output": "..."}
def agent_executor_to_just_response(inp):
    return inp["output"]
  
def pre_process_input(inp):
    # this is needed to conform to agent executor input which requires input and agent_scratchpad
    return {
        "input": inp
    }

chain = RunnableLambda(pre_process_input) | agent_executor | RunnableLambda(agent_executor_to_just_response) | output_parser

In [0]:
# chain.invoke("Which platforms have the highest number of churned users based on the last event before churning?")

In [0]:
mlflow.models.set_model(chain)