In [0]:
%pip install -q -r requirements.txt
dbutils.library.restartPython()

In [0]:
import os 
import yaml
from dbruntime.databricks_repl_context import get_context

#for development purposes only
os.environ["DATABRICKS_TOKEN"] = get_context().apiToken

with open('./smoke_test_config.yaml', 'r') as file:
    config = yaml.safe_load(file)

databricks_config = config["databricks_config"]
input_example_config = config["input_examples"]
agent_config = config["agent_config"]
vector_search_config = config["vector_search_config"]

catalog = databricks_config["catalog"]
schema = databricks_config["schema"]

#mlflow and agent config
mlflow_experiment_name = databricks_config["mlflow_experiment_name"]
agent_name = agent_config["uc_agent_name"]
llm_endpoint_name = agent_config["llm_endpoint_name"]

#vector search config
embedding_endpoint_name = vector_search_config["embedding_model_endpoint_name"]
vector_search_index_name = vector_search_config["vector_search_index_name"]

In [0]:
%run ./02a_rag_agent

In [0]:
rag_input_example = {"messages": [{"role": "user", "content": input_example_config["rag_input_example"]}]}
genie_input_example = {"messages": [{"role": "user", "content": input_example_config["genie_input_example"]}]}

In [0]:
AGENT.predict(rag_input_example)

In [0]:
AGENT.predict(genie_input_example)

In [0]:
import mlflow

try:
  experiment = mlflow.get_experiment_by_name(mlflow_experiment_name)
except:
  mlflow.create_experiment(mlflow_experiment_name)
  print(f"Experiment created: {mlflow_experiment_name}")

In [0]:
import mlflow
from mlflow.models.resources import (
  DatabricksVectorSearchIndex,
  DatabricksServingEndpoint,
  DatabricksSQLWarehouse,
  DatabricksFunction,
  DatabricksGenieSpace,
  DatabricksTable,
  DatabricksUCConnection
)

with mlflow.start_run(run_name="smoke_test"):
    logged_chain_info = mlflow.pyfunc.log_model(
        python_model=os.path.join(os.getcwd(), "02a_rag_agent"),
        model_config=os.path.join(os.getcwd(), "smoke_test_config.yaml"),  # Chain configuration set in 00_config
        artifact_path="agent",  # Required by MLflow
        input_example=rag_input_example,
        resources=[
        DatabricksVectorSearchIndex(index_name=f"{catalog}.{schema}.{vector_search_index_name}"),
        DatabricksServingEndpoint(endpoint_name=llm_endpoint_name),
        DatabricksServingEndpoint(endpoint_name=embedding_endpoint_name)
        ],
        pip_requirements=["-r requirements.txt"],
        example_no_conversion=True,  # Required by MLflow to use the input_example as the chain's schema
    )

In [0]:
import os
import pandas as pd
evals_df = pd.read_csv(f"{os.getcwd()}/data/smoke_test_evals.tsv", sep='\t')

In [0]:
with mlflow.start_run(run_name="smoke_test_eval"):
    # Evaluate
    eval_results = mlflow.evaluate(
        data=evals_df,
        model=f"runs:/{logged_chain_info.run_id}/agent",  # replace `agent` with artifact_path that you used when calling log_model.  By default, this is `agent`.
        model_type="databricks-agent",
        evaluator_config={
            "databricks-agent": {
                "metrics": [
                    "chunk_relevance",
                    "context_sufficiency",
                    "correctness",
                    "safety",
                    "groundedness",
                ],
            }
        },
    )

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_chain_info.run_id}/agent",
    input_data=rag_input_example,
    env_manager="uv",
)

In [0]:
mlflow.set_registry_uri("databricks-uc")

UC_MODEL_NAME = f"{catalog}.{schema}.{agent_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_chain_info.model_uri, name=UC_MODEL_NAME
)

In [0]:
from databricks import agents
deployment_info = agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version)

In [0]:
import time
from typing import Dict

from mlflow.deployments import get_deploy_client

client = get_deploy_client("databricks")  # uses workspace creds automatically

def wait_for_endpoint_ready(
    endpoint_name: str, sleep_seconds: int = 30, timeout_seconds: int = 1_800
) -> Dict:
    start_time = time.time()

    while True:
        # Grab the latest metadata for the endpoint
        meta = client.get_endpoint(endpoint=endpoint_name)

        # 'state.ready' is "READY" once the endpoint can receive traffic
        ready_state = meta.get("state", {}).get("ready", "UNKNOWN")
        print(f"[{time.strftime('%H:%M:%S')}] {endpoint_name}: {ready_state}")

        if ready_state == "READY":
            print("✅ Endpoint is RUNNING and ready to serve requests.")
            return meta

        # Abort if we've waited too long
        if time.time() - start_time > timeout_seconds:
            raise TimeoutError(
                f"Endpoint {endpoint_name} was not READY after "
                f"{timeout_seconds} seconds."
            )

        # Otherwise, wait and try again
        time.sleep(sleep_seconds)

In [0]:
wait_for_endpoint_ready(deployment_info.endpoint_name)

In [0]:
print("""
             Smoke test complete!
            /
        (o_o)/
        <)  )
        /  \                \(o_o)/ \(o_o)/ \(o_o)/ \(o_o)/
      """)