In [0]:
%pip install -r requirements.txt

In [0]:
dbutils.library.restartPython()

In [0]:
import mlflow
from mlflow.models import ModelConfig

mlflow.langchain.autolog()
config = ModelConfig(development_config="config.yml")

In [0]:
from databricks.sdk import WorkspaceClient
from typing import Optional

# warehouse_name = "Starter Warehouse UC"
try:
    warehouse_name = config.get("warehouse_name")
except Exception:
    warehouse_name = None

w = WorkspaceClient()

# Put it in the config or let it fetch it by name
try:
    config_warehouse_id = config.get("warehouse_id")
except Exception:
    config_warehouse_id = None

def get_warehouse_id(warehouse_id: Optional[str], warehouse_name: Optional[str]) -> str:
    # This will try to create a warehouse with a name if id is not configured in config
    # or it will try to list and use it

    if warehouse_id is not None:
        return warehouse_id
    
    assert warehouse_name is not None, "warehouse_name is None unable to find warehouse"
    
    for wh in w.warehouses.list():
        if wh.name == warehouse_name:
            warehouse_id = wh.id

    try:
        if warehouse_id is None:
            warehouse = w.warehouses.create_and_wait(
                cluster_size="Small",
                auto_stop_mins=1, 
                name=warehouse_name, 
                enable_serverless_compute=True, 
                min_num_clusters=1, 
                max_num_clusters=1
            )
            warehouse_id = warehouse.id
    except Exception:
        print("you probably do not have permissions to do this")

    assert warehouse_id is not None, "Warehouse id is None"
    return warehouse_id

warehouse_id = get_warehouse_id(config_warehouse_id, warehouse_name)
warehouse_id

In [0]:
import os
from langchain_community.tools.databricks import UCFunctionToolkit

os.environ["UC_TOOL_CLIENT_EXECUTION_TIMEOUT"] = "200"

CATALOG = config.get("catalog")
SCHEMA = config.get("schema")
FUNCTION1 = config.get("add_function")
FUNCTION2 = config.get("multiply_function")

tools = (
    UCFunctionToolkit(
        warehouse_id=warehouse_id
    )
    .include(
        # Include functions as tools using their qualified names.
        # You can use "{catalog_name}.{schema_name}.*" to get all functions in a schema.
        f"{CATALOG}.{SCHEMA}.{FUNCTION1}",
        f"{CATALOG}.{SCHEMA}.{FUNCTION2}",
    )
    .get_tools()
)
tools

In [0]:
from langchain_databricks import ChatDatabricks

# Create the llm
llm = ChatDatabricks(endpoint=config.get("llm_endpoint"))

In [0]:
from langchain import hub

# Get the prompt to use - can be replaced with any prompt that includes variables "agent_scratchpad" and "input"!
prompt = hub.pull("hwchase17/openai-tools-agent")
prompt.pretty_print()
prompt

In [0]:
from langchain.agents import AgentExecutor, create_tool_calling_agent

agent = create_tool_calling_agent(llm, tools, prompt)

agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

In [0]:
from mlflow.langchain.output_parsers import ChatCompletionsOutputParser

output_parser = ChatCompletionsOutputParser()

In [0]:
from langchain_core.runnables import RunnableLambda

# needed because agent executor returns {"input": "...", "output": "..."}
def agent_executor_to_just_response(inp):
    return inp["output"]
  
def pre_process_input(inp):
    # this is needed to conform to agent executor input which requires input and agent_scratchpad
    return {
        "input": inp
    }

chain = RunnableLambda(pre_process_input) | agent_executor | RunnableLambda(agent_executor_to_just_response) | output_parser

In [0]:
# chain.invoke("what is an llm?")

In [0]:
# chain.invoke("what is 2 + 2?")

In [0]:
# ensure that langchain is < 0.3.0
mlflow.models.set_model(chain)