In [0]:
%pip install -r requirements.txt

In [0]:
dbutils.library.restartPython()

In [0]:
import mlflow
from mlflow.models import ModelConfig

mlflow.langchain.autolog()
config = ModelConfig(development_config="config.yml")

In [0]:
# catalog: main
# schema: default
# add_function: test_add
# multiply_function: test_multiply

CATALOG = config.get("catalog")
SCHEMA = config.get("schema")
FUNCTION1 = config.get("add_function")
FUNCTION2 = config.get("multiply_function")

spark.sql(f"""
CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.{FUNCTION1} (
  a INT COMMENT 'the first operand.',
  b INT COMMENT 'the second operand.'
)
RETURNS INT
LANGUAGE PYTHON
COMMENT 'Adds two operands.'
AS $$
  return a + b
$$
""")

spark.sql(f"""
CREATE OR REPLACE FUNCTION {CATALOG}.{SCHEMA}.{FUNCTION2} (
  a INT COMMENT 'the first operand.',
  b INT COMMENT 'the second operand.'
)
RETURNS INT
LANGUAGE PYTHON
COMMENT 'Multiplies two operands.'
AS $$
  return a * b
$$
""")

In [0]:
from databricks.sdk import WorkspaceClient
from typing import Optional

# warehouse_name = "Starter Warehouse UC"
try:
    warehouse_name = config.get("warehouse_name")
except Exception:
    warehouse_name = None

w = WorkspaceClient()

# Put it in the config or let it fetch it by name
try:
    config_warehouse_id = config.get("warehouse_id")
except Exception:
    config_warehouse_id = None

def get_warehouse_id(warehouse_id: Optional[str], warehouse_name: Optional[str]) -> str:
    # This will try to create a warehouse with a name if id is not configured in config
    # or it will try to list and use it

    if warehouse_id is not None:
        return warehouse_id
    
    assert warehouse_name is not None, "warehouse_name is None unable to find warehouse"
    
    for wh in w.warehouses.list():
        if wh.name == warehouse_name:
            warehouse_id = wh.id

    try:
        if warehouse_id is None:
            warehouse = w.warehouses.create_and_wait(
                cluster_size="Small",
                auto_stop_mins=1, 
                name=warehouse_name, 
                enable_serverless_compute=True, 
                min_num_clusters=1, 
                max_num_clusters=1
            )
            warehouse_id = warehouse.id
    except Exception:
        print("you probably do not have permissions to do this")

    assert warehouse_id is not None, "Warehouse id is None"
    return warehouse_id

warehouse_id = get_warehouse_id(config_warehouse_id, warehouse_name)
warehouse_id

In [0]:
# Log the model to MLflow
import os
import mlflow
from mlflow.models.resources import DatabricksServingEndpoint, DatabricksSQLWarehouse, DatabricksFunction

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "What is an LLM agent?"
        }
    ]
}

function_1 = f"{config.get('catalog')}.{config.get('schema')}.{config.get('add_function')}"
function_2 = f"{config.get('catalog')}.{config.get('schema')}.{config.get('multiply_function')}"

with mlflow.start_run():
    logged_agent_info = mlflow.langchain.log_model(
        lc_model=os.path.join(
            os.getcwd(),
            'model',
        ),
        resources=[
            DatabricksServingEndpoint(endpoint_name=config.get("llm_endpoint")),
            DatabricksSQLWarehouse(warehouse_id=warehouse_id),
            DatabricksFunction(function_name=function_1),
            DatabricksFunction(function_name=function_2),
        ],
        pip_requirements="requirements.txt",
        model_config="config.yml",
        artifact_path='model',
        input_example=input_example,
    )

In [0]:
import pandas as pd

eval_examples = [
    {
        "request": {
            "messages": [
                {
                    "role": "user",
                    "content": "What is 2+2?"
                }
            ]
        },
        "expected_response": "4"
    }
]

eval_dataset = pd.DataFrame(eval_examples)
display(eval_dataset)

In [0]:
import mlflow

with mlflow.start_run(run_id=logged_agent_info.run_id):
    eval_results = mlflow.evaluate(
        f"runs:/{logged_agent_info.run_id}/model",  # replace `chain` with artifact_path that you used when calling log_model.
        data=eval_dataset,  # Your evaluation dataset
        model_type="databricks-agent",  # Enable Mosaic AI Agent Evaluation
    )

# Review the evaluation results in the MLFLow UI (see console output), or access them in place:
display(eval_results.tables['eval_results'])

In [0]:
from databricks.sdk import WorkspaceClient
import re

w = WorkspaceClient()

def normalize(inp: str) -> str:
    normalized_prefix = re.sub(r'\W+', '_', inp)
      # Remove multiple underscores in a row
    return re.sub(r'_+', '_', normalized_prefix).strip("_")

def get_current_user_normalized():
  email = w.current_user.me().user_name.split("@")[0]
  return normalize(email)

prefix = get_current_user_normalized()
prefix

In [0]:
import os

current_directory = os.getcwd()
current_folder_name = os.path.basename(current_directory)
normalized_current_folder = normalize(current_folder_name)
normalized_current_folder

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "main"
schema = "default"
# TODO: modify these values
model_name = f"{prefix}_{normalized_current_folder}"
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME)

In [0]:
from databricks import agents

# Deploy the model to the review app and a model serving endpoint
deployment = agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version)

In [0]:
deployment

In [0]:
result = w.serving_endpoints.wait_get_serving_endpoint_not_updating(name=deployment.endpoint_name, )

In [0]:
if result.state.config_update.value == "UPDATE_FAILED":
    raise Exception("Deployment failed")

In [0]:
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole
messages = [ChatMessage(role=ChatMessageRole.USER, content="Hello, what is 2+2?"),]
w.serving_endpoints.query(name=deployment.endpoint_name, messages=messages)